From 084642d41db29b69816ec26a0427e356c099bf4f Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 27 Aug 2025 00:26:35 +0200 Subject: [PATCH 1/3] Automatically remove dangling tasks from shim - Terminate and remove dangling tasks from shim when processing the instance. Tasks can become dangling because shim was unavailable when the job was in the `terminating` status. - Adjust the shim API to return task statuses to avoid redundant termination. Dangling task termination and removal works with the old API too, but may result in redundant termination requests and warnings if the task is already terminated, but not removed. This is particularly noticeable if `DSTACK_SERVER_KEEP_SHIM_TASKS` is set. --- runner/docs/shim.openapi.yaml | 17 ++- runner/internal/shim/api/api_test.go | 4 +- runner/internal/shim/api/handlers.go | 3 +- runner/internal/shim/api/schemas.go | 2 +- runner/internal/shim/api/server.go | 2 +- runner/internal/shim/docker.go | 9 +- runner/internal/shim/models.go | 5 + runner/internal/shim/task.go | 11 +- .../background/tasks/process_instances.py | 9 +- src/dstack/_internal/server/schemas/runner.py | 10 ++ .../_internal/server/services/instances.py | 42 +++++- .../server/services/jobs/__init__.py | 19 ++- .../_internal/server/services/logging.py | 12 +- .../server/services/runner/client.py | 7 + .../tasks/test_process_instances.py | 120 +++++++++++++++++- 15 files changed, 246 insertions(+), 26 deletions(-) diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index 061da5d7ed..c3489cda44 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -396,13 +396,22 @@ components: title: shim.api.TaskListResponse type: object properties: - ids: + tasks: type: array items: - $ref: "#/components/schemas/TaskID" - description: A list of all task IDs tracked by shim + type: object + properties: + id: + $ref: "#/components/schemas/TaskID" + status: + allOf: + - $ref: "#/components/schemas/TaskStatus" + required: + - id + - status + description: A list of all tasks tracked by shim, each with its ID and status required: - - ids + - tasks additionalProperties: false TaskInfoResponse: diff --git a/runner/internal/shim/api/api_test.go b/runner/internal/shim/api/api_test.go index 44c30423a2..b6879187af 100644 --- a/runner/internal/shim/api/api_test.go +++ b/runner/internal/shim/api/api_test.go @@ -34,8 +34,8 @@ func (ds *DummyRunner) Remove(context.Context, string) error { return nil } -func (ds *DummyRunner) TaskIDs() []string { - return []string{} +func (ds *DummyRunner) TaskList() []*shim.TaskListItem { + return []*shim.TaskListItem{} } func (ds *DummyRunner) TaskInfo(taskID string) shim.TaskInfo { diff --git a/runner/internal/shim/api/handlers.go b/runner/internal/shim/api/handlers.go index 1374fbd803..91df9cb55f 100644 --- a/runner/internal/shim/api/handlers.go +++ b/runner/internal/shim/api/handlers.go @@ -36,7 +36,8 @@ func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Reques } func (s *ShimServer) TaskListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { - return &TaskListResponse{IDs: s.runner.TaskIDs()}, nil + tasks := s.runner.TaskList() + return &TaskListResponse{tasks}, nil } func (s *ShimServer) TaskInfoHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 7f004a4046..41d09b8ac6 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -15,7 +15,7 @@ type InstanceHealthResponse struct { } type TaskListResponse struct { - IDs []string `json:"ids"` + Tasks []*shim.TaskListItem `json:"tasks"` } type TaskInfoResponse struct { diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 4ba67a1f94..8fd7026a99 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -19,7 +19,7 @@ type TaskRunner interface { Remove(ctx context.Context, taskID string) error Resources(context.Context) shim.Resources - TaskIDs() []string + TaskList() []*shim.TaskListItem TaskInfo(taskID string) shim.TaskInfo } diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 1834188ae1..18ded881c1 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -216,8 +216,13 @@ func (d *DockerRunner) Resources(ctx context.Context) Resources { } } -func (d *DockerRunner) TaskIDs() []string { - return d.tasks.IDs() +func (d *DockerRunner) TaskList() []*TaskListItem { + tasks := d.tasks.List() + result := make([]*TaskListItem, 0, len(tasks)) + for _, task := range tasks { + result = append(result, &TaskListItem{ID: task.ID, Status: task.Status}) + } + return result } func (d *DockerRunner) TaskInfo(taskID string) TaskInfo { diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 7294c6cb9b..b8da12670d 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -104,6 +104,11 @@ type TaskConfig struct { ContainerSshKeys []string `json:"container_ssh_keys"` } +type TaskListItem struct { + ID string `json:"id"` + Status TaskStatus `json:"status"` +} + type TaskInfo struct { ID string Status TaskStatus diff --git a/runner/internal/shim/task.go b/runner/internal/shim/task.go index cd2cd92658..f1d67b785c 100644 --- a/runner/internal/shim/task.go +++ b/runner/internal/shim/task.go @@ -148,14 +148,15 @@ type TaskStorage struct { mu sync.RWMutex } -func (ts *TaskStorage) IDs() []string { +// Get a _copy_ of all tasks. To "commit" changes, use Update() +func (ts *TaskStorage) List() []Task { ts.mu.RLock() defer ts.mu.RUnlock() - ids := make([]string, 0, len(ts.tasks)) - for id := range ts.tasks { - ids = append(ids, id) + tasks := make([]Task, 0, len(ts.tasks)) + for _, task := range ts.tasks { + tasks = append(tasks, task) } - return ids + return tasks } // Get a _copy_ of the task. To "commit" changes, use Update() diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 8e2127cd78..849c7dc54c 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -85,8 +85,10 @@ get_instance_provisioning_data, get_instance_requirements, get_instance_ssh_private_keys, + remove_dangling_tasks_from_instance, ) from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.offers import is_divisible_into_blocks from dstack._internal.server.services.placement import ( get_fleet_placement_group_models, @@ -789,6 +791,7 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non ssh_private_keys, job_provisioning_data, None, + instance=instance, check_instance_health=check_instance_health, ) if instance_check is False: @@ -935,7 +938,7 @@ async def _wait_for_instance_provisioning_data( @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) def _check_instance_inner( - ports: Dict[int, int], *, check_instance_health: bool = False + ports: Dict[int, int], *, instance: InstanceModel, check_instance_health: bool = False ) -> InstanceCheck: instance_health_response: Optional[InstanceHealthResponse] = None shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) @@ -955,6 +958,10 @@ def _check_instance_inner( args = (method.__func__.__name__, e.__class__.__name__, e) logger.exception(template, *args) return InstanceCheck(reachable=False, message=template % args) + try: + remove_dangling_tasks_from_instance(shim_client, instance) + except Exception as e: + logger.exception("%s: error removing dangling tasks: %s", fmt(instance), e) return runner_client.healthcheck_response_to_instance_check( healthcheck_response, instance_health_response ) diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index f71d60055c..e954c2308b 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -159,6 +159,16 @@ class GPUDevice(CoreModel): path_in_container: str +class TaskListItem(CoreModel): + id: str + status: TaskStatus + + +class TaskListResponse(CoreModel): + ids: Optional[list[str]] = None # returned by pre-TODO shim + tasks: Optional[list[TaskListItem]] = None # returned by TODO+ shim + + class TaskInfoResponse(CoreModel): id: str status: TaskStatus diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index cf53746d8b..ac6bb6c1d7 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -39,6 +39,7 @@ from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.core.models.volumes import Volume from dstack._internal.core.services.profiles import get_termination +from dstack._internal.server import settings as server_settings from dstack._internal.server.models import ( FleetModel, InstanceHealthCheckModel, @@ -47,9 +48,11 @@ UserModel, ) from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse -from dstack._internal.server.schemas.runner import InstanceHealthResponse +from dstack._internal.server.schemas.runner import InstanceHealthResponse, TaskStatus +from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.offers import generate_shared_offer from dstack._internal.server.services.projects import list_user_project_models +from dstack._internal.server.services.runner.client import ShimClient from dstack._internal.utils import common as common_utils from dstack._internal.utils.logging import get_logger @@ -633,3 +636,40 @@ async def create_ssh_instance_model( busy_blocks=0, ) return im + + +def remove_dangling_tasks_from_instance(shim_client: ShimClient, instance: InstanceModel) -> None: + if not shim_client.is_api_v2_supported(): + return + assigned_to_instance_job_ids = {str(j.id) for j in instance.jobs} + task_list_response = shim_client.list_tasks() + tasks: list[tuple[str, Optional[TaskStatus]]] + if task_list_response.tasks is not None: + tasks = [(t.id, t.status) for t in task_list_response.tasks] + elif task_list_response.ids is not None: + # compatibility with pre-TODO shim + tasks = [(t_id, None) for t_id in task_list_response.ids] + else: + raise ValueError("Unexpected task list response, neither `tasks` nor `ids` is set") + for task_id, task_status in tasks: + if task_id in assigned_to_instance_job_ids: + continue + should_terminate = task_status != TaskStatus.TERMINATED + should_remove = not server_settings.SERVER_KEEP_SHIM_TASKS + if not (should_terminate or should_remove): + continue + logger.warning( + "%s: dangling task found, id=%s, status=%s. Terminating and/or removing", + fmt(instance), + task_id, + task_status or "", + ) + if should_terminate: + shim_client.terminate_task( + task_id=task_id, + reason=None, + message=None, + timeout=0, + ) + if should_remove: + shim_client.remove_task(task_id=task_id) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 0e379b4d99..ffea0c72ea 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -256,7 +256,16 @@ async def process_terminating_job( if jpd is not None: logger.debug("%s: stopping container", fmt(job_model)) ssh_private_keys = get_instance_ssh_private_keys(instance_model) - await stop_container(job_model, jpd, ssh_private_keys) + if not await stop_container(job_model, jpd, ssh_private_keys): + # The dangling container can be removed later during instance processing + logger.warning( + ( + "%s: could not stop container, possibly due to a communication error." + " See debug logs for details." + " Ignoring, can attempt to remove the container later" + ), + fmt(job_model), + ) if jrd is not None and jrd.volume_names is not None: volume_names = jrd.volume_names else: @@ -378,21 +387,22 @@ async def stop_container( job_model: JobModel, job_provisioning_data: JobProvisioningData, ssh_private_keys: tuple[str, Optional[str]], -): +) -> bool: if job_provisioning_data.dockerized: # send a request to the shim to terminate the docker container # SSHError and RequestException are caught in the `runner_ssh_tunner` decorator - await run_async( + return await run_async( _shim_submit_stop, ssh_private_keys, job_provisioning_data, None, job_model, ) + return True @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel): +def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel) -> bool: shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) resp = shim_client.healthcheck() @@ -418,6 +428,7 @@ def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel): shim_client.remove_task(task_id=job_model.id) else: shim_client.stop(force=True) + return True def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, List[JobModel]]]: diff --git a/src/dstack/_internal/server/services/logging.py b/src/dstack/_internal/server/services/logging.py index f2d4666b64..c738534198 100644 --- a/src/dstack/_internal/server/services/logging.py +++ b/src/dstack/_internal/server/services/logging.py @@ -1,14 +1,22 @@ from typing import Union -from dstack._internal.server.models import GatewayModel, JobModel, ProbeModel, RunModel +from dstack._internal.server.models import ( + GatewayModel, + InstanceModel, + JobModel, + ProbeModel, + RunModel, +) -def fmt(model: Union[RunModel, JobModel, GatewayModel, ProbeModel]) -> str: +def fmt(model: Union[RunModel, JobModel, InstanceModel, GatewayModel, ProbeModel]) -> str: """Consistent string representation of a model for logging.""" if isinstance(model, RunModel): return f"run({model.id.hex[:6]}){model.run_name}" if isinstance(model, JobModel): return f"job({model.id.hex[:6]}){model.job_name}" + if isinstance(model, InstanceModel): + return f"instance({model.id.hex[:6]}){model.name}" if isinstance(model, GatewayModel): return f"gateway({model.id.hex[:6]}){model.name}" if isinstance(model, ProbeModel): diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 7b7c31dd8f..60f6c5d8c9 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -26,6 +26,7 @@ ShimVolumeInfo, SubmitBody, TaskInfoResponse, + TaskListResponse, TaskSubmitRequest, TaskTerminateRequest, ) @@ -245,6 +246,12 @@ def get_instance_health(self) -> Optional[InstanceHealthResponse]: self._raise_for_status(resp) return self._response(InstanceHealthResponse, resp) + def list_tasks(self) -> TaskListResponse: + if not self.is_api_v2_supported(): + raise ShimAPIVersionError() + resp = self._request("GET", "/api/tasks", raise_for_status=True) + return self._response(TaskListResponse, resp) + def get_task(self, task_id: "_TaskID") -> TaskInfoResponse: if not self.is_api_v2_supported(): raise ShimAPIVersionError() diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index 990146f3b7..8255007073 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -1,8 +1,9 @@ import datetime as dt from collections import defaultdict +from collections.abc import Generator from contextlib import contextmanager from typing import Optional -from unittest.mock import Mock, patch +from unittest.mock import Mock, call, patch import gpuhunt import pytest @@ -41,7 +42,12 @@ from dstack._internal.server.models import InstanceHealthCheckModel, PlacementGroupModel from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult from dstack._internal.server.schemas.instances import InstanceCheck -from dstack._internal.server.schemas.runner import InstanceHealthResponse +from dstack._internal.server.schemas.runner import ( + InstanceHealthResponse, + TaskListItem, + TaskListResponse, + TaskStatus, +) from dstack._internal.server.testing.common import ( ComputeMockSpec, create_fleet, @@ -377,6 +383,116 @@ async def test_check_shim_check_instance_health(self, test_db, session: AsyncSes assert health_check.response == health_response.json() +class TestRemoveDanglingTasks: + @pytest.fixture + def ssh_tunnel_mock(self) -> Generator[Mock, None, None]: + with patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock: + yield SSHTunnelMock + + @pytest.fixture + def shim_client_mock(self) -> Generator[Mock, None, None]: + with patch("dstack._internal.server.services.runner.client.ShimClient") as ShimClientMock: + yield ShimClientMock.return_value + + @pytest.mark.asyncio + async def test_terminates_and_removes_dangling_tasks( + self, test_db, session: AsyncSession, ssh_tunnel_mock, shim_client_mock: Mock + ): + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock.list_tasks.return_value = TaskListResponse( + tasks=[ + TaskListItem(id=str(job.id), status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_1, status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_2, status=TaskStatus.TERMINATED), + ] + ) + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + shim_client_mock.terminate_task.assert_called_once_with( + task_id=dangling_task_id_1, reason=None, message=None, timeout=0 + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + @pytest.mark.asyncio + async def test_terminates_and_removes_dangling_tasks_legacy_shim( + self, test_db, session: AsyncSession, ssh_tunnel_mock, shim_client_mock: Mock + ): + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock.list_tasks.return_value = TaskListResponse( + ids=[str(job.id), dangling_task_id_1, dangling_task_id_2] + ) + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + assert shim_client_mock.terminate_task.call_count == 2 + shim_client_mock.terminate_task.assert_has_calls( + [ + call(task_id=dangling_task_id_1, reason=None, message=None, timeout=0), + call(task_id=dangling_task_id_2, reason=None, message=None, timeout=0), + ] + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + class TestTerminateIdleTime: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) From 04b857195dade45bdbfddd2140ebc60c022ac2eb Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Thu, 28 Aug 2025 10:41:36 +0200 Subject: [PATCH 2/3] Simplify OpenAPI schema --- runner/docs/shim.openapi.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index c3489cda44..d612cc0bad 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -404,8 +404,7 @@ components: id: $ref: "#/components/schemas/TaskID" status: - allOf: - - $ref: "#/components/schemas/TaskStatus" + $ref: "#/components/schemas/TaskStatus" required: - id - status From 3c842114252c1a1ea0c80106faba1bd9794d371c Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Thu, 28 Aug 2025 10:45:58 +0200 Subject: [PATCH 3/3] Set version in compatibility comments --- src/dstack/_internal/server/schemas/runner.py | 4 ++-- src/dstack/_internal/server/services/instances.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index e954c2308b..6de49f35a1 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -165,8 +165,8 @@ class TaskListItem(CoreModel): class TaskListResponse(CoreModel): - ids: Optional[list[str]] = None # returned by pre-TODO shim - tasks: Optional[list[TaskListItem]] = None # returned by TODO+ shim + ids: Optional[list[str]] = None # returned by pre-0.19.26 shim + tasks: Optional[list[TaskListItem]] = None # returned by 0.19.26+ shim class TaskInfoResponse(CoreModel): diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index ac6bb6c1d7..79fadc2b9f 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -647,7 +647,7 @@ def remove_dangling_tasks_from_instance(shim_client: ShimClient, instance: Insta if task_list_response.tasks is not None: tasks = [(t.id, t.status) for t in task_list_response.tasks] elif task_list_response.ids is not None: - # compatibility with pre-TODO shim + # compatibility with pre-0.19.26 shim tasks = [(t_id, None) for t_id in task_list_response.ids] else: raise ValueError("Unexpected task list response, neither `tasks` nor `ids` is set")