Skip to content

Commit 06fdac8

Browse files
committed
Only register service replicas after probes pass
To avoid service disruptions during rolling deployments, only register a newly started replica to receive requests after all its probes pass, not immediately after it becomes `running`.
1 parent 5fb7af3 commit 06fdac8

File tree

12 files changed

+361
-300
lines changed

12 files changed

+361
-300
lines changed

src/dstack/_internal/proxy/gateway/services/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ async def register_replica(
152152
)
153153

154154
if old_service.find_replica(replica_id) is not None:
155+
# NOTE: as of 0.19.25, the dstack server relies on the exact text of this error.
156+
# See dstack._internal.server.services.services.register_replica
155157
raise ProxyError(f"Replica {replica_id} already exists in service {old_service.fmt()}")
156158

157159
service = old_service.with_replicas(old_service.replicas + (replica,))

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
JobSpec,
3333
JobStatus,
3434
JobTerminationReason,
35+
ProbeSpec,
3536
Run,
3637
RunSpec,
3738
RunStatus,
@@ -70,6 +71,7 @@
7071
from dstack._internal.server.services.runner import client
7172
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
7273
from dstack._internal.server.services.runs import (
74+
is_job_ready,
7375
run_model_to_run,
7476
)
7577
from dstack._internal.server.services.secrets import get_project_secrets_mapping
@@ -140,6 +142,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
140142
select(JobModel)
141143
.where(JobModel.id == job_model.id)
142144
.options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
145+
.options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak))
143146
.execution_options(populate_existing=True)
144147
)
145148
job_model = res.unique().scalar_one()
@@ -382,52 +385,21 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
382385
job_submission.age,
383386
)
384387

385-
if (
386-
initial_status != job_model.status
387-
and job_model.status == JobStatus.RUNNING
388-
and job_model.job_num == 0 # gateway connects only to the first node
389-
and run.run_spec.configuration.type == "service"
390-
):
391-
ssh_head_proxy: Optional[SSHConnectionParams] = None
392-
ssh_head_proxy_private_key: Optional[str] = None
393-
instance = common_utils.get_or_error(job_model.instance)
394-
if instance.remote_connection_info is not None:
395-
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
396-
if rci.ssh_proxy is not None:
397-
ssh_head_proxy = rci.ssh_proxy
398-
ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
399-
ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
400-
try:
401-
await services.register_replica(
402-
session,
403-
run_model.gateway_id,
404-
run,
405-
job_model,
406-
ssh_head_proxy,
407-
ssh_head_proxy_private_key,
408-
)
409-
except GatewayError as e:
410-
logger.warning(
411-
"%s: failed to register service replica: %s, age=%s",
412-
fmt(job_model),
413-
e,
414-
job_submission.age,
415-
)
416-
job_model.status = JobStatus.TERMINATING
417-
job_model.termination_reason = JobTerminationReason.GATEWAY_ERROR
418-
else:
419-
for probe_num in range(len(job.job_spec.probes)):
420-
session.add(
421-
ProbeModel(
422-
name=f"{job_model.job_name}-{probe_num}",
423-
job=job_model,
424-
probe_num=probe_num,
425-
due=common_utils.get_current_datetime(),
426-
success_streak=0,
427-
active=True,
428-
)
388+
if initial_status != job_model.status and job_model.status == JobStatus.RUNNING:
389+
job_model.probes = []
390+
for probe_num in range(len(job.job_spec.probes)):
391+
job_model.probes.append(
392+
ProbeModel(
393+
name=f"{job_model.job_name}-{probe_num}",
394+
probe_num=probe_num,
395+
due=common_utils.get_current_datetime(),
396+
success_streak=0,
397+
active=True,
429398
)
399+
)
430400

401+
if job_model.status == JobStatus.RUNNING:
402+
await _maybe_register_replica(session, run_model, run, job_model, job.job_spec.probes)
431403
if job_model.status == JobStatus.RUNNING:
432404
await _check_gpu_utilization(session, job_model, job)
433405

@@ -822,6 +794,55 @@ def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool:
822794
)
823795

824796

