Skip to content

Commit c022fe4

Browse files
authored
Allow concurrent run and TERMINATING jobs processing (#3641)
* Allow concurrent run and TERMINATING jobs processing * Get run_termination_reason once
1 parent 9ceafab commit c022fe4

5 files changed

Lines changed: 40 additions & 10 deletions

File tree

src/dstack/_internal/core/models/runs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ class RunTerminationReason(str, Enum):
9595
SERVER_ERROR = "server_error"
9696

9797
def to_job_termination_reason(self) -> "JobTerminationReason":
98+
"""
99+
Converts run termination reason to job termination reason.
100+
Used to set job termination reason for non-terminated jobs on run termination.
101+
"""
98102
mapping = {
99103
self.ALL_JOBS_DONE: JobTerminationReason.DONE_BY_RUNNER,
100104
self.JOB_FAILED: JobTerminationReason.TERMINATED_BY_SERVER,

src/dstack/_internal/server/background/scheduled_tasks/runs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666
logger = get_logger(__name__)
6767

6868
MIN_PROCESSING_INTERVAL = datetime.timedelta(seconds=5)
69+
70+
# No need to lock finished or terminating jobs since run processing does not update such jobs.
71+
JOB_STATUSES_EXCLUDED_FOR_LOCKING = JobStatus.finished_statuses() + [JobStatus.TERMINATING]
72+
6973
ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment
7074

7175

@@ -121,10 +125,9 @@ async def _process_next_run():
121125
)
122126
.options(
123127
joinedload(RunModel.jobs).load_only(JobModel.id),
124-
# No need to lock finished jobs
125128
with_loader_criteria(
126129
JobModel,
127-
JobModel.status.not_in(JobStatus.finished_statuses()),
130+
JobModel.status.not_in(JOB_STATUSES_EXCLUDED_FOR_LOCKING),
128131
include_aliases=True,
129132
),
130133
)
@@ -146,7 +149,7 @@ async def _process_next_run():
146149
load_only(JobModel.id),
147150
with_loader_criteria(
148151
JobModel,
149-
JobModel.status.not_in(JobStatus.finished_statuses()),
152+
JobModel.status.not_in(JOB_STATUSES_EXCLUDED_FOR_LOCKING),
150153
include_aliases=True,
151154
),
152155
)

