Skip to content

Commit 1660cc7

Browse files
committed
Do not lose provisioning gateways on restart
- Split processing `submitted` and `provisioning` gateways into two processing iterations, so that gateways in the `provisioning` status are not lost on server restarts. - Decrease the number of gateway configuration attempts during server startup. 50 attempts are only necessary when processing a provisioning gateway. - Make gateway processing logging more comprehensive and consistent.
1 parent db55a99 commit 1660cc7

File tree

5 files changed

+111
-81
lines changed

5 files changed

+111
-81
lines changed

src/dstack/_internal/server/background/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from dstack._internal.server import settings
55
from dstack._internal.server.background.tasks.process_fleets import process_fleets
66
from dstack._internal.server.background.tasks.process_gateways import (
7+
process_gateways,
78
process_gateways_connections,
8-
process_submitted_gateways,
99
)
1010
from dstack._internal.server.background.tasks.process_instances import (
1111
process_instances,
@@ -70,9 +70,7 @@ def start_background_tasks() -> AsyncIOScheduler:
7070
)
7171
_scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1)
7272
_scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15))
73-
_scheduler.add_job(
74-
process_submitted_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5
75-
)
73+
_scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5)
7674
_scheduler.add_job(
7775
process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5
7876
)

src/dstack/_internal/server/background/tasks/process_gateways.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
gateway_connections_pool,
1717
)
1818
from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker
19+
from dstack._internal.server.services.logging import fmt
1920
from dstack._internal.utils.common import get_current_datetime
2021
from dstack._internal.utils.logging import get_logger
2122

@@ -27,14 +28,14 @@ async def process_gateways_connections():
2728
await _process_active_connections()
2829

2930

30-
async def process_submitted_gateways():
31+
async def process_gateways():
3132
lock, lockset = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__)
3233
async with get_session_ctx() as session:
3334
async with lock:
3435
res = await session.execute(
3536
select(GatewayModel)
3637
.where(
37-
GatewayModel.status == GatewayStatus.SUBMITTED,
38+
GatewayModel.status.in_([GatewayStatus.SUBMITTED, GatewayStatus.PROVISIONING]),
3839
GatewayModel.id.not_in(lockset),
3940
)
4041
.options(lazyload(GatewayModel.gateway_compute))
@@ -48,7 +49,25 @@ async def process_submitted_gateways():
4849
lockset.add(gateway_model.id)
4950
try:
5051
gateway_model_id = gateway_model.id
51-
await _process_submitted_gateway(session=session, gateway_model=gateway_model)
52+
initial_status = gateway_model.status
53+
if initial_status == GatewayStatus.SUBMITTED:
54+
await _process_submitted_gateway(session=session, gateway_model=gateway_model)
55+
elif initial_status == GatewayStatus.PROVISIONING:
56+
await _process_provisioning_gateway(session=session, gateway_model=gateway_model)
57+
else:
58+
logger.error(
59+
"%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper()
60+
)
61+
if gateway_model.status != initial_status:
62+
logger.info(
63+
"%s: gateway status has changed %s -> %s%s",
64+
fmt(gateway_model),
65+
initial_status.upper(),
66+
gateway_model.status.upper(),
67+
f": {gateway_model.status_message}" if gateway_model.status_message else "",
68+
)
69+
gateway_model.last_processed_at = get_current_datetime()
70+
await session.commit()
5271
finally:
5372
lockset.difference_update([gateway_model_id])
5473

@@ -89,7 +108,7 @@ async def _process_connection(conn: GatewayConnection):
89108

90109

91110
async def _process_submitted_gateway(session: AsyncSession, gateway_model: GatewayModel):
92-
logger.info("Started gateway %s provisioning", gateway_model.name)
111+
logger.info("%s: started gateway provisioning", fmt(gateway_model))
93112
# Refetch to load related attributes.
94113
# joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
95114
res = await session.execute(
@@ -110,8 +129,6 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
110129
except BackendNotAvailable:
111130
gateway_model.status = GatewayStatus.FAILED
112131
gateway_model.status_message = "Backend not available"
113-
gateway_model.last_processed_at = get_current_datetime()
114-
await session.commit()
115132
return
116133

117134
try:
@@ -123,53 +140,54 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
123140
)
124141
session.add(gateway_model)
125142
gateway_model.status = GatewayStatus.PROVISIONING
126-
await session.commit()
127-
await session.refresh(gateway_model)
128143
except BackendError as e:
129-
logger.info(
130-
"Failed to create gateway compute for gateway %s: %s", gateway_model.name, repr(e)
131-
)
144+
logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e)
132145
gateway_model.status = GatewayStatus.FAILED
133146
status_message = f"Backend error: {repr(e)}"
134147
if len(e.args) > 0:
135148
status_message = str(e.args[0])
136149
gateway_model.status_message = status_message
137-
gateway_model.last_processed_at = get_current_datetime()
138-
await session.commit()
139-
return
140150
except Exception as e:
141-
logger.exception(
142-
"Got exception when creating gateway compute for gateway %s", gateway_model.name
143-
)
151+
logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model))
144152
gateway_model.status = GatewayStatus.FAILED
145153
gateway_model.status_message = f"Unexpected error: {repr(e)}"
146-
gateway_model.last_processed_at = get_current_datetime()
147-
await session.commit()
148-
return
149154

