Skip to content

Commit b4c6f17

Browse files
authored
Add gateway lifecycle events (#3500)
- Gateway created - Gateway status changed - Gateway deleted - Gateway set as default - Gateway unset as default - Gateway wildcard domain changed
1 parent 5efca70 commit b4c6f17

6 files changed

Lines changed: 211 additions & 57 deletions

File tree

src/dstack/_internal/server/background/tasks/process_gateways.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
GatewayConnection,
1515
create_gateway_compute,
1616
gateway_connections_pool,
17+
switch_gateway_status,
1718
)
1819
from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker
1920
from dstack._internal.server.services.logging import fmt
@@ -60,14 +61,6 @@ async def process_gateways():
6061
logger.error(
6162
"%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper()
6263
)
63-
if gateway_model.status != initial_status:
64-
logger.info(
65-
"%s: gateway status has changed %s -> %s%s",
66-
fmt(gateway_model),
67-
initial_status.upper(),
68-
gateway_model.status.upper(),
69-
f": {gateway_model.status_message}" if gateway_model.status_message else "",
70-
)
7164
gateway_model.last_processed_at = get_current_datetime()
7265
await session.commit()
7366
finally:
@@ -128,8 +121,8 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
128121
project=gateway_model.project, backend_type=configuration.backend
129122
)
130123
except BackendNotAvailable:
131-
gateway_model.status = GatewayStatus.FAILED
132124
gateway_model.status_message = "Backend not available"
125+
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
133126
return
134127

135128
try:
@@ -140,18 +133,17 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
140133
backend_id=backend_model.id,
141134
)
142135
session.add(gateway_model)
143-
gateway_model.status = GatewayStatus.PROVISIONING
136+
switch_gateway_status(session, gateway_model, GatewayStatus.PROVISIONING)
144137
except BackendError as e:
145-
logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e)
146-
gateway_model.status = GatewayStatus.FAILED
147138
status_message = f"Backend error: {repr(e)}"
148139
if len(e.args) > 0:
149140
status_message = str(e.args[0])
150141
gateway_model.status_message = status_message
142+
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
151143
except Exception as e:
152144
logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model))
153-
gateway_model.status = GatewayStatus.FAILED
154145
gateway_model.status_message = f"Unexpected error: {repr(e)}"
146+
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
155147

156148