src/dstack/_internal/server/services/jobs/__init__.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
JobSubmission,
3232
JobTerminationReason,
3333
RunSpec,
34+
RunTerminationReason,
3435
)
3536
from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus
3637
from dstack._internal.server import settings
@@ -349,6 +350,7 @@ async def process_terminating_job(
349350
if len(volume_models) > 0:
350351
logger.info("Detaching volumes: %s", [v.name for v in volume_models])
351352
all_volumes_detached = await _detach_volumes_from_job_instance(
353+
session=session,
352354
project=instance_model.project,
353355
job_model=job_model,
354356
jpd=jpd,
@@ -432,6 +434,7 @@ async def process_volumes_detaching(
432434
)
433435
logger.info("Detaching volumes: %s", [v.name for v in volume_models])
434436
all_volumes_detached = await _detach_volumes_from_job_instance(
437+
session=session,
435438
project=instance_model.project,
436439
job_model=job_model,
437440
jpd=jpd,
@@ -523,6 +526,7 @@ def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, Li
523526

524527

525528
async def _detach_volumes_from_job_instance(
529+
session: AsyncSession,
526530
project: ProjectModel,
527531
job_model: JobModel,
528532
jpd: JobProvisioningData,
@@ -542,6 +546,7 @@ async def _detach_volumes_from_job_instance(
542546

543547
all_detached = True
544548
detached_volumes = []
549+
run_termination_reason = await _get_run_termination_reason(session, job_model)
545550
for volume_model in volume_models:
546551
detached = await _detach_volume_from_job_instance(
547552
backend=backend,
@@ -550,6 +555,7 @@ async def _detach_volumes_from_job_instance(
550555
job_spec=job_spec,
551556
instance_model=instance_model,
552557
volume_model=volume_model,
558+
run_termination_reason=run_termination_reason,
553559
)
554560
if detached:
555561
detached_volumes.append(volume_model)
@@ -572,6 +578,7 @@ async def _detach_volume_from_job_instance(
572578
job_spec: JobSpec,
573579
instance_model: InstanceModel,
574580
volume_model: VolumeModel,
581+
run_termination_reason: Optional[RunTerminationReason],
575582
) -> bool:
576583
detached = True
577584
volume = volume_model_to_volume(volume_model)
@@ -601,7 +608,11 @@ async def _detach_volume_from_job_instance(
601608
volume=volume,
602609
provisioning_data=jpd,
603610
)
604-
if not detached and _should_force_detach_volume(job_model, job_spec.stop_duration):
611+
if not detached and _should_force_detach_volume(
612+
job_model,
613+
run_termination_reason=run_termination_reason,
614+
stop_duration=job_spec.stop_duration,
615+
):
605616
logger.info(
606617
"Force detaching volume %s from %s",
607618
volume_model.name,
@@ -633,13 +644,27 @@ async def _detach_volume_from_job_instance(
633644
MIN_FORCE_DETACH_WAIT_PERIOD = timedelta(seconds=60)
634645

635646

636-
def _should_force_detach_volume(job_model: JobModel, stop_duration: Optional[int]) -> bool:
647+
async def _get_run_termination_reason(
648+
session: AsyncSession, job_model: JobModel
649+
) -> Optional[RunTerminationReason]:
650+
res = await session.execute(
651+
select(RunModel.termination_reason).where(RunModel.id == job_model.run_id)
652+
)
653+
return res.scalar_one_or_none()
654+
655+
656+
def _should_force_detach_volume(
657+
job_model: JobModel,
658+
run_termination_reason: Optional[RunTerminationReason],
659+
stop_duration: Optional[int],
660+
) -> bool:
637661
return (
638662
job_model.volumes_detached_at is not None
639663
and common.get_current_datetime()
640664
> job_model.volumes_detached_at + MIN_FORCE_DETACH_WAIT_PERIOD
641665
and (
642666
job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER
667+
or run_termination_reason == RunTerminationReason.ABORTED_BY_USER
643668
or stop_duration is not None
644669
and common.get_current_datetime()
645670
> job_model.volumes_detached_at + timedelta(seconds=stop_duration)

src/dstack/_internal/server/services/runs/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,10 +1003,6 @@ async def process_terminating_run(session: AsyncSession, run_model: RunModel):
10031003
continue
10041004
unfinished_jobs_count += 1
10051005
if job_model.status == JobStatus.TERMINATING:
1006-
if job_termination_reason == JobTerminationReason.ABORTED_BY_USER:
1007-
# Override termination reason so that
1008-
# abort actions such as volume force detach are triggered
1009-
job_model.termination_reason = job_termination_reason
10101006
continue
10111007

10121008
if job_model.status == JobStatus.RUNNING and job_termination_reason not in {

src/dstack/_internal/server/testing/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ async def create_run(
305305
repo: RepoModel,
306306
user: UserModel,
307307
fleet: Optional[FleetModel] = None,
308-
run_name: str = "test-run",
308+
run_name: Optional[str] = None,
309309
status: RunStatus = RunStatus.SUBMITTED,
310310
termination_reason: Optional[RunTerminationReason] = None,
311311
submitted_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
@@ -317,6 +317,8 @@ async def create_run(
317317
resubmission_attempt: int = 0,
318318
next_triggered_at: Optional[datetime] = None,
319319
) -> RunModel:
320+
if run_name is None:
321+
run_name = "test-run"
320322
if run_spec is None:
321323
run_spec = get_run_spec(
322324
run_name=run_name,

0 commit comments

Comments
 (0)