Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions runner/docs/shim.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,21 @@ components:
title: shim.api.TaskListResponse
type: object
properties:
ids:
tasks:
type: array
items:
$ref: "#/components/schemas/TaskID"
description: A list of all task IDs tracked by shim
type: object
properties:
id:
$ref: "#/components/schemas/TaskID"
status:
$ref: "#/components/schemas/TaskStatus"
required:
- id
- status
description: A list of all tasks tracked by shim, each with its ID and status
required:
- ids
- tasks
additionalProperties: false

TaskInfoResponse:
Expand Down
4 changes: 2 additions & 2 deletions runner/internal/shim/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func (ds *DummyRunner) Remove(context.Context, string) error {
return nil
}

func (ds *DummyRunner) TaskIDs() []string {
return []string{}
func (ds *DummyRunner) TaskList() []*shim.TaskListItem {
return []*shim.TaskListItem{}
}

func (ds *DummyRunner) TaskInfo(taskID string) shim.TaskInfo {
Expand Down
3 changes: 2 additions & 1 deletion runner/internal/shim/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Reques
}

func (s *ShimServer) TaskListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
return &TaskListResponse{IDs: s.runner.TaskIDs()}, nil
tasks := s.runner.TaskList()
return &TaskListResponse{tasks}, nil
}

func (s *ShimServer) TaskInfoHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
Expand Down
2 changes: 1 addition & 1 deletion runner/internal/shim/api/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type InstanceHealthResponse struct {
}

type TaskListResponse struct {
IDs []string `json:"ids"`
Tasks []*shim.TaskListItem `json:"tasks"`
}

type TaskInfoResponse struct {
Expand Down
2 changes: 1 addition & 1 deletion runner/internal/shim/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type TaskRunner interface {
Remove(ctx context.Context, taskID string) error

Resources(context.Context) shim.Resources
TaskIDs() []string
TaskList() []*shim.TaskListItem
TaskInfo(taskID string) shim.TaskInfo
}

Expand Down
9 changes: 7 additions & 2 deletions runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,13 @@ func (d *DockerRunner) Resources(ctx context.Context) Resources {
}
}