797+
async def _maybe_register_replica(
798+
session: AsyncSession,
799+
run_model: RunModel,
800+
run: Run,
801+
job_model: JobModel,
802+
probe_specs: Iterable[ProbeSpec],
803+
) -> None:
804+
"""
805+
Register the replica represented by this job to receive service requests if it is ready.
806+
"""
807+
808+
if (
809+
run.run_spec.configuration.type != "service"
810+
or job_model.registered
811+
or job_model.job_num != 0 # only the first job in the replica receives service requests
812+
or not is_job_ready(job_model.probes, probe_specs)
813+
):
814+
return
815+
816+
ssh_head_proxy: Optional[SSHConnectionParams] = None
817+
ssh_head_proxy_private_key: Optional[str] = None
818+
instance = common_utils.get_or_error(job_model.instance)
819+
if instance.remote_connection_info is not None:
820+
rci: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw(
821+
instance.remote_connection_info
822+
)
823+
if rci.ssh_proxy is not None:
824+
ssh_head_proxy = rci.ssh_proxy
825+
ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
826+
ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
827+
try:
828+
await services.register_replica(
829+
session,
830+
run_model.gateway_id,
831+
run,
832+
job_model,
833+
ssh_head_proxy,
834+
ssh_head_proxy_private_key,
835+
)
836+
except GatewayError as e:
837+
logger.warning(
838+
"%s: failed to register service replica: %s",
839+
fmt(job_model),
840+
e,
841+
)
842+
job_model.status = JobStatus.TERMINATING
843+
job_model.termination_reason = JobTerminationReason.GATEWAY_ERROR
844+
845+
825846
async def _check_gpu_utilization(session: AsyncSession, job_model: JobModel, job: Job) -> None:
826847
policy = job.job_spec.utilization_policy
827848
if policy is None:

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from dstack._internal.server.models import (
2424
InstanceModel,
2525
JobModel,
26-
ProbeModel,
2726
ProjectModel,
2827
RunModel,
2928
UserModel,
@@ -37,7 +36,7 @@
3736
from dstack._internal.server.services.prometheus.client_metrics import run_metrics
3837
from dstack._internal.server.services.runs import (
3938
fmt,
40-
is_replica_ready,
39+
is_replica_registered,
4140
process_terminating_run,
4241
retry_run_replica_jobs,
4342
run_model_to_run,
@@ -151,11 +150,6 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
151150
.joinedload(JobModel.instance)
152151
.load_only(InstanceModel.fleet_id)
153152
)
154-
.options(
155-
selectinload(RunModel.jobs)
156-
.joinedload(JobModel.probes)
157-
.load_only(ProbeModel.success_streak)
158-
)
159153
.execution_options(populate_existing=True)
160154
)
161155
run_model = res.unique().scalar_one()
@@ -465,6 +459,9 @@ async def _handle_run_replicas(
465459
run_spec=run_spec,
466460
)
467461
if _has_out_of_date_replicas(run_model):
462+
assert run_spec.configuration.type == "service", (
463+
"Rolling deployment is only supported for services"
464+
)
468465
non_terminated_replica_count = len(
469466
{j.replica_num for j in run_model.jobs if not j.status.is_finished()}
470467
)
@@ -479,22 +476,24 @@ async def _handle_run_replicas(
479476
)
480477

