diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index 061da5d7ed..d612cc0bad 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -396,13 +396,21 @@ components: title: shim.api.TaskListResponse type: object properties: - ids: + tasks: type: array items: - $ref: "#/components/schemas/TaskID" - description: A list of all task IDs tracked by shim + type: object + properties: + id: + $ref: "#/components/schemas/TaskID" + status: + $ref: "#/components/schemas/TaskStatus" + required: + - id + - status + description: A list of all tasks tracked by shim, each with its ID and status required: - - ids + - tasks additionalProperties: false TaskInfoResponse: diff --git a/runner/internal/shim/api/api_test.go b/runner/internal/shim/api/api_test.go index 44c30423a2..b6879187af 100644 --- a/runner/internal/shim/api/api_test.go +++ b/runner/internal/shim/api/api_test.go @@ -34,8 +34,8 @@ func (ds *DummyRunner) Remove(context.Context, string) error { return nil } -func (ds *DummyRunner) TaskIDs() []string { - return []string{} +func (ds *DummyRunner) TaskList() []*shim.TaskListItem { + return []*shim.TaskListItem{} } func (ds *DummyRunner) TaskInfo(taskID string) shim.TaskInfo { diff --git a/runner/internal/shim/api/handlers.go b/runner/internal/shim/api/handlers.go index 1374fbd803..91df9cb55f 100644 --- a/runner/internal/shim/api/handlers.go +++ b/runner/internal/shim/api/handlers.go @@ -36,7 +36,8 @@ func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Reques } func (s *ShimServer) TaskListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { - return &TaskListResponse{IDs: s.runner.TaskIDs()}, nil + tasks := s.runner.TaskList() + return &TaskListResponse{tasks}, nil } func (s *ShimServer) TaskInfoHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 7f004a4046..41d09b8ac6 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -15,7 +15,7 @@ type InstanceHealthResponse struct { } type TaskListResponse struct { - IDs []string `json:"ids"` + Tasks []*shim.TaskListItem `json:"tasks"` } type TaskInfoResponse struct { diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 4ba67a1f94..8fd7026a99 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -19,7 +19,7 @@ type TaskRunner interface { Remove(ctx context.Context, taskID string) error Resources(context.Context) shim.Resources - TaskIDs() []string + TaskList() []*shim.TaskListItem TaskInfo(taskID string) shim.TaskInfo } diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 1834188ae1..18ded881c1 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -216,8 +216,13 @@ func (d *DockerRunner) Resources(ctx context.Context) Resources { } } -func (d *DockerRunner) TaskIDs() []string { - return d.tasks.IDs() +func (d *DockerRunner) TaskList() []*TaskListItem { + tasks := d.tasks.List() + result := make([]*TaskListItem, 0, len(tasks)) + for _, task := range tasks { + result = append(result, &TaskListItem{ID: task.ID, Status: task.Status}) + } + return result } func (d *DockerRunner) TaskInfo(taskID string) TaskInfo { diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 7294c6cb9b..b8da12670d 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -104,6 +104,11 @@ type TaskConfig struct { ContainerSshKeys []string `json:"container_ssh_keys"` } +type TaskListItem struct { + ID string `json:"id"` + Status TaskStatus `json:"status"` +} + type TaskInfo struct { ID string Status TaskStatus diff --git a/runner/internal/shim/task.go b/runner/internal/shim/task.go index cd2cd92658..f1d67b785c 100644 --- a/runner/internal/shim/task.go +++ b/runner/internal/shim/task.go @@ -148,14 +148,15 @@ type TaskStorage struct { mu sync.RWMutex } -func (ts *TaskStorage) IDs() []string { +// Get a _copy_ of all tasks. To "commit" changes, use Update() +func (ts *TaskStorage) List() []Task { ts.mu.RLock() defer ts.mu.RUnlock() - ids := make([]string, 0, len(ts.tasks)) - for id := range ts.tasks { - ids = append(ids, id) + tasks := make([]Task, 0, len(ts.tasks)) + for _, task := range ts.tasks { + tasks = append(tasks, task) } - return ids + return tasks } // Get a _copy_ of the task. To "commit" changes, use Update() diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 8e2127cd78..849c7dc54c 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -85,8 +85,10 @@ get_instance_provisioning_data, get_instance_requirements, get_instance_ssh_private_keys, + remove_dangling_tasks_from_instance, ) from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.offers import is_divisible_into_blocks from dstack._internal.server.services.placement import ( get_fleet_placement_group_models, @@ -789,6 +791,7 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non ssh_private_keys, job_provisioning_data, None, + instance=instance, check_instance_health=check_instance_health, ) if instance_check is False: @@ -935,7 +938,7 @@ async def _wait_for_instance_provisioning_data( @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) def _check_instance_inner( - ports: Dict[int, int], *, check_instance_health: bool = False + ports: Dict[int, int], *, instance: InstanceModel, check_instance_health: bool = False ) -> InstanceCheck: instance_health_response: Optional[InstanceHealthResponse] = None shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) @@ -955,6 +958,10 @@ def _check_instance_inner( args = (method.__func__.__name__, e.__class__.__name__, e) logger.exception(template, *args) return InstanceCheck(reachable=False, message=template % args) + try: + remove_dangling_tasks_from_instance(shim_client, instance) + except Exception as e: + logger.exception("%s: error removing dangling tasks: %s", fmt(instance), e) return runner_client.healthcheck_response_to_instance_check( healthcheck_response, instance_health_response ) diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index f71d60055c..6de49f35a1 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -159,6 +159,16 @@ class GPUDevice(CoreModel): path_in_container: str +class TaskListItem(CoreModel): + id: str + status: TaskStatus + + +class TaskListResponse(CoreModel): + ids: Optional[list[str]] = None # returned by pre-0.19.26 shim + tasks: Optional[list[TaskListItem]] = None # returned by 0.19.26+ shim + + class TaskInfoResponse(CoreModel): id: str status: TaskStatus diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index cf53746d8b..79fadc2b9f 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -39,6 +39,7 @@ from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.core.models.volumes import Volume from dstack._internal.core.services.profiles import get_termination +from dstack._internal.server import settings as server_settings from dstack._internal.server.models import ( FleetModel, InstanceHealthCheckModel, @@ -47,9 +48,11 @@ UserModel, ) from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse -from dstack._internal.server.schemas.runner import InstanceHealthResponse +from dstack._internal.server.schemas.runner import InstanceHealthResponse, TaskStatus +from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.offers import generate_shared_offer from dstack._internal.server.services.projects import list_user_project_models +from dstack._internal.server.services.runner.client import ShimClient from dstack._internal.utils import common as common_utils from dstack._internal.utils.logging import get_logger @@ -633,3 +636,40 @@ async def create_ssh_instance_model( busy_blocks=0, ) return im + + +def remove_dangling_tasks_from_instance(shim_client: ShimClient, instance: InstanceModel) -> None: + if not shim_client.is_api_v2_supported(): + return + assigned_to_instance_job_ids = {str(j.id) for j in instance.jobs} + task_list_response = shim_client.list_tasks() + tasks: list[tuple[str, Optional[TaskStatus]]] + if task_list_response.tasks is not None: + tasks = [(t.id, t.status) for t in task_list_response.tasks] + elif task_list_response.ids is not None: + # compatibility with pre-0.19.26 shim + tasks = [(t_id, None) for t_id in task_list_response.ids] + else: + raise ValueError("Unexpected task list response, neither `tasks` nor `ids` is set") + for task_id, task_status in tasks: + if task_id in assigned_to_instance_job_ids: + continue + should_terminate = task_status != TaskStatus.TERMINATED + should_remove = not server_settings.SERVER_KEEP_SHIM_TASKS + if not (should_terminate or should_remove): + continue + logger.warning( + "%s: dangling task found, id=%s, status=%s. Terminating and/or removing", + fmt(instance), + task_id, + task_status or "", + ) + if should_terminate: + shim_client.terminate_task( + task_id=task_id, + reason=None, + message=None, + timeout=0, + ) + if should_remove: + shim_client.remove_task(task_id=task_id) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 0e379b4d99..ffea0c72ea 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -256,7 +256,16 @@ async def process_terminating_job( if jpd is not None: logger.debug("%s: stopping container", fmt(job_model)) ssh_private_keys = get_instance_ssh_private_keys(instance_model) - await stop_container(job_model, jpd, ssh_private_keys) + if not await stop_container(job_model, jpd, ssh_private_keys): + # The dangling container can be removed later during instance processing + logger.warning( + ( + "%s: could not stop container, possibly due to a communication error." + " See debug logs for details." + " Ignoring, can attempt to remove the container later" + ), + fmt(job_model), + ) if jrd is not None and jrd.volume_names is not None: volume_names = jrd.volume_names else: @@ -378,21 +387,22 @@ async def stop_container( job_model: JobModel, job_provisioning_data: JobProvisioningData, ssh_private_keys: tuple[str, Optional[str]], -): +) -> bool: if job_provisioning_data.dockerized: # send a request to the shim to terminate the docker container # SSHError and RequestException are caught in the `runner_ssh_tunner` decorator - await run_async( + return await run_async( _shim_submit_stop, ssh_private_keys, job_provisioning_data, None, job_model, ) + return True @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel): +def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel) -> bool: shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) resp = shim_client.healthcheck() @@ -418,6 +428,7 @@ def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel): shim_client.remove_task(task_id=job_model.id) else: shim_client.stop(force=True) + return True def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, List[JobModel]]]: diff --git a/src/dstack/_internal/server/services/logging.py b/src/dstack/_internal/server/services/logging.py index f2d4666b64..c738534198 100644 --- a/src/dstack/_internal/server/services/logging.py +++ b/src/dstack/_internal/server/services/logging.py @@ -1,14 +1,22 @@ from typing import Union -from dstack._internal.server.models import GatewayModel, JobModel, ProbeModel, RunModel +from dstack._internal.server.models import ( + GatewayModel, + InstanceModel, + JobModel, + ProbeModel, + RunModel, +) -def fmt(model: Union[RunModel, JobModel, GatewayModel, ProbeModel]) -> str: +def fmt(model: Union[RunModel, JobModel, InstanceModel, GatewayModel, ProbeModel]) -> str: """Consistent string representation of a model for logging.""" if isinstance(model, RunModel): return f"run({model.id.hex[:6]}){model.run_name}" if isinstance(model, JobModel): return f"job({model.id.hex[:6]}){model.job_name}" + if isinstance(model, InstanceModel): + return f"instance({model.id.hex[:6]}){model.name}" if isinstance(model, GatewayModel): return f"gateway({model.id.hex[:6]}){model.name}" if isinstance(model, ProbeModel): diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 7b7c31dd8f..60f6c5d8c9 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -26,6 +26,7 @@ ShimVolumeInfo, SubmitBody, TaskInfoResponse, + TaskListResponse, TaskSubmitRequest, TaskTerminateRequest, ) @@ -245,6 +246,12 @@ def get_instance_health(self) -> Optional[InstanceHealthResponse]: self._raise_for_status(resp) return self._response(InstanceHealthResponse, resp) + def list_tasks(self) -> TaskListResponse: + if not self.is_api_v2_supported(): + raise ShimAPIVersionError() + resp = self._request("GET", "/api/tasks", raise_for_status=True) + return self._response(TaskListResponse, resp) + def get_task(self, task_id: "_TaskID") -> TaskInfoResponse: if not self.is_api_v2_supported(): raise ShimAPIVersionError() diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index 990146f3b7..8255007073 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -1,8 +1,9 @@ import datetime as dt from collections import defaultdict +from collections.abc import Generator from contextlib import contextmanager from typing import Optional -from unittest.mock import Mock, patch +from unittest.mock import Mock, call, patch import gpuhunt import pytest @@ -41,7 +42,12 @@ from dstack._internal.server.models import InstanceHealthCheckModel, PlacementGroupModel from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult from dstack._internal.server.schemas.instances import InstanceCheck -from dstack._internal.server.schemas.runner import InstanceHealthResponse +from dstack._internal.server.schemas.runner import ( + InstanceHealthResponse, + TaskListItem, + TaskListResponse, + TaskStatus, +) from dstack._internal.server.testing.common import ( ComputeMockSpec, create_fleet, @@ -377,6 +383,116 @@ async def test_check_shim_check_instance_health(self, test_db, session: AsyncSes assert health_check.response == health_response.json() +class TestRemoveDanglingTasks: + @pytest.fixture + def ssh_tunnel_mock(self) -> Generator[Mock, None, None]: + with patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock: + yield SSHTunnelMock + + @pytest.fixture + def shim_client_mock(self) -> Generator[Mock, None, None]: + with patch("dstack._internal.server.services.runner.client.ShimClient") as ShimClientMock: + yield ShimClientMock.return_value + + @pytest.mark.asyncio + async def test_terminates_and_removes_dangling_tasks( + self, test_db, session: AsyncSession, ssh_tunnel_mock, shim_client_mock: Mock + ): + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock.list_tasks.return_value = TaskListResponse( + tasks=[ + TaskListItem(id=str(job.id), status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_1, status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_2, status=TaskStatus.TERMINATED), + ] + ) + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + shim_client_mock.terminate_task.assert_called_once_with( + task_id=dangling_task_id_1, reason=None, message=None, timeout=0 + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + @pytest.mark.asyncio + async def test_terminates_and_removes_dangling_tasks_legacy_shim( + self, test_db, session: AsyncSession, ssh_tunnel_mock, shim_client_mock: Mock + ): + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock.list_tasks.return_value = TaskListResponse( + ids=[str(job.id), dangling_task_id_1, dangling_task_id_2] + ) + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + assert shim_client_mock.terminate_task.call_count == 2 + shim_client_mock.terminate_task.assert_has_calls( + [ + call(task_id=dangling_task_id_1, reason=None, message=None, timeout=0), + call(task_id=dangling_task_id_2, reason=None, message=None, timeout=0), + ] + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + class TestTerminateIdleTime: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)