func (d *DockerRunner) TaskIDs() []string {
return d.tasks.IDs()
func (d *DockerRunner) TaskList() []*TaskListItem {
tasks := d.tasks.List()
result := make([]*TaskListItem, 0, len(tasks))
for _, task := range tasks {
result = append(result, &TaskListItem{ID: task.ID, Status: task.Status})
}
return result
}

func (d *DockerRunner) TaskInfo(taskID string) TaskInfo {
Expand Down
5 changes: 5 additions & 0 deletions runner/internal/shim/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ type TaskConfig struct {
ContainerSshKeys []string `json:"container_ssh_keys"`
}

type TaskListItem struct {
ID string `json:"id"`
Status TaskStatus `json:"status"`
}

type TaskInfo struct {
ID string
Status TaskStatus
Expand Down
11 changes: 6 additions & 5 deletions runner/internal/shim/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,15 @@ type TaskStorage struct {
mu sync.RWMutex
}

func (ts *TaskStorage) IDs() []string {
// Get a _copy_ of all tasks. To "commit" changes, use Update()
func (ts *TaskStorage) List() []Task {
ts.mu.RLock()
defer ts.mu.RUnlock()
ids := make([]string, 0, len(ts.tasks))
for id := range ts.tasks {
ids = append(ids, id)
tasks := make([]Task, 0, len(ts.tasks))
for _, task := range ts.tasks {
tasks = append(tasks, task)
}
return ids
return tasks
}

// Get a _copy_ of the task. To "commit" changes, use Update()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@
get_instance_provisioning_data,
get_instance_requirements,
get_instance_ssh_private_keys,
remove_dangling_tasks_from_instance,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.logging import fmt
from dstack._internal.server.services.offers import is_divisible_into_blocks
from dstack._internal.server.services.placement import (
get_fleet_placement_group_models,
Expand Down Expand Up @@ -789,6 +791,7 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non
ssh_private_keys,
job_provisioning_data,
None,
instance=instance,
check_instance_health=check_instance_health,
)
if instance_check is False:
Expand Down Expand Up @@ -935,7 +938,7 @@ async def _wait_for_instance_provisioning_data(

@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1)
def _check_instance_inner(
ports: Dict[int, int], *, check_instance_health: bool = False
ports: Dict[int, int], *, instance: InstanceModel, check_instance_health: bool = False
) -> InstanceCheck:
instance_health_response: Optional[InstanceHealthResponse] = None
shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
Expand All @@ -955,6 +958,10 @@ def _check_instance_inner(
args = (method.__func__.__name__, e.__class__.__name__, e)
logger.exception(template, *args)
return InstanceCheck(reachable=False, message=template % args)
try:
remove_dangling_tasks_from_instance(shim_client, instance)
except Exception as e:
logger.exception("%s: error removing dangling tasks: %s", fmt(instance), e)
return runner_client.healthcheck_response_to_instance_check(
healthcheck_response, instance_health_response
)
Expand Down
10 changes: 10 additions & 0 deletions src/dstack/_internal/server/schemas/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ class GPUDevice(CoreModel):
path_in_container: str


class TaskListItem(CoreModel):
id: str
status: TaskStatus


class TaskListResponse(CoreModel):
ids: Optional[list[str]] = None # returned by pre-0.19.26 shim
tasks: Optional[list[TaskListItem]] = None # returned by 0.19.26+ shim


class TaskInfoResponse(CoreModel):
id: str
status: TaskStatus
Expand Down
42 changes: 41 additions & 1 deletion src/dstack/_internal/server/services/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.core.models.volumes import Volume
from dstack._internal.core.services.profiles import get_termination
from dstack._internal.server import settings as server_settings
from dstack._internal.server.models import (
FleetModel,
InstanceHealthCheckModel,
Expand All @@ -47,9 +48,11 @@
UserModel,
)
from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse
from dstack._internal.server.schemas.runner import InstanceHealthResponse
from dstack._internal.server.schemas.runner import InstanceHealthResponse, TaskStatus
from dstack._internal.server.services.logging import fmt
from dstack._internal.server.services.offers import generate_shared_offer
from dstack._internal.server.services.projects import list_user_project_models
from dstack._internal.server.services.runner.client import ShimClient
from dstack._internal.utils import common as common_utils
from dstack._internal.utils.logging import get_logger

Expand Down Expand Up @@ -633,3 +636,40 @@ async def create_ssh_instance_model(
busy_blocks=0,
)
return im


def remove_dangling_tasks_from_instance(shim_client: ShimClient, instance: InstanceModel) -> None:
if not shim_client.is_api_v2_supported():
return
assigned_to_instance_job_ids = {str(j.id) for j in instance.jobs}
task_list_response = shim_client.list_tasks()
tasks: list[tuple[str, Optional[TaskStatus]]]
if task_list_response.tasks is not None:
tasks = [(t.id, t.status) for t in task_list_response.tasks]
elif task_list_response.ids is not None:
# compatibility with pre-0.19.26 shim
tasks = [(t_id, None) for t_id in task_list_response.ids]
else:
raise ValueError("Unexpected task list response, neither `tasks` nor `ids` is set")
for task_id, task_status in tasks:
if task_id in assigned_to_instance_job_ids:
continue
should_terminate = task_status != TaskStatus.TERMINATED
should_remove = not server_settings.SERVER_KEEP_SHIM_TASKS
if not (should_terminate or should_remove):
continue
logger.warning(
"%s: dangling task found, id=%s, status=%s. Terminating and/or removing",
fmt(instance),
task_id,
task_status or "<unknown>",
)
if should_terminate:
shim_client.terminate_task(
task_id=task_id,
reason=None,
message=None,
timeout=0,
)
if should_remove:
shim_client.remove_task(task_id=task_id)
19 changes: 15 additions & 4 deletions src/dstack/_internal/server/services/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,16 @@ async def process_terminating_job(
if jpd is not None:
logger.debug("%s: stopping container", fmt(job_model))
ssh_private_keys = get_instance_ssh_private_keys(instance_model)
await stop_container(job_model, jpd, ssh_private_keys)
if not await stop_container(job_model, jpd, ssh_private_keys):
# The dangling container can be removed later during instance processing
logger.warning(
(
"%s: could not stop container, possibly due to a communication error."
" See debug logs for details."
" Ignoring, can attempt to remove the container later"
),
fmt(job_model),
)
if jrd is not None and jrd.volume_names is not None:
volume_names = jrd.volume_names
else:
Expand Down Expand Up @@ -378,21 +387,22 @@ async def stop_container(
job_model: JobModel,
job_provisioning_data: JobProvisioningData,
ssh_private_keys: tuple[str, Optional[str]],
):
) -> bool:
if job_provisioning_data.dockerized:
# send a request to the shim to terminate the docker container
# SSHError and RequestException are caught in the `runner_ssh_tunner` decorator
await run_async(
return await run_async(
_shim_submit_stop,
ssh_private_keys,
job_provisioning_data,
None,
job_model,
)
return True


@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT])
def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel):
def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel) -> bool:
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])

resp = shim_client.healthcheck()
Expand All @@ -418,6 +428,7 @@ def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel):
shim_client.remove_task(task_id=job_model.id)
else:
shim_client.stop(force=True)
return True


def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, List[JobModel]]]:
Expand Down
12 changes: 10 additions & 2 deletions src/dstack/_internal/server/services/logging.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from typing import Union

from dstack._internal.server.models import GatewayModel, JobModel, ProbeModel, RunModel
from dstack._internal.server.models import (
GatewayModel,
InstanceModel,
JobModel,
ProbeModel,
RunModel,
)


def fmt(model: Union[RunModel, JobModel, GatewayModel, ProbeModel]) -> str:
def fmt(model: Union[RunModel, JobModel, InstanceModel, GatewayModel, ProbeModel]) -> str:
"""Consistent string representation of a model for logging."""
if isinstance(model, RunModel):
return f"run({model.id.hex[:6]}){model.run_name}"
if isinstance(model, JobModel):
return f"job({model.id.hex[:6]}){model.job_name}"
if isinstance(model, InstanceModel):
return f"instance({model.id.hex[:6]}){model.name}"
if isinstance(model, GatewayModel):
return f"gateway({model.id.hex[:6]}){model.name}"
if isinstance(model, ProbeModel):
Expand Down
7 changes: 7 additions & 0 deletions src/dstack/_internal/server/services/runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ShimVolumeInfo,
SubmitBody,
TaskInfoResponse,
TaskListResponse,
TaskSubmitRequest,
TaskTerminateRequest,
)
Expand Down Expand Up @@ -245,6 +246,12 @@ def get_instance_health(self) -> Optional[InstanceHealthResponse]:
self._raise_for_status(resp)
return self._response(InstanceHealthResponse, resp)

def list_tasks(self) -> TaskListResponse:
if not self.is_api_v2_supported():
raise ShimAPIVersionError()
resp = self._request("GET", "/api/tasks", raise_for_status=True)
return self._response(TaskListResponse, resp)

def get_task(self, task_id: "_TaskID") -> TaskInfoResponse:
if not self.is_api_v2_supported():
raise ShimAPIVersionError()
Expand Down
Loading
Loading