481478
replicas_to_stop_count = 0
482-
# stop any out-of-date replicas that are not ready
479+
# stop any out-of-date replicas that are not registered
483480
replicas_to_stop_count += sum(
484481
any(j.deployment_num < run_model.deployment_num for j in jobs)
485482
and any(
486483
j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses()
487484
for j in jobs
488485
)
489-
and not is_replica_ready(jobs)
486+
and not is_replica_registered(jobs)
487+
for _, jobs in group_jobs_by_replica_latest(run_model.jobs)
488+
)
489+
# stop excessive registered out-of-date replicas, except those that are already `terminating`
490+
non_terminating_registered_replicas_count = sum(
491+
is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs)
490492
for _, jobs in group_jobs_by_replica_latest(run_model.jobs)
491493
)
492-
ready_replica_count = sum(
493-
is_replica_ready(jobs) for _, jobs in group_jobs_by_replica_latest(run_model.jobs)
494+
replicas_to_stop_count += max(
495+
0, non_terminating_registered_replicas_count - run_model.desired_replica_count
494496
)
495-
if ready_replica_count > run_model.desired_replica_count:
496-
# stop excessive ready out-of-date replicas
497-
replicas_to_stop_count += ready_replica_count - run_model.desired_replica_count
498497
if replicas_to_stop_count:
499498
await scale_run_replicas(
500499
session,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Add JobModel.registered
2+
3+
Revision ID: 3d7f6c2ec000
4+
Revises: 74a1f55209bd
5+
Create Date: 2025-08-11 13:23:39.530103
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "3d7f6c2ec000"
14+
down_revision = "74a1f55209bd"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
with op.batch_alter_table("jobs", schema=None) as batch_op:
21+
batch_op.add_column(
22+
sa.Column("registered", sa.Boolean(), server_default=sa.false(), nullable=False)
23+
)
24+
25+
26+
def downgrade() -> None:
27+
with op.batch_alter_table("jobs", schema=None) as batch_op:
28+
batch_op.drop_column("registered")

src/dstack/_internal/server/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ class JobModel(BaseModel):
430430
probes: Mapped[list["ProbeModel"]] = relationship(
431431
back_populates="job", order_by="ProbeModel.probe_num"
432432
)
433+
# Whether the replica is registered to receive service requests.
434+
# Always `False` for non-service runs.
435+
registered: Mapped[bool] = mapped_column(Boolean, server_default=false())
433436

434437

435438
class GatewayModel(BaseModel):
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
from dstack._internal.core.models.runs import Probe
1+
from dstack._internal.core.models.runs import Probe, ProbeSpec
22
from dstack._internal.server.models import ProbeModel
33

44

55
def probe_model_to_probe(probe_model: ProbeModel) -> Probe:
66
return Probe(success_streak=probe_model.success_streak)
7+
8+
9+
def is_probe_ready(probe: ProbeModel, spec: ProbeSpec) -> bool:
10+
return probe.success_streak >= spec.ready_after

src/dstack/_internal/server/services/proxy/repo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
5454
RunModel.gateway_id.is_(None),
5555
JobModel.run_name == run_name,
5656
JobModel.status == JobStatus.RUNNING,
57+
JobModel.registered == True,
5758
JobModel.job_num == 0,
5859
)
5960
.options(

src/dstack/_internal/server/services/runs.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
JobStatus,
4242
JobSubmission,
4343
JobTerminationReason,
44+
ProbeSpec,
4445
Run,
4546
RunPlan,
4647
RunSpec,
@@ -58,6 +59,7 @@
5859
from dstack._internal.server.db import get_db
5960
from dstack._internal.server.models import (
6061
JobModel,
62+
ProbeModel,
6163
ProjectModel,
6264
RepoModel,
6365
RunModel,
@@ -86,6 +88,7 @@
8688
from dstack._internal.server.services.logging import fmt
8789
from dstack._internal.server.services.offers import get_offers_by_requirements
8890
from dstack._internal.server.services.plugins import apply_plugin_policies
91+
from dstack._internal.server.services.probes import is_probe_ready
8992
from dstack._internal.server.services.projects import list_user_project_models
9093
from dstack._internal.server.services.resources import set_resources_defaults
9194
from dstack._internal.server.services.secrets import get_project_secrets_mapping
@@ -1185,8 +1188,8 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
11851188
elif {JobStatus.PROVISIONING, JobStatus.PULLING} & statuses:
11861189
# if there are any provisioning or pulling jobs, the replica is active and has the importance of 1
11871190
active_replicas.append((1, is_out_of_date, replica_num, replica_jobs))
1188-
elif not is_replica_ready(replica_jobs):
1189-
# all jobs are running, but probes are failing, the replica is active and has the importance of 2
1191+
elif not is_replica_registered(replica_jobs):
1192+
# all jobs are running, but not receiving traffic, the replica is active and has the importance of 2
11901193
active_replicas.append((2, is_out_of_date, replica_num, replica_jobs))
11911194
else:
11921195
# all jobs are running and ready, the replica is active and has the importance of 3
@@ -1273,15 +1276,13 @@ async def retry_run_replica_jobs(
12731276
session.add(new_job_model)
12741277

12751278

1276-
def is_replica_ready(jobs: Iterable[JobModel]) -> bool:
1277-
if not all(job.status == JobStatus.RUNNING for job in jobs):
1278-
return False
1279-
for job in jobs:
1280-
job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
1281-
for probe_spec, probe in zip(job_spec.probes, job.probes):
1282-
if probe.success_streak < probe_spec.ready_after:
1283-
return False
1284-
return True
1279+
def is_job_ready(probes: Iterable[ProbeModel], probe_specs: Iterable[ProbeSpec]) -> bool:
1280+
return all(is_probe_ready(probe, probe_spec) for probe, probe_spec in zip(probes, probe_specs))
1281+
1282+
1283+
def is_replica_registered(jobs: list[JobModel]) -> bool:
1284+
# Only job_num=0 is supposed to receive service requests
1285+
return jobs[0].registered
12851286

12861287

12871288
def _remove_job_spec_sensitive_info(spec: JobSpec):

0 commit comments

Comments
 (0)