Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/dstack/_internal/server/background/pipeline_tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ async def heartbeat(self):

class Fetcher(Generic[ItemT], ABC):
_DEFAULT_FETCH_DELAYS = [0.5, 1, 2, 5]
"""Increasing fetch delays on empty fetches to avoid frequent selects on low-activity/low-resource servers."""

def __init__(
self,
Expand Down Expand Up @@ -319,7 +320,15 @@ async def fetch(self, limit: int) -> list[ItemT]:
pass

def _next_fetch_delay(self, empty_fetch_count: int) -> float:
next_delay = self._fetch_delays[min(empty_fetch_count, len(self._fetch_delays) - 1)]
effective_empty_fetch_count = empty_fetch_count
if random.random() < 0.1:
# Empty fetch count can be 0 not because there are no items in the DB,
# but for other reasons such as waiting parent resource processing.
# From time to time, force minimal next delay to avoid empty results due to rare fetches.
effective_empty_fetch_count = 0
next_delay = self._fetch_delays[
min(effective_empty_fetch_count, len(self._fetch_delays) - 1)
]
jitter = random.random() * 0.4 - 0.2
return next_delay * (1 + jitter)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
workers_num: int = 20,
queue_lower_limit_factor: float = 0.5,
queue_upper_limit_factor: float = 2.0,
min_processing_interval: timedelta = timedelta(seconds=10),
min_processing_interval: timedelta = timedelta(seconds=5),
lock_timeout: timedelta = timedelta(seconds=30),
heartbeat_trigger: timedelta = timedelta(seconds=15),
) -> None:
Expand Down Expand Up @@ -196,7 +196,19 @@ async def fetch(self, limit: int) -> list[JobRunningPipelineItem]:
[JobStatus.PROVISIONING, JobStatus.PULLING, JobStatus.RUNNING]
),
RunModel.status.not_in([RunStatus.TERMINATING]),
JobModel.last_processed_at <= now - self._min_processing_interval,
or_(
# Process provisioning and pulling jobs quicker for low-latency provisioning.
# Active jobs processing can be less frequent to minimize contention with `RunPipeline`.
and_(
JobModel.status.in_([JobStatus.PROVISIONING, JobStatus.PULLING]),
JobModel.last_processed_at <= now - self._min_processing_interval,
),
and_(
JobModel.status.in_([JobStatus.RUNNING]),
JobModel.last_processed_at
<= now - self._min_processing_interval * 2,
),
),
or_(
and_(
# Do not try to lock jobs if the run is waiting for the lock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
get_job_provisioning_data,
get_job_runtime_data,
get_job_spec,
stop_runner,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.logging import fmt
Expand Down Expand Up @@ -265,8 +266,10 @@ class _JobUpdateMap(ItemUpdateMap, total=False):
termination_reason: Optional[JobTerminationReason]
termination_reason_message: Optional[str]
instance_id: Optional[uuid.UUID]
graceful_termination_attempts: int
volumes_detached_at: UpdateMapDateTime
registered: bool
remove_at: UpdateMapDateTime


class _InstanceUpdateMap(ItemUpdateMap, total=False):
Expand Down Expand Up @@ -580,9 +583,11 @@ async def _process_terminating_job(
instance_model: Optional[InstanceModel],
) -> _ProcessResult:
"""
Stops the job: tells shim to stop the container, detaches the job from the instance,
and detaches volumes from the instance.
Graceful stop should already be done by the run terminating path.
Terminates the job:
1. tells the runner to stop the job's command
2. tells the shim to stop the container
3. detaches the job from the instance
4. and detaches volumes from the instance.
"""
instance_update_map = None if instance_model is None else _InstanceUpdateMap()
result = _ProcessResult(instance_update_map=instance_update_map)
Expand All @@ -592,6 +597,10 @@ async def _process_terminating_job(
result.job_update_map["status"] = _get_job_termination_status(job_model)
return result

if job_model.graceful_termination_attempts == 0 and job_model.remove_at is None:
result.job_update_map = await _stop_job_gracefully(job_model, instance_model)
return result

jrd = get_job_runtime_data(job_model)
jpd = get_job_provisioning_data(job_model)
if jpd is not None:
Expand Down Expand Up @@ -642,6 +651,20 @@ async def _process_terminating_job(
return result


async def _stop_job_gracefully(
job_model: JobModel, instance_model: InstanceModel
) -> _JobUpdateMap:
"""
Tells the runner to stop the job's command. Records the first graceful-stop attempt and
sets `remove_at` so `_process_terminating_job()` stops the container on a later iteration.
"""
job_update_map = _JobUpdateMap()
await stop_runner(job_model=job_model, instance_model=instance_model)
job_update_map["graceful_termination_attempts"] = 1
job_update_map["remove_at"] = get_current_datetime() + timedelta(seconds=10)
return job_update_map


async def _process_job_volumes_detaching(
job_model: JobModel,
instance_model: InstanceModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
workers_num: int = 10,
queue_lower_limit_factor: float = 0.5,
queue_upper_limit_factor: float = 2.0,
min_processing_interval: timedelta = timedelta(seconds=10),
min_processing_interval: timedelta = timedelta(seconds=5),
lock_timeout: timedelta = timedelta(seconds=30),
heartbeat_trigger: timedelta = timedelta(seconds=15),
) -> None:
Expand Down Expand Up @@ -164,7 +164,17 @@ async def fetch(self, limit: int) -> list[RunPipelineItem]:
),
),
or_(
RunModel.last_processed_at <= now - self._min_processing_interval,
# Process submitted runs quicker for low-latency provisioning.
# Active run processing can be less frequent to minimize contention with `JobRunningPipeline`.
and_(
RunModel.status == RunStatus.SUBMITTED,
RunModel.last_processed_at <= now - self._min_processing_interval,
),
and_(
RunModel.status != RunStatus.SUBMITTED,
RunModel.last_processed_at
<= now - self._min_processing_interval * 2,
),
RunModel.last_processed_at == RunModel.submitted_at,
),
or_(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from datetime import datetime
from typing import Optional

import httpx
Expand All @@ -17,10 +17,9 @@
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.services import events
from dstack._internal.server.services.gateways import get_or_add_gateway_connection
from dstack._internal.server.services.jobs import stop_runner
from dstack._internal.server.services.logging import fmt
from dstack._internal.server.services.runs import _get_next_triggered_at, get_run_spec
from dstack._internal.utils.common import get_current_datetime, get_or_error
from dstack._internal.utils.common import get_or_error
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -35,7 +34,7 @@ class TerminatingRunUpdateMap(ItemUpdateMap, total=False):
class TerminatingRunJobUpdateMap(ItemUpdateMap, total=False):
status: JobStatus
termination_reason: Optional[JobTerminationReason]
remove_at: Optional[datetime]
graceful_termination_attempts: int


@dataclass
Expand Down Expand Up @@ -77,10 +76,6 @@ async def process_terminating_run(context: TerminatingContext) -> TerminatingRes
JobTerminationReason.ABORTED_BY_USER,
JobTerminationReason.DONE_BY_RUNNER,
}:
# Send a signal to stop the job gracefully.
await stop_runner(
job_model=job_model, instance_model=get_or_error(job_model.instance)
)
delayed_job_ids.append(job_model.id)
continue
regular_job_ids.append(job_model.id)
Expand Down Expand Up @@ -123,7 +118,7 @@ def _get_job_id_to_update_map(
job_id_to_update_map[job_id] = TerminatingRunJobUpdateMap(
status=JobStatus.TERMINATING,
termination_reason=job_termination_reason,
remove_at=get_current_datetime() + timedelta(seconds=15),
graceful_termination_attempts=0,
)
return job_id_to_update_map

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Add JobModel.graceful_termination_attempts

Revision ID: e9d81c97c042
Revises: 59e328ced74c
Create Date: 2026-03-30 08:41:29.308250+00:00

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "e9d81c97c042"
down_revision = "59e328ced74c"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("jobs", schema=None) as batch_op:
batch_op.add_column(
sa.Column("graceful_termination_attempts", sa.Integer(), nullable=True)
)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("jobs", schema=None) as batch_op:
batch_op.drop_column("graceful_termination_attempts")

# ### end Alembic commands ###
10 changes: 9 additions & 1 deletion src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,16 @@ class JobModel(PipelineModelMixin, BaseModel):
runner_timestamp: Mapped[Optional[int]] = mapped_column(BigInteger)
inactivity_secs: Mapped[Optional[int]] = mapped_column(Integer)
"""`inactivity_secs` uses `0` for active jobs and `None` when inactivity is not applicable."""
graceful_termination_attempts: Mapped[Optional[int]] = mapped_column(Integer)
"""`graceful_termination_attempts` is used for terminating jobs.
* `None` means graceful termination is not needed
* `0` means it is needed but not attempted,
* `>= 1` means at least one graceful stop attempt was sent.
"""
remove_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
"""`remove_at` is used to ensure the instance is killed after the job is finished."""
"""`remove_at` is used to ensure the container/instance is killed after the job is gracefully finished.
Cannot kill the container/instance until `remove_at` is set.
"""
volumes_detached_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
instance_assigned: Mapped[bool] = mapped_column(Boolean, default=False)
"""`instance_assigned` shows whether instance assignment has already been attempted.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
JobTerminatingPipeline,
)
from dstack._internal.server.background.pipeline_tasks.runs import RunPipeline, RunWorker
from dstack._internal.server.background.pipeline_tasks.runs.terminating import (
TerminatingResult,
process_terminating_run,
)
from dstack._internal.server.testing.common import (
create_fleet,
create_instance,
Expand Down Expand Up @@ -84,32 +88,14 @@ async def test_transitions_running_jobs_to_terminating(
)
lock_run(run)
await session.commit()
item = run_to_pipeline_item(run)
observed_job_lock = {}

async def record_stop_call(**kwargs) -> None:
observed_job_lock["lock_token"] = kwargs["job_model"].lock_token
observed_job_lock["lock_owner"] = kwargs["job_model"].lock_owner

with patch(
"dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner",
new=AsyncMock(side_effect=record_stop_call),
) as stop_runner:
await worker.process(item)

assert stop_runner.await_count == 1
stop_call = stop_runner.await_args
assert stop_call is not None
assert stop_call.kwargs["job_model"].id == job.id
assert observed_job_lock["lock_token"] == item.lock_token
assert observed_job_lock["lock_owner"] == RunPipeline.__name__
assert stop_call.kwargs["instance_model"].id == instance.id
await worker.process(run_to_pipeline_item(run))

await session.refresh(job)
await session.refresh(run)
assert job.status == JobStatus.TERMINATING
assert job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER
assert job.remove_at is not None
assert job.graceful_termination_attempts == 0
assert job.remove_at is None
assert job.lock_token is None
assert job.lock_expires_at is None
assert job.lock_owner is None
Expand Down Expand Up @@ -154,19 +140,17 @@ async def test_updates_delayed_and_regular_jobs_separately(
lock_run(run)
await session.commit()

with patch(
"dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner",
new=AsyncMock(),
):
await worker.process(run_to_pipeline_item(run))
await worker.process(run_to_pipeline_item(run))

await session.refresh(delayed_job)
await session.refresh(regular_job)
assert delayed_job.status == JobStatus.TERMINATING
assert delayed_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER
assert delayed_job.remove_at is not None
assert delayed_job.graceful_termination_attempts == 0
assert delayed_job.remove_at is None
assert regular_job.status == JobStatus.TERMINATING
assert regular_job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER
assert regular_job.graceful_termination_attempts is None
assert regular_job.remove_at is None

async def test_finishes_non_scheduled_run_when_all_jobs_are_finished(
Expand Down Expand Up @@ -273,14 +257,16 @@ async def test_noops_when_run_lock_changes_after_processing(
await session.commit()
item = run_to_pipeline_item(run)
new_lock_token = uuid.uuid4()
original_process_terminating_run = process_terminating_run

async def change_run_lock(**kwargs) -> None:
async def change_run_lock(context) -> TerminatingResult:
run.lock_token = new_lock_token
run.lock_expires_at = get_current_datetime() + timedelta(minutes=1)
await session.commit()
return await original_process_terminating_run(context)

with patch(
"dstack._internal.server.background.pipeline_tasks.runs.terminating.stop_runner",
"dstack._internal.server.background.pipeline_tasks.runs.terminating.process_terminating_run",
new=AsyncMock(side_effect=change_run_lock),
):
await worker.process(item)
Expand All @@ -289,7 +275,10 @@ async def change_run_lock(**kwargs) -> None:
await session.refresh(job)
assert run.status == RunStatus.TERMINATING
assert run.lock_token == new_lock_token
assert run.lock_owner == RunPipeline.__name__
assert job.status == JobStatus.RUNNING
assert job.graceful_termination_attempts is None
assert job.remove_at is None
assert job.lock_token is None
assert job.lock_expires_at is None
assert job.lock_owner is None
Expand Down
Loading
Loading