11import asyncio
22import datetime
33import uuid
4+ from collections .abc import AsyncGenerator
5+ from contextlib import asynccontextmanager
46from datetime import timedelta
57from functools import partial
68from typing import List , Optional , Sequence
4547 ProjectModel ,
4648 UserModel ,
4749)
50+ from dstack ._internal .server .services import events
4851from dstack ._internal .server .services .backends import (
4952 check_backend_type_available ,
5053 get_project_backend_by_type_or_error ,
6669logger = get_logger (__name__ )
6770
6871
72+ def switch_gateway_status (
73+ session : AsyncSession ,
74+ gateway_model : GatewayModel ,
75+ new_status : GatewayStatus ,
76+ actor : events .AnyActor = events .SystemActor (),
77+ ):
78+ old_status = gateway_model .status
79+ if old_status == new_status :
80+ return
81+
82+ gateway_model .status = new_status
83+
84+ msg = f"Gateway status changed { old_status .upper ()} -> { new_status .upper ()} "
85+ if gateway_model .status_message is not None :
86+ msg += f" ({ gateway_model .status_message } )"
87+ events .emit (session , msg , actor = actor , targets = [events .Target .from_model (gateway_model )])
88+
89+
6990GATEWAY_CONNECT_ATTEMPTS = 30
7091GATEWAY_CONNECT_DELAY = 10
7192GATEWAY_CONFIGURE_ATTEMPTS = 50
@@ -163,6 +184,7 @@ async def create_gateway(
163184 configuration .name = await generate_gateway_name (session = session , project = project )
164185
165186 gateway = GatewayModel (
187+ id = uuid .uuid4 (),
166188 name = configuration .name ,
167189 region = configuration .region ,
168190 project_id = project .id ,
@@ -173,11 +195,19 @@ async def create_gateway(
173195 last_processed_at = get_current_datetime (),
174196 )
175197 session .add (gateway )
198+ events .emit (
199+ session ,
200+ f"Gateway created. Status: { gateway .status .upper ()} " ,
201+ actor = events .UserActor .from_user (user ),
202+ targets = [events .Target .from_model (gateway )],
203+ )
176204 await session .commit ()
177205
178206 default_gateway = await get_project_default_gateway_model (session = session , project = project )
179207 if default_gateway is None or configuration .default :
180- await set_default_gateway (session = session , project = project , name = configuration .name )
208+ await set_default_gateway (
209+ session = session , project = project , name = configuration .name , user = user
210+ )
181211 return gateway_model_to_gateway (gateway )
182212
183213
@@ -214,6 +244,7 @@ async def delete_gateways(
214244 session : AsyncSession ,
215245 project : ProjectModel ,
216246 gateways_names : List [str ],
247+ user : UserModel ,
217248):
218249 res = await session .execute (
219250 select (GatewayModel ).where (
@@ -273,46 +304,51 @@ async def delete_gateways(
273304 gateway_model .gateway_compute .deleted = True
274305 session .add (gateway_model .gateway_compute )
275306 await session .delete (gateway_model )
307+ events .emit (
308+ session ,
309+ "Gateway deleted" ,
310+ actor = events .UserActor .from_user (user ),
311+ targets = [events .Target .from_model (gateway_model )],
312+ )
276313 await session .commit ()
277314
278315
279316async def set_gateway_wildcard_domain (
280- session : AsyncSession , project : ProjectModel , name : str , wildcard_domain : Optional [str ]
317+ session : AsyncSession ,
318+ project : ProjectModel ,
319+ name : str ,
320+ wildcard_domain : Optional [str ],
321+ user : UserModel ,
281322) -> Gateway :
282- gateway = await get_project_gateway_model_by_name (
283- session = session ,
284- project = project ,
285- name = name ,
286- )
287- if gateway is None :
288- raise ResourceNotExistsError ()
289- if gateway .backend .type == BackendType .DSTACK :
290- raise ServerClientError ("Custom domains for dstack Sky gateway are not supported" )
291- await session .execute (
292- update (GatewayModel )
293- .where (
294- GatewayModel .project_id == project .id ,
295- GatewayModel .name == name ,
296- )
297- .values (
298- wildcard_domain = wildcard_domain ,
299- )
300- )
301- await session .commit ()
302- gateway = await get_project_gateway_model_by_name (
303- session = session ,
304- project = project ,
305- name = name ,
306- )
307- if gateway is None :
308- raise ResourceNotExistsError ()
323+ async with get_project_gateway_model_by_name_for_update (
324+ session = session , project = project , name = name
325+ ) as gateway :
326+ if gateway is None :
327+ raise ResourceNotExistsError ()
328+ if gateway .backend .type == BackendType .DSTACK :
329+ raise ServerClientError ("Custom domains for dstack Sky gateway are not supported" )
330+ old_domain = gateway .wildcard_domain
331+ if old_domain != wildcard_domain :
332+ gateway .wildcard_domain = wildcard_domain
333+ events .emit (
334+ session ,
335+ f"Gateway wildcard domain changed { old_domain !r} -> { gateway .wildcard_domain !r} " ,
336+ actor = events .UserActor .from_user (user ),
337+ targets = [events .Target .from_model (gateway )],
338+ )
339+ await session .commit ()
309340 return gateway_model_to_gateway (gateway )
310341
311342
312- async def set_default_gateway (session : AsyncSession , project : ProjectModel , name : str ):
343+ async def set_default_gateway (
344+ session : AsyncSession , project : ProjectModel , name : str , user : Optional [UserModel ]
345+ ):
313346 gateway = await get_project_gateway_model_by_name (session = session , project = project , name = name )
314347 if gateway is None :
315348 raise ResourceNotExistsError ()
349+ if project .default_gateway_id == gateway .id :
350+ return
351+ previous_gateway = await get_project_default_gateway_model (session , project )
316352 await session .execute (
317353 update (ProjectModel )
318354 .where (
@@ -322,6 +358,19 @@ async def set_default_gateway(session: AsyncSession, project: ProjectModel, name
322358 default_gateway_id = gateway .id ,
323359 )
324360 )
361+ if previous_gateway is not None :
362+ events .emit (
363+ session ,
364+ "Gateway unset as default" ,
365+ actor = events .UserActor .from_user (user ) if user is not None else events .SystemActor (),
366+ targets = [events .Target .from_model (previous_gateway )],
367+ )
368+ events .emit (
369+ session ,
370+ "Gateway set as default" ,
371+ actor = events .UserActor .from_user (user ) if user is not None else events .SystemActor (),
372+ targets = [events .Target .from_model (gateway )],
373+ )
325374 await session .commit ()
326375
327376
@@ -343,6 +392,38 @@ async def get_project_gateway_model_by_name(
343392 return res .scalar ()
344393
345394
395+ @asynccontextmanager
396+ async def get_project_gateway_model_by_name_for_update (
397+ session : AsyncSession , project : ProjectModel , name : str
398+ ) -> AsyncGenerator [Optional [GatewayModel ], None ]:
399+ """
400+ Fetch the gateway from the database and lock it for update.
401+
402+ **NOTE**: commit changes to the database before exiting from this context manager,
403+ so that in-memory locks are only released after commit.
404+ """
405+
406+ filters = [
407+ GatewayModel .project_id == project .id ,
408+ GatewayModel .name == name ,
409+ ]
410+ res = await session .execute (select (GatewayModel .id ).where (* filters ))
411+ gateway_id = res .scalar_one_or_none ()
412+ if gateway_id is None :
413+ yield None
414+ else :
415+ async with get_locker (get_db ().dialect_name ).lock_ctx (
416+ GatewayModel .__tablename__ , [gateway_id ]
417+ ):
418+ # Refetch after lock
419+ res = await session .execute (
420+ select (GatewayModel )
421+ .where (GatewayModel .id .in_ ([gateway_id ]), * filters )
422+ .with_for_update (key_share = True , of = GatewayModel )
423+ )
424+ yield res .scalar_one_or_none ()
425+
426+
346427async def get_project_default_gateway_model (
347428 session : AsyncSession , project : ProjectModel
348429) -> Optional [GatewayModel ]:
0 commit comments