157149
async def _process_provisioning_gateway(
@@ -179,18 +171,18 @@ async def _process_provisioning_gateway(
179171
gateway_model.gateway_compute
180172
)
181173
if connection is None:
182-
gateway_model.status = GatewayStatus.FAILED
183174
gateway_model.status_message = "Failed to connect to gateway"
175+
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
184176
gateway_model.gateway_compute.deleted = True
185177
return
186178
try:
187179
await gateways_services.configure_gateway(connection)
188180
except Exception:
189181
logger.exception("%s: failed to configure gateway", fmt(gateway_model))
190-
gateway_model.status = GatewayStatus.FAILED
191182
gateway_model.status_message = "Failed to configure gateway"
183+
switch_gateway_status(session, gateway_model, GatewayStatus.FAILED)
192184
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
193185
gateway_model.gateway_compute.active = False
194186
return
195187

196-
gateway_model.status = GatewayStatus.RUNNING
188+
switch_gateway_status(session, gateway_model, GatewayStatus.RUNNING)

src/dstack/_internal/server/routers/gateways.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ async def delete_gateways(
7272
session: AsyncSession = Depends(get_session),
7373
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
7474
):
75-
_, project = user_project
75+
user, project = user_project
7676
await gateways.delete_gateways(
7777
session=session,
7878
project=project,
7979
gateways_names=body.names,
80+
user=user,
8081
)
8182

8283

@@ -86,8 +87,8 @@ async def set_default_gateway(
8687
session: AsyncSession = Depends(get_session),
8788
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
8889
):
89-
_, project = user_project
90-
await gateways.set_default_gateway(session=session, project=project, name=body.name)
90+
user, project = user_project
91+
await gateways.set_default_gateway(session=session, project=project, name=body.name, user=user)
9192

9293

9394
@router.post("/set_wildcard_domain", response_model=models.Gateway)
@@ -96,9 +97,13 @@ async def set_gateway_wildcard_domain(
9697
session: AsyncSession = Depends(get_session),
9798
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
9899
):
99-
_, project = user_project
100+
user, project = user_project
100101
return CustomORJSONResponse(
101102
await gateways.set_gateway_wildcard_domain(
102-
session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain
103+
session=session,
104+
project=project,
105+
name=body.name,
106+
wildcard_domain=body.wildcard_domain,
107+
user=user,
103108
)
104109
)

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 111 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
import datetime
33
import uuid
4+
from collections.abc import AsyncGenerator
5+
from contextlib import asynccontextmanager
46
from datetime import timedelta
57
from functools import partial
68
from typing import List, Optional, Sequence
@@ -45,6 +47,7 @@
4547
ProjectModel,
4648
UserModel,
4749
)
50+
from dstack._internal.server.services import events
4851
from dstack._internal.server.services.backends import (
4952
check_backend_type_available,
5053
get_project_backend_by_type_or_error,
@@ -66,6 +69,24 @@
6669
logger = get_logger(__name__)
6770

6871

72+
def switch_gateway_status(
73+
session: AsyncSession,
74+
gateway_model: GatewayModel,
75+
new_status: GatewayStatus,
76+
actor: events.AnyActor = events.SystemActor(),
77+
):
78+
old_status = gateway_model.status
79+
if old_status == new_status:
80+
return
81+
82+
gateway_model.status = new_status
83+
84+
msg = f"Gateway status changed {old_status.upper()} -> {new_status.upper()}"
85+
if gateway_model.status_message is not None:
86+
msg += f" ({gateway_model.status_message})"
87+
events.emit(session, msg, actor=actor, targets=[events.Target.from_model(gateway_model)])
88+
89+
6990
GATEWAY_CONNECT_ATTEMPTS = 30
7091
GATEWAY_CONNECT_DELAY = 10
7192
GATEWAY_CONFIGURE_ATTEMPTS = 50
@@ -163,6 +184,7 @@ async def create_gateway(
163184
configuration.name = await generate_gateway_name(session=session, project=project)
164185

165186
gateway = GatewayModel(
187+
id=uuid.uuid4(),
166188
name=configuration.name,
167189
region=configuration.region,
168190
project_id=project.id,
@@ -173,11 +195,19 @@ async def create_gateway(
173195
last_processed_at=get_current_datetime(),
174196
)
175197
session.add(gateway)
198+
events.emit(
199+
session,
200+
f"Gateway created. Status: {gateway.status.upper()}",
201+
actor=events.UserActor.from_user(user),
202+
targets=[events.Target.from_model(gateway)],
203+
)
176204
await session.commit()
177205

178206
default_gateway = await get_project_default_gateway_model(session=session, project=project)
179207
if default_gateway is None or configuration.default:
180-
await set_default_gateway(session=session, project=project, name=configuration.name)
208+
await set_default_gateway(
209+
session=session, project=project, name=configuration.name, user=user
210+
)
181211
return gateway_model_to_gateway(gateway)
182212

183213

@@ -214,6 +244,7 @@ async def delete_gateways(
214244
session: AsyncSession,
215245
project: ProjectModel,
216246
gateways_names: List[str],
247+
user: UserModel,
217248
):
218249
res = await session.execute(
219250
select(GatewayModel).where(
@@ -273,46 +304,51 @@ async def delete_gateways(
273304
gateway_model.gateway_compute.deleted = True
274305
session.add(gateway_model.gateway_compute)
275306
await session.delete(gateway_model)
307+
events.emit(
308+
session,
309+
"Gateway deleted",
310+
actor=events.UserActor.from_user(user),
311+
targets=[events.Target.from_model(gateway_model)],
312+
)
276313
await session.commit()
277314

278315

279316
async def set_gateway_wildcard_domain(
280-
session: AsyncSession, project: ProjectModel, name: str, wildcard_domain: Optional[str]
317+
session: AsyncSession,
318+
project: ProjectModel,
319+
name: str,
320+
wildcard_domain: Optional[str],
321+
user: UserModel,
281322
) -> Gateway:
282-
gateway = await get_project_gateway_model_by_name(
283-
session=session,
284-
project=project,
285-
name=name,
286-
)
287-
if gateway is None:
288-
raise ResourceNotExistsError()
289-
if gateway.backend.type == BackendType.DSTACK:
290-
raise ServerClientError("Custom domains for dstack Sky gateway are not supported")
291-
await session.execute(
292-
update(GatewayModel)
293-
.where(
294-
GatewayModel.project_id == project.id,
295-
GatewayModel.name == name,
296-
)
297-
.values(
298-
wildcard_domain=wildcard_domain,
299-
)
300-
)
301-
await session.commit()
302-
gateway = await get_project_gateway_model_by_name(
303-
session=session,
304-
project=project,
305-
name=name,
306-
)
307-
if gateway is None:
308-
raise ResourceNotExistsError()
323+
async with get_project_gateway_model_by_name_for_update(
324+
session=session, project=project, name=name
325+
) as gateway:
326+
if gateway is None:
327+
raise ResourceNotExistsError()
328+
if gateway.backend.type == BackendType.DSTACK:
329+
raise ServerClientError("Custom domains for dstack Sky gateway are not supported")
330+
old_domain = gateway.wildcard_domain
331+
if old_domain != wildcard_domain:
332+
gateway.wildcard_domain = wildcard_domain
333+
events.emit(
334+
session,
335+
f"Gateway wildcard domain changed {old_domain!r} -> {gateway.wildcard_domain!r}",
336+
actor=events.UserActor.from_user(user),
337+
targets=[events.Target.from_model(gateway)],
338+
)
339+
await session.commit()
309340
return gateway_model_to_gateway(gateway)
310341

311342

312-
async def set_default_gateway(session: AsyncSession, project: ProjectModel, name: str):
343+
async def set_default_gateway(
344+
session: AsyncSession, project: ProjectModel, name: str, user: Optional[UserModel]
345+
):
313346
gateway = await get_project_gateway_model_by_name(session=session, project=project, name=name)
314347
if gateway is None:
315348
raise ResourceNotExistsError()
349+
if project.default_gateway_id == gateway.id:
350+
return
351+
previous_gateway = await get_project_default_gateway_model(session, project)
316352
await session.execute(
317353
update(ProjectModel)
318354
.where(
@@ -322,6 +358,19 @@ async def set_default_gateway(session: AsyncSession, project: ProjectModel, name
322358
default_gateway_id=gateway.id,
323359
)
324360
)
361+
if previous_gateway is not None:
362+
events.emit(
363+
session,
364+
"Gateway unset as default",
365+
actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(),
366+
targets=[events.Target.from_model(previous_gateway)],
367+
)
368+
events.emit(
369+
session,
370+
"Gateway set as default",
371+
actor=events.UserActor.from_user(user) if user is not None else events.SystemActor(),
372+
targets=[events.Target.from_model(gateway)],
373+
)
325374
await session.commit()
326375

327376

@@ -343,6 +392,38 @@ async def get_project_gateway_model_by_name(
343392
return res.scalar()
344393

345394

395+
@asynccontextmanager
396+
async def get_project_gateway_model_by_name_for_update(
397+
session: AsyncSession, project: ProjectModel, name: str
398+
) -> AsyncGenerator[Optional[GatewayModel], None]:
399+
"""
400+
Fetch the gateway from the database and lock it for update.
401+
402+
**NOTE**: commit changes to the database before exiting from this context manager,
403+
so that in-memory locks are only released after commit.
404+
"""
405+
406+
filters = [
407+
GatewayModel.project_id == project.id,
408+
GatewayModel.name == name,
409+
]
410+
res = await session.execute(select(GatewayModel.id).where(*filters))
411+
gateway_id = res.scalar_one_or_none()
412+
if gateway_id is None:
413+
yield None
414+
else:
415+
async with get_locker(get_db().dialect_name).lock_ctx(
416+
GatewayModel.__tablename__, [gateway_id]
417+
):
418+
# Refetch after lock
419+
res = await session.execute(
420+
select(GatewayModel)
421+
.where(GatewayModel.id.in_([gateway_id]), *filters)
422+
.with_for_update(key_share=True, of=GatewayModel)
423+
)
424+
yield res.scalar_one_or_none()
425+
426+
346427
async def get_project_default_gateway_model(
347428
session: AsyncSession, project: ProjectModel
348429
) -> Optional[GatewayModel]:

src/dstack/_internal/server/testing/common.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from uuid import UUID
88

99
import gpuhunt
10-
from sqlalchemy import select
10+
from sqlalchemy import delete, select
1111
from sqlalchemy.ext.asyncio import AsyncSession
12+
from sqlalchemy.orm import joinedload
1213

1314
from dstack._internal.core.backends.base.compute import (
1415
Compute,
@@ -1114,8 +1115,16 @@ async def create_secret(
11141115

11151116

11161117
async def list_events(session: AsyncSession) -> list[EventModel]:
1117-
res = await session.execute(select(EventModel).order_by(EventModel.recorded_at, EventModel.id))
1118-
return list(res.scalars().all())
1118+
res = await session.execute(
1119+
select(EventModel)
1120+
.order_by(EventModel.recorded_at, EventModel.id)
1121+
.options(joinedload(EventModel.targets))
1122+
)
1123+
return list(res.scalars().unique().all())
1124+
1125+
1126+
async def clear_events(session: AsyncSession) -> None:
1127+
await session.execute(delete(EventModel))
11191128

11201129

11211130
def get_private_key_string() -> str:

0 commit comments

Comments
 (0)