155+
156+
async def _process_provisioning_gateway(
157+
session: AsyncSession, gateway_model: GatewayModel
158+
) -> None:
159+
# Refetch to load related attributes.
160+
# joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
161+
res = await session.execute(
162+
select(GatewayModel)
163+
.where(GatewayModel.id == gateway_model.id)
164+
.execution_options(populate_existing=True)
165+
)
166+
gateway_model = res.unique().scalar_one()
167+
168+
# FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway:
169+
# - cannot delete the gateway before it is provisioned because the DB model is locked
170+
# - connection retry counter is reset on server restart
171+
# - only one server replica is processing the gateway
172+
# Easy to fix by doing only one connection/configuration attempt per processing iteration. The
173+
# main challenge is applying the same provisioning model to the dstack Sky gateway to avoid
174+
# maintaining a different model for Sky.
150175
connection = await gateways_services.connect_to_gateway_with_retry(
151176
gateway_model.gateway_compute
152177
)
153178
if connection is None:
154179
gateway_model.status = GatewayStatus.FAILED
155180
gateway_model.status_message = "Failed to connect to gateway"
156-
gateway_model.last_processed_at = get_current_datetime()
157181
gateway_model.gateway_compute.deleted = True
158-
await session.commit()
159182
return
160-
161183
try:
162184
await gateways_services.configure_gateway(connection)
163185
except Exception:
164-
logger.exception("Failed to configure gateway %s", gateway_model.name)
186+
logger.exception("%s: failed to configure gateway", fmt(gateway_model))
165187
gateway_model.status = GatewayStatus.FAILED
166188
gateway_model.status_message = "Failed to configure gateway"
167-
gateway_model.last_processed_at = get_current_datetime()
168189
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
169190
gateway_model.gateway_compute.active = False
170-
await session.commit()
171191
return
172192

173193
gateway_model.status = GatewayStatus.RUNNING
174-
gateway_model.last_processed_at = get_current_datetime()
175-
await session.commit()

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import datetime
33
import uuid
44
from datetime import timedelta, timezone
5+
from functools import partial
56
from typing import List, Optional, Sequence
67

78
import httpx
@@ -186,6 +187,7 @@ async def create_gateway(
186187
return gateway_model_to_gateway(gateway)
187188

188189

190+
# NOTE: dstack Sky imports and uses this function
189191
async def connect_to_gateway_with_retry(
190192
gateway_compute: GatewayComputeModel,
191193
) -> Optional[GatewayConnection]:
@@ -380,6 +382,8 @@ async def get_or_add_gateway_connection(
380382
async def init_gateways(session: AsyncSession):
381383
res = await session.execute(
382384
select(GatewayComputeModel).where(
385+
# FIXME: should not include computes related to gateways in the `provisioning` status.
386+
# Causes warnings and delays when restarting the server during gateway provisioning.
383387
GatewayComputeModel.active == True,
384388
GatewayComputeModel.deleted == False,
385389
)
@@ -421,7 +425,8 @@ async def init_gateways(session: AsyncSession):
421425

422426
for gateway_compute, error in await gather_map_async(
423427
await gateway_connections_pool.all(),
424-
configure_gateway,
428+
# Need several attempts to handle short gateway downtime after update
429+
partial(configure_gateway, attempts=7),
425430
return_exceptions=True,
426431
):
427432
if isinstance(error, Exception):
@@ -461,15 +466,19 @@ def _recently_updated(gateway_compute_model: GatewayComputeModel) -> bool:
461466
) > get_current_datetime() - timedelta(seconds=60)
462467

463468

464-
async def configure_gateway(connection: GatewayConnection) -> None:
469+
# NOTE: dstack Sky imports and uses this function
470+
async def configure_gateway(
471+
connection: GatewayConnection,
472+
attempts: int = GATEWAY_CONFIGURE_ATTEMPTS,
473+
) -> None:
465474
"""
466475
Try submitting gateway config several times in case gateway's HTTP server is not
467476
running yet
468477
"""
469478

470479
logger.debug("Configuring gateway %s", connection.ip_address)
471480

472-
for attempt in range(GATEWAY_CONFIGURE_ATTEMPTS - 1):
481+
for attempt in range(attempts - 1):
473482
try:
474483
async with connection.client() as client:
475484
await client.submit_gateway_config()
@@ -478,7 +487,7 @@ async def configure_gateway(connection: GatewayConnection) -> None:
478487
logger.debug(
479488
"Failed attempt %s/%s at configuring gateway %s: %r",
480489
attempt + 1,
481-
GATEWAY_CONFIGURE_ATTEMPTS,
490+
attempts,
482491
connection.ip_address,
483492
e,
484493
)
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from typing import Union
22

3-
from dstack._internal.server.models import JobModel, RunModel
3+
from dstack._internal.server.models import GatewayModel, JobModel, RunModel
44

55

6-
def fmt(model: Union[RunModel, JobModel]) -> str:
6+
def fmt(model: Union[RunModel, JobModel, GatewayModel]) -> str:
77
"""Consistent string representation of a model for logging."""
88
if isinstance(model, RunModel):
99
return f"run({model.id.hex[:6]}){model.run_name}"
1010
if isinstance(model, JobModel):
1111
return f"job({model.id.hex[:6]}){model.job_name}"
12+
if isinstance(model, GatewayModel):
13+
return f"gateway({model.id.hex[:6]}){model.name}"
1214
return str(model)

0 commit comments

Comments
 (0)