From f7ed2f663adb91f068bb6546929c339a18cebc45 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 6 Aug 2025 16:59:27 +0000 Subject: [PATCH 1/5] Add NVIDIA GPU passive health checks * shim: Integrate libdcgm, add a new endpoint returning overall GPU health with list of incidents. * Periodically pull instance health from shim, store the raw response in a new DB table. Infer overall instance health and store it in a new column of the "instances" table. * Don't consider failed instances for submitted jobs. Note: instances with warnings are still considered for jobs. * API: add a new method returning a list of instance health checks with unified structure. * CLI: display "warning" and "failure" health statuses in the same way as "unreachable", below the instance status. Closes: https://github.com/dstackai/dstack/issues/2930 --- docs/docs/reference/environment-variables.md | 4 + runner/cmd/shim/main.go | 31 +++- runner/docs/shim.openapi.yaml | 97 +++++++++++- runner/go.mod | 2 + runner/go.sum | 4 + runner/internal/shim/api/handlers.go | 14 ++ runner/internal/shim/api/handlers_test.go | 4 +- runner/internal/shim/api/schemas.go | 9 +- runner/internal/shim/api/server.go | 8 +- runner/internal/shim/dcgm/wrapper.go | 117 ++++++++++++++ runner/internal/shim/dcgm/wrapper_test.go | 80 ++++++++++ runner/internal/shim/models.go | 4 + src/dstack/_internal/cli/utils/fleet.py | 10 +- .../core/backends/remote/provisioning.py | 45 +++--- src/dstack/_internal/core/models/health.py | 28 ++++ src/dstack/_internal/core/models/instances.py | 2 + src/dstack/_internal/server/app.py | 1 + .../_internal/server/background/__init__.py | 2 + .../background/tasks/process_instances.py | 137 ++++++++++------ .../728b1488b1b4_add_instance_health.py | 50 ++++++ src/dstack/_internal/server/models.py | 20 +++ .../_internal/server/routers/instances.py | 38 ++++- .../server/schemas/health/__init__.py | 0 .../_internal/server/schemas/health/dcgm.py | 56 +++++++ .../_internal/server/schemas/instances.py | 32 ++++ src/dstack/_internal/server/schemas/runner.py | 5 + .../_internal/server/services/instances.py | 104 ++++++++++++- .../server/services/runner/client.py | 72 ++++++--- .../_internal/server/services/runner/ssh.py | 8 +- src/dstack/_internal/server/settings.py | 7 + src/dstack/_internal/server/testing/common.py | 22 +++ .../tasks/test_process_instances.py | 147 ++++++++++++++---- .../tasks/test_process_submitted_jobs.py | 51 ++++++ .../_internal/server/routers/test_fleets.py | 4 + .../server/routers/test_instances.py | 107 +++++++++++++ .../server/services/test_instances.py | 4 + 36 files changed, 1192 insertions(+), 134 deletions(-) create mode 100644 runner/internal/shim/dcgm/wrapper.go create mode 100644 runner/internal/shim/dcgm/wrapper_test.go create mode 100644 src/dstack/_internal/core/models/health.py create mode 100644 src/dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py create mode 100644 src/dstack/_internal/server/schemas/health/__init__.py create mode 100644 src/dstack/_internal/server/schemas/health/dcgm.py diff --git a/docs/docs/reference/environment-variables.md b/docs/docs/reference/environment-variables.md index e98bf83617..fcb7777087 100644 --- a/docs/docs/reference/environment-variables.md +++ b/docs/docs/reference/environment-variables.md @@ -126,6 +126,10 @@ For more details on the options below, refer to the [server deployment](../guide - `DSTACK_SERVER_BACKGROUND_PROCESSING_DISABLED`{ #DSTACK_SERVER_BACKGROUND_PROCESSING_DISABLED } - Disables background processing if set to any value. Useful to run only web frontend and API server. - `DSTACK_SERVER_MAX_PROBES_PER_JOB`{ #DSTACK_SERVER_MAX_PROBES_PER_JOB } - Maximum number of probes allowed in a run configuration. Validated at apply time. - `DSTACK_SERVER_MAX_PROBE_TIMEOUT`{ #DSTACK_SERVER_MAX_PROBE_TIMEOUT } - Maximum allowed timeout for a probe. Validated at apply time. +- `DSTACK_SERVER_METRICS_RUNNING_TTL_SECONDS`{ #DSTACK_SERVER_METRICS_RUNNING_TTL_SECONDS } – Maximum age of metrics samples for running jobs. +- `DSTACK_SERVER_METRICS_FINISHED_TTL_SECONDS`{ #DSTACK_SERVER_METRICS_FINISHED_TTL_SECONDS } – Maximum age of metrics samples for finished jobs. +- `DSTACK_SERVER_INSTANCE_HEALTH_TTL_SECONDS`{ #DSTACK_SERVER_INSTANCE_HEALTH_TTL_SECONDS } – Maximum age of instance health checks. +- `DSTACK_SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS`{ #DSTACK_SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS } – Minimum time interval between consecutive health checks of the same instance. ??? info "Internal environment variables" The following environment variables are intended for development purposes: diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 47cc8e9332..ba6249490d 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -112,6 +112,14 @@ func main() { Destination: &args.DCGMExporter.Interval, EnvVars: []string{"DSTACK_DCGM_EXPORTER_INTERVAL"}, }, + /* DCGM Parameters */ + &cli.StringFlag{ + Name: "dcgm-address", + Usage: "nv-hostengine `hostname`, e.g., `localhost`", + DefaultText: "start libdcgm in embedded mode", + Destination: &args.DCGM.Address, + EnvVars: []string{"DSTACK_DCGM_ADDRESS"}, + }, /* Docker Parameters */ &cli.BoolFlag{ Name: "privileged", @@ -196,6 +204,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) } var dcgmExporter *dcgm.DCGMExporter + var dcgmWrapper *dcgm.DCGMWrapper if common.GetGpuVendor() == common.GpuVendorNvidia { dcgmExporterPath, err := dcgm.GetDCGMExporterExecPath(ctx) @@ -207,16 +216,32 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) if err == nil { log.Info(ctx, "using DCGM Exporter") defer func() { - _ = dcgmExporter.Stop(ctx) + if err := dcgmExporter.Stop(ctx); err != nil { + log.Error(ctx, "failed to stop DCGM Exporter", "err", err) + } }() } else { log.Warning(ctx, "not using DCGM Exporter", "err", err) - dcgmExporter = nil + } + + dcgmWrapper, err = dcgm.NewDCGMWrapper(args.DCGM.Address) + if err == nil { + log.Info(ctx, "using libdcgm") + defer func() { + if err := dcgmWrapper.Shutdown(); err != nil { + log.Error(ctx, "failed to shut down libdcgm", "err", err) + } + }() + if err := dcgmWrapper.EnableHealthChecks(); err != nil { + log.Error(ctx, "failed to enable libdcgm health checks", "err", err) + } + } else { + log.Warning(ctx, "not using libdcgm", "err", err) } } address := fmt.Sprintf(":%d", args.Shim.HTTPPort) - shimServer := api.NewShimServer(ctx, address, dockerRunner, dcgmExporter, Version) + shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper) defer func() { shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second) diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index bc199181be..061da5d7ed 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.1 info: title: dstack-shim API - version: v2/0.18.34 + version: v2/0.19.22 x-logo: url: https://avatars.githubusercontent.com/u/54146142?s=260 description: > @@ -50,10 +50,25 @@ paths: schema: $ref: "#/components/schemas/HealthcheckResponse" + /instance/health: + get: + summary: Get instance health + + description: (since [0.19.22](https://github.com/dstackai/dstack/releases/tag/0.19.22)) Returns an object of optional passive system checks + tags: [Instance] + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: "#/components/schemas/InstanceHealthResponse" + /tasks: get: summary: Get task list description: Returns a list of all tasks known to shim, including terminated ones + tags: [Tasks] responses: "200": description: "" @@ -63,6 +78,7 @@ paths: $ref: "#/components/schemas/TaskListResponse" post: summary: Submit and run new task + tags: [Tasks] requestBody: required: true content: @@ -86,6 +102,7 @@ paths: /tasks/{id}: get: summary: Get task info + tags: [Tasks] parameters: - $ref: "#/parameters/taskId" responses: @@ -102,6 +119,7 @@ paths: Stops the task, that is, cancels image pulling if in progress, stops the container if running, and sets the status to `terminated`. No-op if the task is already terminated + tags: [Tasks] parameters: - in: path name: id @@ -131,6 +149,7 @@ paths: description: > Removes the task from in-memory storage and destroys its associated resources: a container, logs, etc. + tags: [Tasks] parameters: - $ref: "#/parameters/taskId" responses: @@ -270,7 +289,7 @@ components: type: string default: "" description: Mount point inside container - + GPUDevice: title: shim.GPUDevice type: object @@ -284,6 +303,72 @@ components: default: "" description: Path inside container + DCGMHealth: + title: shim.dcgm.Health + type: object + properties: + overall_health: + type: integer + description: > + [dcgmHealthWatchResult_enum](https://docs.nvidia.com/datacenter/dcgm/latest/dcgm-api/dcgm-api-data-structs.html#_CPPv426dcgmHealthWatchResult_enum) + examples: + - 10 + incidents: + type: array + items: + $ref: "#/components/schemas/DCGMHealthIncident" + required: + - overall_health + - incidents + additionalProperties: false + + DCGMHealthIncident: + title: shim.dcgm.HealthIncident + type: object + properties: + system: + type: integer + description: > + [dcgmHealthSystems_enum](https://docs.nvidia.com/datacenter/dcgm/latest/dcgm-api/dcgm-api-data-structs.html#_CPPv422dcgmHealthSystems_enum) + examples: + - 1 + health: + type: integer + description: > + [dcgmHealthWatchResult_enum](https://docs.nvidia.com/datacenter/dcgm/latest/dcgm-api/dcgm-api-data-structs.html#_CPPv426dcgmHealthWatchResult_enum) + examples: + - 10 + error_message: + type: string + examples: + - > + Detected more than 16 PCIe replays per minute for GPU 0 : 99 Reconnect PCIe card. + Run system side PCIE diagnostic utilities to verify hops off the GPU board. If issue is on the board, run the field diagnostic. + error_code: + type: integer + description: > + [dcgmError_enum](https://github.com/NVIDIA/DCGM/blob/master/dcgmlib/dcgm_errors.h) + examples: + - 3 + entity_group_id: + type: integer + description: > + [dcgm_field_entity_group_t](https://docs.nvidia.com/datacenter/dcgm/latest/dcgm-api/dcgm-api-field-entity.html#_CPPv425dcgm_field_entity_group_t) + examples: + - 1 + entity_id: + type: integer + examples: + - 0 + required: + - system + - health + - error_message + - error_code + - entity_group_id + - entity_id + additionalProperties: false + HealthcheckResponse: title: shim.api.HealthcheckResponse type: object @@ -299,6 +384,14 @@ components: - version additionalProperties: false + InstanceHealthResponse: + title: shim.api.InstanceHealthResponse + type: object + properties: + dcgm: + $ref: "#/components/schemas/DCGMHealth" + additionalProperties: false + TaskListResponse: title: shim.api.TaskListResponse type: object diff --git a/runner/go.mod b/runner/go.mod index 850ea82530..8af11d34dc 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -3,6 +3,7 @@ module github.com/dstackai/dstack/runner go 1.23.8 require ( + github.com/NVIDIA/go-dcgm v0.0.0-20250707210631-823394f2bd9b github.com/alexellis/go-execute/v2 v2.2.1 github.com/bluekeyes/go-gitdiff v0.7.2 github.com/codeclysm/extract/v4 v4.0.0 @@ -29,6 +30,7 @@ require ( dario.cat/mergo v1.0.0 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/ProtonMail/go-crypto v1.0.0 // indirect + github.com/bits-and-blooms/bitset v1.22.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect diff --git a/runner/go.sum b/runner/go.sum index 1222fcac83..adb3d96a07 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -7,6 +7,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= +github.com/NVIDIA/go-dcgm v0.0.0-20250707210631-823394f2bd9b h1:FL0NJYUNMX1ezl2Dv0azgedHPBXDuqHnqGDtqj6aqZM= +github.com/NVIDIA/go-dcgm v0.0.0-20250707210631-823394f2bd9b/go.mod h1:cA0Bv7+JtAd8sqCCZizhAQjj4+Z47x/d8KD60iYBT+g= github.com/ProtonMail/go-crypto v1.0.0 h1:LRuvITjQWX+WIfr930YHG2HNfjR1uOfyf5vE0kC2U78= github.com/ProtonMail/go-crypto v1.0.0/go.mod h1:EjAoLdwvbIOoOQr3ihjnSoLZRtE8azugULFRteWMNc0= github.com/alexellis/go-execute/v2 v2.2.1 h1:4Ye3jiCKQarstODOEmqDSRCqxMHLkC92Bhse743RdOI= @@ -17,6 +19,8 @@ github.com/arduino/go-paths-helper v1.12.1 h1:WkxiVUxBjKWlLMiMuYy8DcmVrkxdP7aKxQ github.com/arduino/go-paths-helper v1.12.1/go.mod h1:jcpW4wr0u69GlXhTYydsdsqAjLaYK5n7oWHfKqOG6LM= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/bits-and-blooms/bitset v1.22.0 h1:Tquv9S8+SGaS3EhyA+up3FXzmkhxPGjQQCkcs2uw7w4= +github.com/bits-and-blooms/bitset v1.22.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bluekeyes/go-gitdiff v0.7.2 h1:42jrcVZdjjxXtVsFNYTo/I6T1ZvIiQL+iDDLiH904hw= github.com/bluekeyes/go-gitdiff v0.7.2/go.mod h1:QpfYYO1E0fTVHVZAZKiRjtSGY9823iCdvGXBcEzHGbM= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= diff --git a/runner/internal/shim/api/handlers.go b/runner/internal/shim/api/handlers.go index 0a0af9b39b..1374fbd803 100644 --- a/runner/internal/shim/api/handlers.go +++ b/runner/internal/shim/api/handlers.go @@ -21,6 +21,20 @@ func (s *ShimServer) HealthcheckHandler(w http.ResponseWriter, r *http.Request) }, nil } +func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + ctx := r.Context() + response := InstanceHealthResponse{} + if s.dcgmWrapper != nil { + if dcgmHealth, err := s.dcgmWrapper.GetHealth(); err != nil { + log.Error(ctx, "failed to get health from DCGM", "err", err) + } else { + response.DCGM = &dcgmHealth + } + } + + return &response, nil +} + func (s *ShimServer) TaskListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { return &TaskListResponse{IDs: s.runner.TaskIDs()}, nil } diff --git a/runner/internal/shim/api/handlers_test.go b/runner/internal/shim/api/handlers_test.go index 98e129fc1a..c640fdb731 100644 --- a/runner/internal/shim/api/handlers_test.go +++ b/runner/internal/shim/api/handlers_test.go @@ -13,7 +13,7 @@ func TestHealthcheck(t *testing.T) { request := httptest.NewRequest("GET", "/api/healthcheck", nil) responseRecorder := httptest.NewRecorder() - server := NewShimServer(context.Background(), ":12345", NewDummyRunner(), nil, "0.0.1.dev2") + server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil) f := common.JSONResponseHandler(server.HealthcheckHandler) f(responseRecorder, request) @@ -30,7 +30,7 @@ func TestHealthcheck(t *testing.T) { } func TestTaskSubmit(t *testing.T) { - server := NewShimServer(context.Background(), ":12340", NewDummyRunner(), nil, "0.0.1.dev2") + server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil) requestBody := `{ "id": "dummy-id", "name": "dummy-name", diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 327db2f443..7f004a4046 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -1,12 +1,19 @@ package api -import "github.com/dstackai/dstack/runner/internal/shim" +import ( + "github.com/dstackai/dstack/runner/internal/shim" + "github.com/dstackai/dstack/runner/internal/shim/dcgm" +) type HealthcheckResponse struct { Service string `json:"service"` Version string `json:"version"` } +type InstanceHealthResponse struct { + DCGM *dcgm.Health `json:"dcgm"` +} + type TaskListResponse struct { IDs []string `json:"ids"` } diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 94a693b4f0..62052e6b8f 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -29,11 +29,15 @@ type ShimServer struct { runner TaskRunner dcgmExporter *dcgm.DCGMExporter + dcgmWrapper *dcgm.DCGMWrapper version string } -func NewShimServer(ctx context.Context, address string, runner TaskRunner, dcgmExporter *dcgm.DCGMExporter, version string) *ShimServer { +func NewShimServer( + ctx context.Context, address string, version string, + runner TaskRunner, dcgmExporter *dcgm.DCGMExporter, dcgmWrapper *dcgm.DCGMWrapper, +) *ShimServer { r := api.NewRouter() s := &ShimServer{ HttpServer: &http.Server{ @@ -45,12 +49,14 @@ func NewShimServer(ctx context.Context, address string, runner TaskRunner, dcgmE runner: runner, dcgmExporter: dcgmExporter, + dcgmWrapper: dcgmWrapper, version: version, } // The healthcheck endpoint should stay backward compatible, as it is used for negotiation r.AddHandler("GET", "/api/healthcheck", s.HealthcheckHandler) + r.AddHandler("GET", "/api/instance/health", s.InstanceHealthHandler) r.AddHandler("GET", "/api/tasks", s.TaskListHandler) r.AddHandler("GET", "/api/tasks/{id}", s.TaskInfoHandler) r.AddHandler("POST", "/api/tasks", s.TaskSubmitHandler) diff --git a/runner/internal/shim/dcgm/wrapper.go b/runner/internal/shim/dcgm/wrapper.go new file mode 100644 index 0000000000..749060128c --- /dev/null +++ b/runner/internal/shim/dcgm/wrapper.go @@ -0,0 +1,117 @@ +package dcgm + +import ( + "errors" + "fmt" + "sync" + + godcgm "github.com/NVIDIA/go-dcgm/pkg/dcgm" +) + +type HealthStatus string + +const ( + HealthStatusHealthy HealthStatus = "healthy" + HealthStatusWarning HealthStatus = "warning" + HealthStatusFailure HealthStatus = "failure" +) + +type HealthIncident struct { + System int `json:"system"` + Health int `json:"health"` + ErrorMessage string `json:"error_message"` + ErrorCode int `json:"error_code"` + EntityGroupID int `json:"entity_group_id"` + EntityID int `json:"entity_id"` +} + +type Health struct { + OverallHealth int `json:"overall_health"` + Incidents []HealthIncident `json:"incidents"` +} + +// DCGMWrapper is a wrapper around go-dcgm (which, in turn, is a wrapper around libdcgm.so) +type DCGMWrapper struct { + group godcgm.GroupHandle + healthCheckEnabled bool + + mu *sync.Mutex +} + +// NewDCGMWrapper initializes and starts DCGM in the specific mode: +// - If address is empty, then libdcgm starts embedded hostengine within the current process. +// This is the main mode. +// - If address is not empty, then libdcgm connects to already running nv-hostengine service via TCP. +// This mode is useful for debugging, e.g., one can start nv-hostengine via systemd and inject +// errors via dcgmi: +// - systemctl start nvidia-dcgm.service +// - dcgmi test --inject --gpuid 0 -f 202 -v 99999 +// +// Note: embedded hostengine is started in AUTO operation mode, which means that +// the library handles periodic tasks by itself executing them in additional threads. +func NewDCGMWrapper(address string) (*DCGMWrapper, error) { + var err error + if address == "" { + _, err = godcgm.Init(godcgm.Embedded) + } else { + // "address is a unix socket filename (1) or a TCP/IP address (0)" + _, err = godcgm.Init(godcgm.Standalone, address, "0") + } + if err != nil { + return nil, fmt.Errorf("failed to initialize or start DCGM: %w", err) + } + return &DCGMWrapper{ + group: godcgm.GroupAllGPUs(), + mu: new(sync.Mutex), + }, nil +} + +func (w *DCGMWrapper) Shutdown() error { + if err := godcgm.Shutdown(); err != nil { + return fmt.Errorf("failed to shut down DCGM: %w", err) + } + return nil +} + +func (w *DCGMWrapper) EnableHealthChecks() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.healthCheckEnabled { + return errors.New("health check system already enabled") + } + if err := godcgm.HealthSet(w.group, godcgm.DCGM_HEALTH_WATCH_ALL); err != nil { + return fmt.Errorf("failed to configure health watches: %w", err) + } + // "On the first call, stateful information about all of the enabled watches within a group + // is created but no error results are provided. On subsequent calls, any error information + // will be returned." + if _, err := godcgm.HealthCheck(w.group); err != nil { + return fmt.Errorf("failed to initialize health watches state: %w", err) + } + w.healthCheckEnabled = true + return nil +} + +func (w *DCGMWrapper) GetHealth() (Health, error) { + health := Health{} + if !w.healthCheckEnabled { + return health, errors.New("health check system is not enabled") + } + response, err := godcgm.HealthCheck(w.group) + if err != nil { + return health, fmt.Errorf("failed to fetch health status: %w", err) + } + health.OverallHealth = int(response.OverallHealth) + health.Incidents = make([]HealthIncident, 0, len(response.Incidents)) + for _, incident := range response.Incidents { + health.Incidents = append(health.Incidents, HealthIncident{ + System: int(incident.System), + Health: int(incident.Health), + ErrorMessage: incident.Error.Message, + ErrorCode: int(incident.Error.Code), + EntityGroupID: int(incident.EntityInfo.EntityGroupId), + EntityID: int(incident.EntityInfo.EntityId), + }) + } + return health, nil +} diff --git a/runner/internal/shim/dcgm/wrapper_test.go b/runner/internal/shim/dcgm/wrapper_test.go new file mode 100644 index 0000000000..93c5b12ef6 --- /dev/null +++ b/runner/internal/shim/dcgm/wrapper_test.go @@ -0,0 +1,80 @@ +package dcgm + +import ( + "strings" + "testing" + "time" + + godcgm "github.com/NVIDIA/go-dcgm/pkg/dcgm" + "github.com/stretchr/testify/require" +) + +func TestDCGMWrapperGetHealth(t *testing.T) { + dcgmw := getDCGMWrapper(t) + defer dcgmw.Shutdown() + + gpuID := getGpuID(t) + + err := dcgmw.EnableHealthChecks() + require.NoError(t, err) + + health, err := dcgmw.GetHealth() + require.NoError(t, err) + require.Equal(t, health.OverallHealth, 0) // DCGM_HEALTH_RESULT_PASS + require.Len(t, health.Incidents, 0) + + injectError(t, gpuID, godcgm.DCGM_FI_DEV_ECC_DBE_VOL_TOTAL, godcgm.DCGM_FT_INT64, int64(888)) + injectError(t, gpuID, godcgm.DCGM_FI_DEV_PCIE_REPLAY_COUNTER, godcgm.DCGM_FT_INT64, int64(999)) + + health, err = dcgmw.GetHealth() + require.NoError(t, err) + require.Equal(t, health.OverallHealth, 20) // DCGM_HEALTH_RESULT_FAIL + require.Len(t, health.Incidents, 2) + for _, incident := range health.Incidents { + switch incident.System { + case 0x1: // DCGM_HEALTH_WATCH_PCIE + require.Equal(t, incident.Health, 10) // DCGM_HEALTH_RESULT_WARN + require.Contains(t, incident.ErrorMessage, "PCIe replay") + case 0x10: // DCGM_HEALTH_WATCH_MEM + require.Equal(t, incident.Health, 20) // DCGM_HEALTH_RESULT_FAIL + require.Contains(t, incident.ErrorMessage, "volatile double-bit ECC error") + default: + t.Logf("unexpected HealthSystem: 0x%x", incident.System) + t.FailNow() + } + require.Equal(t, incident.EntityGroupID, 1) // FE_GPU + require.Equal(t, incident.EntityID, int(gpuID)) + } +} + +// Utils. Must be called after NewDCGMWrapper(), as it indirectly calls dlopen("libdcgm.so.4") + +func getDCGMWrapper(t *testing.T) *DCGMWrapper { + dcgmw, err := NewDCGMWrapper("") + if err != nil && strings.Contains(err.Error(), "libdcgm.so") { + t.Skip("Skipping test that requires ligdcm.so") + } + require.NoError(t, err) + gpuIDs, err := godcgm.GetSupportedDevices() + require.NoError(t, err) + if len(gpuIDs) < 1 { + t.Skip("Skipping test that requires live GPUs. None were found") + } + return dcgmw +} + +func getGpuID(t *testing.T) uint { + t.Helper() + gpuIDs, err := godcgm.GetSupportedDevices() + require.NoError(t, err) + if len(gpuIDs) < 1 { + t.Skip("Skipping test that requires live GPUs. None were found") + } + return gpuIDs[0] +} + +func injectError(t *testing.T, gpuID uint, fieldID godcgm.Short, fieldType uint, value interface{}) { + t.Helper() + err := godcgm.InjectFieldValue(gpuID, fieldID, fieldType, 0, time.Now().UnixMicro(), value) + require.NoError(t, err) +} diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 78c7a4a3e0..7294c6cb9b 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -33,6 +33,10 @@ type CLIArgs struct { Interval int // milliseconds } + DCGM struct { + Address string + } + Docker struct { ConcatinatedPublicSSHKeys string Privileged bool diff --git a/src/dstack/_internal/cli/utils/fleet.py b/src/dstack/_internal/cli/utils/fleet.py index 1c96e2b41e..3d04100c8a 100644 --- a/src/dstack/_internal/cli/utils/fleet.py +++ b/src/dstack/_internal/cli/utils/fleet.py @@ -51,11 +51,11 @@ def get_fleets_table( and total_blocks > 1 ): status = f"{busy_blocks}/{total_blocks} {InstanceStatus.BUSY.value}" - if ( - instance.status in [InstanceStatus.IDLE, InstanceStatus.BUSY] - and instance.unreachable - ): - status += "\n(unreachable)" + if instance.status in [InstanceStatus.IDLE, InstanceStatus.BUSY]: + if instance.unreachable: + status += "\n(unreachable)" + elif not instance.health_status.is_healthy(): + status += f"\n({instance.health_status.value})" backend = instance.backend or "" if backend == "remote": diff --git a/src/dstack/_internal/core/backends/remote/provisioning.py b/src/dstack/_internal/core/backends/remote/provisioning.py index 9c65ec66ac..c4731c8a4e 100644 --- a/src/dstack/_internal/core/backends/remote/provisioning.py +++ b/src/dstack/_internal/core/backends/remote/provisioning.py @@ -20,6 +20,7 @@ Resources, SSHConnectionParams, ) +from dstack._internal.server.schemas.runner import HealthcheckResponse from dstack._internal.utils.gpu import ( convert_amd_gpu_name, convert_intel_accelerator_name, @@ -220,27 +221,35 @@ def get_host_info(client: paramiko.SSHClient, working_dir: str) -> Dict[str, Any raise ProvisioningError("Cannot get host_info") -def get_shim_healthcheck(client: paramiko.SSHClient) -> str: +def get_shim_healthcheck(client: paramiko.SSHClient) -> HealthcheckResponse: retries = 20 iter_delay = 3 for _ in range(retries): - try: - _, stdout, stderr = client.exec_command( - f"curl -s http://localhost:{DSTACK_SHIM_HTTP_PORT}/api/healthcheck", timeout=15 - ) - out = stdout.read().strip().decode() - err = stderr.read().strip().decode() - if err: - raise ProvisioningError( - f"The command 'get_shim_healthcheck' didn't work. stdout: {out}, stderr: {err}" - ) - if not out: - logger.debug("healthcheck is empty. retry") - time.sleep(iter_delay) - continue - return out - except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"get_shim_healthcheck failed: {e}") from e + healthcheck = _get_shim_healthcheck(client) + if healthcheck is not None: + return healthcheck + logger.debug("healthcheck is empty. retry") + time.sleep(iter_delay) + raise ProvisioningError("Cannot get HealthcheckResponse") + + +def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[HealthcheckResponse]: + try: + _, stdout, stderr = client.exec_command( + f"curl -s http://localhost:{DSTACK_SHIM_HTTP_PORT}/api/healthcheck", timeout=15 + ) + out = stdout.read().strip().decode() + err = stderr.read().strip().decode() + except (paramiko.SSHException, OSError) as e: + raise ProvisioningError(f"get_shim_healthcheck failed: {e}") from e + if err: + raise ProvisioningError(f"get_shim_healthcheck didn't work. stdout: {out}, stderr: {err}") + if not out: + return None + try: + return HealthcheckResponse.__response__.parse_raw(out) + except ValueError as e: + raise ProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e def host_info_to_instance_type(host_info: Dict[str, Any], cpu_arch: GoArchType) -> InstanceType: diff --git a/src/dstack/_internal/core/models/health.py b/src/dstack/_internal/core/models/health.py new file mode 100644 index 0000000000..06c8a23fb5 --- /dev/null +++ b/src/dstack/_internal/core/models/health.py @@ -0,0 +1,28 @@ +from datetime import datetime +from enum import Enum + +from dstack._internal.core.models.common import CoreModel + + +class HealthStatus(str, Enum): + HEALTHY = "healthy" + WARNING = "warning" + FAILURE = "failure" + + def is_healthy(self) -> bool: + return self == self.HEALTHY + + def is_failure(self) -> bool: + return self == self.FAILURE + + +class HealthEvent(CoreModel): + timestamp: datetime + status: HealthStatus + message: str + + +class HealthCheck(CoreModel): + collected_at: datetime + status: HealthStatus + events: list[HealthEvent] diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index 0771e2a2dc..40a86e5d5f 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -9,6 +9,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.envs import Env +from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import pretty_resources @@ -225,6 +226,7 @@ class Instance(CoreModel): hostname: Optional[str] = None status: InstanceStatus unreachable: bool = False + health_status: HealthStatus = HealthStatus.HEALTHY termination_reason: Optional[str] = None created: datetime.datetime region: Optional[str] = None diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index b49cd3a2aa..8e65897710 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -200,6 +200,7 @@ def register_routes(app: FastAPI, ui: bool = True): app.include_router(fleets.root_router) app.include_router(fleets.project_router) app.include_router(instances.root_router) + app.include_router(instances.project_router) app.include_router(repos.router) app.include_router(runs.root_router) app.include_router(runs.project_router) diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index e4d2cece53..099f8ce51c 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -9,6 +9,7 @@ ) from dstack._internal.server.background.tasks.process_idle_volumes import process_idle_volumes from dstack._internal.server.background.tasks.process_instances import ( + delete_instance_health_checks, process_instances, ) from dstack._internal.server.background.tasks.process_metrics import ( @@ -86,6 +87,7 @@ def start_background_tasks() -> AsyncIOScheduler: IntervalTrigger(seconds=10, jitter=2), max_instances=1, ) + _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR): # Add multiple copies of tasks if requested. # max_instances=1 for additional copies to avoid running too many tasks. diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 2ad6ca4837..de3ab7d0e1 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -1,13 +1,14 @@ import asyncio import datetime +import logging from datetime import timedelta -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, cast import requests from paramiko.pkey import PKey from paramiko.ssh_exception import PasswordRequiredException from pydantic import ValidationError -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -77,12 +78,14 @@ from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( FleetModel, + InstanceHealthCheckModel, InstanceModel, JobModel, PlacementGroupModel, ProjectModel, ) -from dstack._internal.server.schemas.runner import HealthcheckResponse +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import InstanceHealthResponse from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.fleets import ( fleet_model_to_fleet, @@ -103,7 +106,6 @@ schedule_fleet_placement_groups_deletion, ) from dstack._internal.server.services.runner import client as runner_client -from dstack._internal.server.services.runner.client import HealthStatus from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.server.utils import sentry_utils from dstack._internal.utils.common import ( @@ -137,6 +139,17 @@ async def process_instances(batch_size: int = 1): await asyncio.gather(*tasks) +@sentry_utils.instrument_background_task +async def delete_instance_health_checks(): + now = get_current_datetime() + cutoff = now - timedelta(seconds=server_settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS) + async with get_session_ctx() as session: + await session.execute( + delete(InstanceHealthCheckModel).where(InstanceHealthCheckModel.collected_at < cutoff) + ) + await session.commit() + + @sentry_utils.instrument_background_task async def _process_next_instance(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) @@ -415,10 +428,10 @@ async def _add_remote(instance: InstanceModel) -> None: def _deploy_instance( remote_details: RemoteConnectionInfo, - pkeys: List[PKey], + pkeys: list[PKey], ssh_proxy_pkeys: Optional[list[PKey]], - authorized_keys: List[str], -) -> Tuple[HealthStatus, Dict[str, Any], GoArchType]: + authorized_keys: list[str], +) -> tuple[InstanceCheck, dict[str, Any], GoArchType]: with get_paramiko_connection( remote_details.ssh_user, remote_details.host, @@ -466,14 +479,10 @@ def _deploy_instance( host_info = get_host_info(client, dstack_working_dir) logger.debug("Received a host_info %s", host_info) - raw_health = get_shim_healthcheck(client) - try: - health_response = HealthcheckResponse.__response__.parse_raw(raw_health) - except ValueError as e: - raise ProvisioningError("Cannot read HealthcheckResponse") from e - health = runner_client.health_response_to_health_status(health_response) + healthcheck = get_shim_healthcheck(client) + instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck) - return health, host_info, arch + return instance_check, host_info, arch async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None: @@ -758,29 +767,63 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non ssh_private_keys = get_instance_ssh_private_keys(instance) + health_check_cutoff = get_current_datetime() - timedelta( + seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS + ) + res = await session.execute( + select(func.count(1)).where( + InstanceHealthCheckModel.instance_id == instance.id, + InstanceHealthCheckModel.collected_at > health_check_cutoff, + ) + ) + check_instance_health = res.scalar_one() == 0 + # May return False if fails to establish ssh connection - health_status_response = await run_async( - _instance_healthcheck, + instance_check = await run_async( + _check_instance_inner, ssh_private_keys, job_provisioning_data, None, + check_instance_health=check_instance_health, ) - if isinstance(health_status_response, bool) or health_status_response is None: - health_status = HealthStatus(healthy=False, reason="SSH or tunnel error") - else: - health_status = health_status_response + if instance_check is False: + instance_check = InstanceCheck(reachable=False, message="SSH or tunnel error") - logger.debug( - "Check instance %s status. shim health: %s", + if instance_check.reachable and check_instance_health: + health_status = instance_check.get_health_status() + else: + # Keep previous health status + health_status = instance.health + + loglevel = logging.DEBUG + if not instance_check.reachable or (check_instance_health and not health_status.is_healthy()): + loglevel = logging.WARNING + logger.log( + loglevel, + "Instance %s check: reachable=%s health_status=%s message=%r", instance.name, - health_status, - extra={"instance_name": instance.name, "shim_health": health_status}, + instance_check.reachable, + health_status.name, + instance_check.message, + extra={"instance_name": instance.name, "health_status": health_status}, ) - if health_status.healthy: + if instance_check.has_health_checks(): + # ensured by has_health_checks() + assert instance_check.health_response is not None + health_check_model = InstanceHealthCheckModel( + instance_id=instance.id, + collected_at=get_current_datetime(), + status=health_status, + response=instance_check.health_response.json(), + ) + session.add(health_check_model) + + instance.health = health_status + instance.unreachable = not instance_check.reachable + + if instance_check.reachable: instance.termination_deadline = None - instance.health_status = None - instance.unreachable = False if instance.status == InstanceStatus.PROVISIONING: instance.status = InstanceStatus.IDLE if not instance.jobs else InstanceStatus.BUSY @@ -798,9 +841,6 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non if instance.termination_deadline is None: instance.termination_deadline = get_current_datetime() + TERMINATION_DEADLINE_OFFSET - instance.health_status = health_status.reason - instance.unreachable = True - if instance.status == InstanceStatus.PROVISIONING and instance.started_at is not None: provisioning_deadline = _get_provisioning_deadline( instance=instance, @@ -816,12 +856,7 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non "instance_status": InstanceStatus.TERMINATING.value, }, ) - elif instance.status in (InstanceStatus.IDLE, InstanceStatus.BUSY): - logger.warning( - "Instance %s shim is not available", - instance.name, - extra={"instance_name": instance.name}, - ) + elif instance.status.is_available(): deadline = instance.termination_deadline if get_current_datetime() > deadline: instance.status = InstanceStatus.TERMINATING @@ -892,20 +927,30 @@ async def _wait_for_instance_provisioning_data( @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) -def _instance_healthcheck(ports: Dict[int, int]) -> HealthStatus: +def _check_instance_inner( + ports: Dict[int, int], *, check_instance_health: bool = False +) -> InstanceCheck: + instance_health_response: Optional[InstanceHealthResponse] = None shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + method = shim_client.healthcheck try: - resp = shim_client.healthcheck(unmask_exeptions=True) - if resp is None: - return HealthStatus(healthy=False, reason="Unknown reason") - return runner_client.health_response_to_health_status(resp) + healthcheck_response = method(unmask_exceptions=True) + if check_instance_health: + method = shim_client.get_instance_health + instance_health_response = method() except requests.RequestException as e: - return HealthStatus(healthy=False, reason=f"Can't request shim: {e}") + template = "shim.%s(): request error: %s" + args = (method.__func__.__name__, e) + logger.warning(template, *args) + return InstanceCheck(reachable=False, message=template % args) except Exception as e: - logger.exception("Unknown exception from shim.healthcheck: %s", e) - return HealthStatus( - healthy=False, reason=f"Unknown exception ({e.__class__.__name__}): {e}" - ) + template = "shim.%s(): unexpected exception %s: %s" + args = (method.__func__.__name__, e.__class__.__name__, e) + logger.exception(template, *args) + return InstanceCheck(reachable=False, message=template % args) + return runner_client.healthcheck_response_to_instance_check( + healthcheck_response, instance_health_response + ) async def _terminate(instance: InstanceModel) -> None: diff --git a/src/dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py b/src/dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py new file mode 100644 index 0000000000..79065fccb1 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/728b1488b1b4_add_instance_health.py @@ -0,0 +1,50 @@ +"""Add instance health + +Revision ID: 728b1488b1b4 +Revises: 25479f540245 +Create Date: 2025-08-01 14:56:20.466990 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "728b1488b1b4" +down_revision = "25479f540245" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "instance_health_checks", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column( + "instance_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("collected_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.Column("status", sa.VARCHAR(length=100), nullable=False), + sa.Column("response", sa.Text(), nullable=False), + sa.ForeignKeyConstraint( + ["instance_id"], + ["instances.id"], + name=op.f("fk_instance_health_checks_instance_id_instances"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_instance_health_checks")), + ) + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.add_column(sa.Column("health", sa.VARCHAR(length=100), nullable=True)) + op.execute("UPDATE instances SET health = 'HEALTHY'") + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.alter_column("health", existing_type=sa.VARCHAR(length=100), nullable=False) + + +def downgrade() -> None: + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.drop_column("health") + + op.drop_table("instance_health_checks") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index d08208a399..1b7f003da2 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -28,6 +28,7 @@ from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.fleets import FleetStatus from dstack._internal.core.models.gateways import GatewayStatus +from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.profiles import ( DEFAULT_FLEET_TERMINATION_IDLE_TIME, @@ -599,7 +600,11 @@ class InstanceModel(BaseModel): # instance termination handling termination_deadline: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) termination_reason: Mapped[Optional[str]] = mapped_column(String(4000)) + # Deprecated since 0.19.22, not used health_status: Mapped[Optional[str]] = mapped_column(String(4000)) + health: Mapped[HealthStatus] = mapped_column( + EnumAsString(HealthStatus, 100), default=HealthStatus.HEALTHY + ) first_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) last_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) @@ -630,6 +635,21 @@ class InstanceModel(BaseModel): ) +class InstanceHealthCheckModel(BaseModel): + __tablename__ = "instance_health_checks" + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + + instance_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("instances.id")) + instance: Mapped["InstanceModel"] = relationship() + + collected_at: Mapped[datetime] = mapped_column(NaiveDateTime) + status: Mapped[HealthStatus] = mapped_column(EnumAsString(HealthStatus, 100)) + response: Mapped[str] = mapped_column(Text) + + class VolumeModel(BaseModel): __tablename__ = "volumes" diff --git a/src/dstack/_internal/server/routers/instances.py b/src/dstack/_internal/server/routers/instances.py index 740c51fd6c..67f22d9b65 100644 --- a/src/dstack/_internal/server/routers/instances.py +++ b/src/dstack/_internal/server/routers/instances.py @@ -3,12 +3,16 @@ from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession -import dstack._internal.server.services.instances as instances +import dstack._internal.server.services.instances as instances_services from dstack._internal.core.models.instances import Instance from dstack._internal.server.db import get_session -from dstack._internal.server.models import UserModel -from dstack._internal.server.schemas.instances import ListInstancesRequest -from dstack._internal.server.security.permissions import Authenticated +from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.server.schemas.instances import ( + GetInstanceHealthChecksRequest, + GetInstanceHealthChecksResponse, + ListInstancesRequest, +) +from dstack._internal.server.security.permissions import Authenticated, ProjectMember from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, @@ -19,6 +23,11 @@ tags=["instances"], responses=get_base_api_additional_responses(), ) +project_router = APIRouter( + prefix="/api/project/{project_name}/instances", + tags=["instances"], + responses=get_base_api_additional_responses(), +) @root_router.post("/list", response_model=List[Instance]) @@ -35,7 +44,7 @@ async def list_instances( the last instance from the previous page as `prev_created_at` and `prev_id`. """ return CustomORJSONResponse( - await instances.list_user_instances( + await instances_services.list_user_instances( session=session, user=user, project_names=body.project_names, @@ -47,3 +56,22 @@ async def list_instances( ascending=body.ascending, ) ) + + +@project_router.post("/get_instance_health_checks", response_model=GetInstanceHealthChecksResponse) +async def get_instance_health_checks( + body: GetInstanceHealthChecksRequest, + session: AsyncSession = Depends(get_session), + user_project: tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +): + _, project = user_project + health_checks = await instances_services.get_instance_health_checks( + session=session, + project=project, + fleet_name=body.fleet_name, + instance_num=body.instance_num, + after=body.after, + before=body.before, + limit=body.limit, + ) + return CustomORJSONResponse(GetInstanceHealthChecksResponse(health_checks=health_checks)) diff --git a/src/dstack/_internal/server/schemas/health/__init__.py b/src/dstack/_internal/server/schemas/health/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/_internal/server/schemas/health/dcgm.py b/src/dstack/_internal/server/schemas/health/dcgm.py new file mode 100644 index 0000000000..f6aeaa40e5 --- /dev/null +++ b/src/dstack/_internal/server/schemas/health/dcgm.py @@ -0,0 +1,56 @@ +from enum import IntEnum + +from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.health import HealthStatus + + +class DCGMHealthResult(IntEnum): + """ + `dcgmHealthWatchResult_enum` + + See: https://github.com/NVIDIA/go-dcgm/blob/85ceb31/pkg/dcgm/const.go#L1020-L1026 + """ + + DCGM_HEALTH_RESULT_PASS = 0 + DCGM_HEALTH_RESULT_WARN = 10 + DCGM_HEALTH_RESULT_FAIL = 20 + + def to_health_status(self) -> HealthStatus: + if self == self.DCGM_HEALTH_RESULT_PASS: + return HealthStatus.HEALTHY + if self == self.DCGM_HEALTH_RESULT_WARN: + return HealthStatus.WARNING + if self == self.DCGM_HEALTH_RESULT_FAIL: + return HealthStatus.FAILURE + raise AssertionError("should not reach here") + + +class DCGMHealthIncident(CoreModel): + """ + Flattened `dcgmIncidentInfo_t` + + See: https://github.com/NVIDIA/go-dcgm/blob/85ceb31/pkg/dcgm/health.go#L68-L73 + """ + + # dcgmIncidentInfo_t + system: int + health: DCGMHealthResult + + # dcgmDiagErrorDetail_t + error_message: str + error_code: int + + # dcgmGroupEntityPair_t + entity_group_id: int # dcgmGroupEntityPair_t + entity_id: int + + +class DCGMHealthResponse(CoreModel): + """ + `dcgmHealthResponse_v5` + + See: https://github.com/NVIDIA/go-dcgm/blob/85ceb31/pkg/dcgm/health.go#L75-L78 + """ + + overall_health: DCGMHealthResult + incidents: list[DCGMHealthIncident] diff --git a/src/dstack/_internal/server/schemas/instances.py b/src/dstack/_internal/server/schemas/instances.py index 0b1d6ccaad..60843c6293 100644 --- a/src/dstack/_internal/server/schemas/instances.py +++ b/src/dstack/_internal/server/schemas/instances.py @@ -3,6 +3,8 @@ from uuid import UUID from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.health import HealthCheck, HealthStatus +from dstack._internal.server.schemas.runner import InstanceHealthResponse class ListInstancesRequest(CoreModel): @@ -13,3 +15,33 @@ class ListInstancesRequest(CoreModel): prev_id: Optional[UUID] = None limit: int = 1000 ascending: bool = False + + +class InstanceCheck(CoreModel): + reachable: bool + message: Optional[str] = None + health_response: Optional[InstanceHealthResponse] = None + + def get_health_status(self) -> HealthStatus: + if self.health_response is None: + return HealthStatus.HEALTHY + if self.health_response.dcgm is None: + return HealthStatus.HEALTHY + return self.health_response.dcgm.overall_health.to_health_status() + + def has_health_checks(self) -> bool: + if self.health_response is None: + return False + return self.health_response.dcgm is not None + + +class GetInstanceHealthChecksRequest(CoreModel): + fleet_name: str + instance_num: int + after: Optional[datetime] = None + before: Optional[datetime] = None + limit: Optional[int] = None + + +class GetInstanceHealthChecksResponse(CoreModel): + health_checks: list[HealthCheck] diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index f8d59343d7..f71d60055c 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -16,6 +16,7 @@ RunSpec, ) from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint +from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse class JobStateEvent(CoreModel): @@ -114,6 +115,10 @@ class HealthcheckResponse(CoreModel): version: str +class InstanceHealthResponse(CoreModel): + dcgm: Optional[DCGMHealthResponse] = None + + class GPUMetrics(CoreModel): gpu_memory_usage_bytes: int gpu_util_percent: int diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 4636f4ab0e..cf53746d8b 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -1,3 +1,4 @@ +import operator import uuid from collections.abc import Container, Iterable from datetime import datetime @@ -6,15 +7,17 @@ import gpuhunt from sqlalchemy import and_, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, load_only from dstack._internal.core.backends.base.offers import ( offer_to_catalog_item, requirements_to_query_filter, ) from dstack._internal.core.backends.features import BACKENDS_WITH_MULTINODE_SUPPORT +from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.envs import Env +from dstack._internal.core.models.health import HealthCheck, HealthEvent, HealthStatus from dstack._internal.core.models.instances import ( Instance, InstanceAvailability, @@ -38,10 +41,13 @@ from dstack._internal.core.services.profiles import get_termination from dstack._internal.server.models import ( FleetModel, + InstanceHealthCheckModel, InstanceModel, ProjectModel, UserModel, ) +from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse +from dstack._internal.server.schemas.runner import InstanceHealthResponse from dstack._internal.server.services.offers import generate_shared_offer from dstack._internal.server.services.projects import list_user_project_models from dstack._internal.utils import common as common_utils @@ -50,6 +56,57 @@ logger = get_logger(__name__) +async def get_instance_health_checks( + session: AsyncSession, + project: ProjectModel, + fleet_name: str, + instance_num: int, + after: Optional[datetime] = None, + before: Optional[datetime] = None, + limit: Optional[int] = None, +) -> list[HealthCheck]: + """ + Returns instance health checks ordered from the latest to the earliest. + + Expected usage: + * limit=100 — get the latest 100 checks + * after= — get checks for the last hour + * before=, limit=100 ­— paginate back in history + """ + res = await session.execute( + select(InstanceModel) + .join(FleetModel) + .where( + ~InstanceModel.deleted, + InstanceModel.project_id == project.id, + InstanceModel.instance_num == instance_num, + FleetModel.name == fleet_name, + ) + .options(load_only(InstanceModel.id)) + ) + instance = res.scalar_one_or_none() + if instance is None: + raise ResourceNotExistsError() + + stmt = ( + select(InstanceHealthCheckModel) + .where(InstanceHealthCheckModel.instance_id == instance.id) + .order_by(InstanceHealthCheckModel.collected_at.desc()) + ) + if after is not None: + stmt = stmt.where(InstanceHealthCheckModel.collected_at > after) + if before is not None: + stmt = stmt.where(InstanceHealthCheckModel.collected_at < before) + if limit is not None: + stmt = stmt.limit(limit) + health_checks: list[HealthCheck] = [] + res = await session.execute(stmt) + for health_check_model in res.scalars(): + health_check = instance_health_check_model_to_health_check(health_check_model) + health_checks.append(health_check) + return health_checks + + def instance_model_to_instance(instance_model: InstanceModel) -> Instance: instance = Instance( id=instance_model.id, @@ -60,6 +117,7 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: instance_num=instance_model.instance_num, status=instance_model.status, unreachable=instance_model.unreachable, + health_status=instance_model.health, termination_reason=instance_model.termination_reason, created=instance_model.created_at, total_blocks=instance_model.total_blocks, @@ -81,6 +139,48 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: return instance +def instance_health_check_model_to_health_check(model: InstanceHealthCheckModel) -> HealthCheck: + collected_at = model.collected_at + status = HealthStatus.HEALTHY + events: list[HealthEvent] = [] + instance_health_response = get_instance_health_response(model) + if (dcgm := instance_health_response.dcgm) is not None: + dcgm_health_check = dcgm_health_response_to_health_check(dcgm, collected_at) + status = dcgm_health_check.status + events.extend(dcgm_health_check.events) + events.sort(key=operator.attrgetter("timestamp"), reverse=True) + return HealthCheck( + collected_at=collected_at, + status=status, + events=events, + ) + + +def dcgm_health_response_to_health_check( + response: DCGMHealthResponse, collected_at: datetime +) -> HealthCheck: + events: list[HealthEvent] = [] + for incident in response.incidents: + events.append( + HealthEvent( + timestamp=collected_at, + status=incident.health.to_health_status(), + message=incident.error_message, + ) + ) + return HealthCheck( + collected_at=collected_at, + status=response.overall_health.to_health_status(), + events=events, + ) + + +def get_instance_health_response( + instance_health_check_model: InstanceHealthCheckModel, +) -> InstanceHealthResponse: + return InstanceHealthResponse.__response__.parse_raw(instance_health_check_model.response) + + def get_instance_provisioning_data(instance_model: InstanceModel) -> Optional[JobProvisioningData]: if instance_model.job_provisioning_data is None: return None @@ -194,6 +294,8 @@ def filter_pool_instances( continue if instance.unreachable: continue + if instance.health.is_failure(): + continue fleet = instance.fleet if profile.fleets is not None and (fleet is None or fleet.name not in profile.fleets): continue diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index be6a11be83..7b7c31dd8f 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -1,7 +1,6 @@ import uuid -from dataclasses import dataclass from http import HTTPStatus -from typing import BinaryIO, Dict, List, Optional, TypeVar, Union +from typing import BinaryIO, Dict, List, Literal, Optional, TypeVar, Union, overload import packaging.version import requests @@ -14,9 +13,11 @@ from dstack._internal.core.models.resources import Memory from dstack._internal.core.models.runs import ClusterInfo, Job, Run from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint +from dstack._internal.server.schemas.instances import InstanceCheck from dstack._internal.server.schemas.runner import ( GPUDevice, HealthcheckResponse, + InstanceHealthResponse, LegacyPullResponse, LegacyStopBody, LegacySubmitBody, @@ -37,15 +38,6 @@ logger = get_logger(__name__) -@dataclass -class HealthStatus: - healthy: bool - reason: str - - def __str__(self) -> str: - return self.reason - - class RunnerClient: def __init__( self, @@ -193,6 +185,9 @@ class ShimClient: # API v1 (a.k.a. Legacy API) — `/api/{submit,pull,stop}` _API_V2_MIN_SHIM_VERSION = (0, 18, 34) + # `/api/instance/health` + _INSTANCE_HEALTH_MIN_SHIM_VERSION = (0, 19, 22) + _shim_version: Optional["_Version"] _api_version: int _negotiated: bool = False @@ -212,11 +207,25 @@ def is_api_v2_supported(self) -> bool: self._negotiate() return self._api_version == 2 - def healthcheck(self, unmask_exeptions: bool = False) -> Optional[HealthcheckResponse]: + def is_instance_health_supported(self) -> bool: + if not self._negotiated: + self._negotiate() + return ( + self._shim_version is None + or self._shim_version >= self._INSTANCE_HEALTH_MIN_SHIM_VERSION + ) + + @overload + def healthcheck(self) -> Optional[HealthcheckResponse]: ... + + @overload + def healthcheck(self, unmask_exceptions: Literal[True]) -> HealthcheckResponse: ... + + def healthcheck(self, unmask_exceptions: bool = False) -> Optional[HealthcheckResponse]: try: resp = self._request("GET", "/api/healthcheck", raise_for_status=True) except requests.exceptions.RequestException: - if unmask_exeptions: + if unmask_exceptions: raise return None if not self._negotiated: @@ -225,6 +234,17 @@ def healthcheck(self, unmask_exeptions: bool = False) -> Optional[HealthcheckRes # API v2 methods + def get_instance_health(self) -> Optional[InstanceHealthResponse]: + if not self.is_instance_health_supported(): + logger.debug("instance health is not supported: %s", self._shim_version) + return None + resp = self._request("GET", "/api/instance/health") + if resp.status_code == HTTPStatus.NOT_FOUND: + logger.warning("instance health: %s", resp.text) + return None + self._raise_for_status(resp) + return self._response(InstanceHealthResponse, resp) + def get_task(self, task_id: "_TaskID") -> TaskInfoResponse: if not self.is_api_v2_supported(): raise ShimAPIVersionError() @@ -418,14 +438,26 @@ def _negotiate(self, healthcheck_response: Optional[requests.Response] = None) - self._negotiated = True -def health_response_to_health_status(data: HealthcheckResponse) -> HealthStatus: - if data.service == "dstack-shim": - return HealthStatus(healthy=True, reason="Service is OK") - else: - return HealthStatus( - healthy=False, - reason=f"Service name is {data.service}, service version: {data.version}", +def healthcheck_response_to_instance_check( + response: HealthcheckResponse, + instance_health_response: Optional[InstanceHealthResponse] = None, +) -> InstanceCheck: + if response.service == "dstack-shim": + message: Optional[str] = None + if ( + instance_health_response is not None + and instance_health_response.dcgm is not None + and instance_health_response.dcgm.incidents + ): + message = instance_health_response.dcgm.incidents[0].error_message + return InstanceCheck( + reachable=True, health_response=instance_health_response, message=message ) + return InstanceCheck( + reachable=False, + message=f"unexpected service: {response.service} version: {response.version}", + health_response=instance_health_response, + ) def _volume_to_shim_volume_info(volume: Volume, instance_id: str) -> ShimVolumeInfo: diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 95d5f1ab53..669ea2181d 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -2,7 +2,7 @@ import socket import time from collections.abc import Iterable -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union import requests from typing_extensions import Concatenate, ParamSpec @@ -27,7 +27,7 @@ def runner_ssh_tunnel( [Callable[Concatenate[Dict[int, int], P], R]], Callable[ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], - Union[bool, R], + Union[Literal[False], R], ], ]: """ @@ -42,7 +42,7 @@ def decorator( func: Callable[Concatenate[Dict[int, int], P], R], ) -> Callable[ Concatenate[PrivateKeyOrPair, JobProvisioningData, Optional[JobRuntimeData], P], - Union[bool, R], + Union[Literal[False], R], ]: @functools.wraps(func) def wrapper( @@ -51,7 +51,7 @@ def wrapper( job_runtime_data: Optional[JobRuntimeData], *args: P.args, **kwargs: P.kwargs, - ) -> Union[bool, R]: + ) -> Union[Literal[False], R]: """ Returns: is successful diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 8763d15f78..ac4924d9c4 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -93,6 +93,13 @@ os.getenv("DSTACK_SERVER_METRICS_FINISHED_TTL_SECONDS", 7 * 24 * 3600) ) +SERVER_INSTANCE_HEALTH_TTL_SECONDS = int( + os.getenv("DSTACK_SERVER_INSTANCE_HEALTH_TTL_SECONDS", 7 * 24 * 3600) +) +SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS = int( + os.getenv("DSTACK_SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS", 60) +) + SERVER_KEEP_SHIM_TASKS = os.getenv("DSTACK_SERVER_KEEP_SHIM_TASKS") is not None DEFAULT_PROJECT_NAME = "main" diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 7b3630955f..646a03e88b 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -35,6 +35,7 @@ SSHParams, ) from dstack._internal.core.models.gateways import GatewayComputeConfiguration, GatewayStatus +from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( Disk, Gpu, @@ -85,6 +86,7 @@ FleetModel, GatewayComputeModel, GatewayModel, + InstanceHealthCheckModel, InstanceModel, JobMetricsPoint, JobModel, @@ -614,6 +616,7 @@ async def create_instance( fleet: Optional[FleetModel] = None, status: InstanceStatus = InstanceStatus.IDLE, unreachable: bool = False, + health_status: HealthStatus = HealthStatus.HEALTHY, created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), finished_at: Optional[datetime] = None, spot: bool = False, @@ -676,6 +679,7 @@ async def create_instance( status=status, last_processed_at=last_processed_at, unreachable=unreachable, + health=health_status, created_at=created_at, started_at=created_at, finished_at=finished_at, @@ -796,6 +800,24 @@ def get_ssh_key() -> SSHKey: ) +async def create_instance_health_check( + session: AsyncSession, + instance: InstanceModel, + collected_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + status: HealthStatus = HealthStatus.HEALTHY, + response: str = "{}", +) -> InstanceHealthCheckModel: + health_check = InstanceHealthCheckModel( + instance_id=instance.id, + collected_at=collected_at, + status=status, + response=response, + ) + session.add(health_check) + await session.commit() + return health_check + + async def create_volume( session: AsyncSession, project: ProjectModel, diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index 73aea904b3..990146f3b7 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -18,6 +18,7 @@ ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.fleets import InstanceGroupPlacement +from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( Gpu, InstanceAvailability, @@ -34,14 +35,18 @@ JobStatus, ) from dstack._internal.server.background.tasks.process_instances import ( - HealthStatus, + delete_instance_health_checks, process_instances, ) -from dstack._internal.server.models import PlacementGroupModel +from dstack._internal.server.models import InstanceHealthCheckModel, PlacementGroupModel +from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import InstanceHealthResponse from dstack._internal.server.testing.common import ( ComputeMockSpec, create_fleet, create_instance, + create_instance_health_check, create_job, create_project, create_repo, @@ -72,14 +77,13 @@ async def test_check_shim_transitions_provisioning_on_ready( status=InstanceStatus.PROVISIONING, ) instance.termination_deadline = get_current_datetime() + dt.timedelta(days=1) - instance.health_status = "ssh connect problem" await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._instance_healthcheck" + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" ) as healthcheck: - healthcheck.return_value = HealthStatus(healthy=True, reason="OK") + healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() await session.refresh(instance) @@ -87,7 +91,6 @@ async def test_check_shim_transitions_provisioning_on_ready( assert instance is not None assert instance.status == InstanceStatus.IDLE assert instance.termination_deadline is None - assert instance.health_status is None @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -101,16 +104,15 @@ async def test_check_shim_transitions_provisioning_on_terminating( status=InstanceStatus.PROVISIONING, ) instance.started_at = get_current_datetime() + dt.timedelta(minutes=-20) - instance.health_status = "ssh connect problem" await session.commit() health_reason = "Shim problem" with patch( - "dstack._internal.server.background.tasks.process_instances._instance_healthcheck" + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" ) as healthcheck: - healthcheck.return_value = HealthStatus(healthy=False, reason=health_reason) + healthcheck.return_value = InstanceCheck(reachable=False, message=health_reason) await process_instances() await session.refresh(instance) @@ -118,7 +120,6 @@ async def test_check_shim_transitions_provisioning_on_terminating( assert instance is not None assert instance.status == InstanceStatus.TERMINATING assert instance.termination_deadline is not None - assert instance.health_status == health_reason @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -145,7 +146,6 @@ async def test_check_shim_transitions_provisioning_on_busy( instance.termination_deadline = get_current_datetime().replace( tzinfo=dt.timezone.utc ) + dt.timedelta(days=1) - instance.health_status = "ssh connect problem" job = await create_job( session=session, @@ -157,9 +157,9 @@ async def test_check_shim_transitions_provisioning_on_busy( await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._instance_healthcheck" + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" ) as healthcheck: - healthcheck.return_value = HealthStatus(healthy=True, reason="OK") + healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() await session.refresh(instance) @@ -168,7 +168,6 @@ async def test_check_shim_transitions_provisioning_on_busy( assert instance is not None assert instance.status == InstanceStatus.BUSY assert instance.termination_deadline is None - assert instance.health_status is None assert job.instance == instance @pytest.mark.asyncio @@ -182,9 +181,9 @@ async def test_check_shim_start_termination_deadline(self, test_db, session: Asy ) health_status = "SSH connection fail" with patch( - "dstack._internal.server.background.tasks.process_instances._instance_healthcheck" + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" ) as healthcheck: - healthcheck.return_value = HealthStatus(healthy=False, reason=health_status) + healthcheck.return_value = InstanceCheck(reachable=False, message=health_status) await process_instances() await session.refresh(instance) @@ -195,7 +194,6 @@ async def test_check_shim_start_termination_deadline(self, test_db, session: Asy assert instance.termination_deadline.replace( tzinfo=dt.timezone.utc ) > get_current_datetime() + dt.timedelta(minutes=19) - assert instance.health_status == health_status @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -210,9 +208,9 @@ async def test_check_shim_stop_termination_deadline(self, test_db, session: Asyn await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._instance_healthcheck" + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" ) as healthcheck: - healthcheck.return_value = HealthStatus(healthy=True, reason="OK") + healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() await session.refresh(instance) @@ -220,7 +218,6 @@ async def test_check_shim_stop_termination_deadline(self, test_db, session: Asyn assert instance is not None assert instance.status == InstanceStatus.IDLE assert instance.termination_deadline is None - assert instance.health_status is None @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -237,9 +234,9 @@ async def test_check_shim_terminate_instance_by_deadline(self, test_db, session: health_status = "Not ok" with patch( - "dstack._internal.server.background.tasks.process_instances._instance_healthcheck" + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" ) as healthcheck: - healthcheck.return_value = HealthStatus(healthy=False, reason=health_status) + healthcheck.return_value = InstanceCheck(reachable=False, message=health_status) await process_instances() await session.refresh(instance) @@ -248,7 +245,6 @@ async def test_check_shim_terminate_instance_by_deadline(self, test_db, session: assert instance.status == InstanceStatus.TERMINATING assert instance.termination_deadline == termination_deadline_time assert instance.termination_reason == "Termination deadline" - assert instance.health_status == health_status @pytest.mark.asyncio @pytest.mark.parametrize( @@ -302,9 +298,9 @@ async def test_check_shim_process_ureachable_state( await session.commit() with patch( - "dstack._internal.server.background.tasks.process_instances._instance_healthcheck" + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" ) as healthcheck: - healthcheck.return_value = HealthStatus(healthy=True, reason="OK") + healthcheck.return_value = InstanceCheck(reachable=True) await process_instances() healthcheck.assert_called() @@ -314,6 +310,72 @@ async def test_check_shim_process_ureachable_state( assert instance.status == InstanceStatus.IDLE assert not instance.unreachable + @pytest.mark.asyncio + @pytest.mark.parametrize("health_status", [HealthStatus.HEALTHY, HealthStatus.FAILURE]) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_check_shim_switch_to_unreachable_state( + self, test_db, session: AsyncSession, health_status: HealthStatus + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + health_status=health_status, + ) + + with patch( + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + ) as healthcheck: + healthcheck.return_value = InstanceCheck(reachable=False) + await process_instances() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable + # Should keep the previous status + assert instance.health == health_status + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_check_shim_check_instance_health(self, test_db, session: AsyncSession): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + health_status=HealthStatus.HEALTHY, + ) + health_response = InstanceHealthResponse( + dcgm=DCGMHealthResponse( + overall_health=DCGMHealthResult.DCGM_HEALTH_RESULT_WARN, incidents=[] + ) + ) + + with patch( + "dstack._internal.server.background.tasks.process_instances._check_instance_inner" + ) as healthcheck: + healthcheck.return_value = InstanceCheck( + reachable=True, health_response=health_response + ) + await process_instances() + + await session.refresh(instance) + + assert instance is not None + assert instance.status == InstanceStatus.IDLE + assert not instance.unreachable + assert instance.health == HealthStatus.WARNING + + res = await session.execute(select(InstanceHealthCheckModel)) + health_check = res.scalars().one() + assert health_check.status == HealthStatus.WARNING + assert health_check.response == health_response.json() + class TestTerminateIdleTime: @pytest.mark.asyncio @@ -879,7 +941,7 @@ def host_info(self) -> dict: @pytest.fixture def deploy_instance_mock(self, monkeypatch: pytest.MonkeyPatch, host_info: dict): - mock = Mock(return_value=(HealthStatus(healthy=True, reason="OK"), host_info, "amd64")) + mock = Mock(return_value=(InstanceCheck(reachable=True), host_info, "amd64")) monkeypatch.setattr( "dstack._internal.server.background.tasks.process_instances._deploy_instance", mock ) @@ -933,3 +995,36 @@ async def test_adds_ssh_instance( assert instance.status == InstanceStatus.IDLE assert instance.total_blocks == expected_blocks assert instance.busy_blocks == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db", "image_config_mock") +class TestDeleteInstanceHealthChecks: + async def test_deletes_instance_health_checks( + self, monkeypatch: pytest.MonkeyPatch, session: AsyncSession + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.IDLE + ) + # 30 minutes + monkeypatch.setattr( + "dstack._internal.server.settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS", 1800 + ) + now = get_current_datetime() + # old check + await create_instance_health_check( + session=session, instance=instance, collected_at=now - dt.timedelta(minutes=40) + ) + # recent check + check = await create_instance_health_check( + session=session, instance=instance, collected_at=now - dt.timedelta(minutes=20) + ) + + await delete_instance_health_checks() + + res = await session.execute(select(InstanceHealthCheckModel)) + all_checks = res.scalars().all() + assert len(all_checks) == 1 + assert all_checks[0] == check diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 491ebb4b03..901a91be4d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -8,6 +8,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import TaskConfiguration +from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceOfferWithAvailability, @@ -425,6 +426,56 @@ async def test_assignes_job_to_instance(self, test_db, session: AsyncSession): job.instance_assigned and job.instance is not None and job.instance.id == instance.id ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_does_no_reuse_unavailable_instances(self, test_db, session: AsyncSession): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + # busy + await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + # unreachable + await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=True, + ) + # fatal health issue + await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + health_status=HealthStatus.FAILURE, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=False, + ) + + await process_submitted_jobs() + + await session.refresh(job) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.instance is None + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_assigns_job_to_instance_with_volumes(self, test_db, session: AsyncSession): diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index b52ebc543b..33fc73e019 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -400,6 +400,7 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async "hostname": None, "status": "pending", "unreachable": False, + "health_status": "healthy", "termination_reason": None, "created": "2023-01-02T03:04:00+00:00", "backend": None, @@ -534,6 +535,7 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A "hostname": "1.1.1.1", "status": "pending", "unreachable": False, + "health_status": "healthy", "termination_reason": None, "created": "2023-01-02T03:04:00+00:00", "region": "remote", @@ -704,6 +706,7 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A "hostname": "10.0.0.100", "status": "terminating", "unreachable": False, + "health_status": "healthy", "termination_reason": None, "created": "2023-01-02T03:04:00+00:00", "region": "remote", @@ -736,6 +739,7 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A "hostname": "10.0.0.101", "status": "pending", "unreachable": False, + "health_status": "healthy", "termination_reason": None, "created": "2023-01-02T03:04:00+00:00", "region": "remote", diff --git a/src/tests/_internal/server/routers/test_instances.py b/src/tests/_internal/server/routers/test_instances.py index c44b7ff291..72198331d0 100644 --- a/src/tests/_internal/server/routers/test_instances.py +++ b/src/tests/_internal/server/routers/test_instances.py @@ -15,6 +15,7 @@ from dstack._internal.server.testing.common import ( create_fleet, create_instance, + create_instance_health_check, create_project, create_user, get_auth_headers, @@ -265,3 +266,109 @@ async def test_pagination( async def test_not_authenticated(self, client: AsyncClient, data) -> None: resp = await client.post("/api/instances/list", json={}) assert resp.status_code == 403 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db") +class TestGetInstanceHealthChecks: + async def test_returns_403_if_not_project_member( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + response = await client.post( + f"/api/project/{project.name}/instances/get_instance_health_checks", + headers=get_auth_headers(user.token), + json={ + "fleet_name": "test", + "instance_num": 0, + }, + ) + assert response.status_code == 403 + + async def test_returns_400_if_instance_not_found( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session, project=project, user=user, project_role=ProjectRole.USER + ) + + response = await client.post( + f"/api/project/{project.name}/instances/get_instance_health_checks", + headers=get_auth_headers(user.token), + json={ + "fleet_name": "test", + "instance_num": 0, + }, + ) + assert response.status_code == 400 + + async def test_returns_health_checks(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + ) + await create_instance_health_check( + session=session, + instance=instance, + collected_at=dt.datetime(2025, 1, 1, 12, 0, tzinfo=dt.timezone.utc), + response="{}", + ) + health_response_with_dcgm = """ + { + "dcgm": { + "overall_health": 20, + "incidents": [{ + "system": 16, + "health": 20, + "error_message": "Detected 333 volatile double-bit ECC error(s) in GPU 0.", + "error_code": 4, + "entity_group_id": 1, + "entity_id": 0 + }] + } + } + """ + await create_instance_health_check( + session=session, + instance=instance, + collected_at=dt.datetime(2025, 1, 1, 12, 1, tzinfo=dt.timezone.utc), + response=health_response_with_dcgm, + ) + + response = await client.post( + f"/api/project/{project.name}/instances/get_instance_health_checks", + headers=get_auth_headers(user.token), + json={ + "fleet_name": fleet.name, + "instance_num": instance.instance_num, + }, + ) + + assert response.status_code == 200 + assert response.json() == { + "health_checks": [ + { + "collected_at": "2025-01-01T12:01:00+00:00", + "status": "failure", + "events": [ + { + "timestamp": "2025-01-01T12:01:00+00:00", + "status": "failure", + "message": "Detected 333 volatile double-bit ECC error(s) in GPU 0.", + } + ], + }, + {"collected_at": "2025-01-01T12:00:00+00:00", "status": "healthy", "events": []}, + ] + } diff --git a/src/tests/_internal/server/services/test_instances.py b/src/tests/_internal/server/services/test_instances.py index 055645852b..414a360c21 100644 --- a/src/tests/_internal/server/services/test_instances.py +++ b/src/tests/_internal/server/services/test_instances.py @@ -5,6 +5,7 @@ import dstack._internal.server.services.instances as instances_services from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( Instance, InstanceStatus, @@ -130,6 +131,8 @@ async def test_converts_instance(self, test_db, session: AsyncSession): instance_num=0, hostname="hostname_test", status=InstanceStatus.PENDING, + unreachable=False, + health_status=HealthStatus.WARNING, created=created, region="eu-west-1", price=1.0, @@ -143,6 +146,7 @@ async def test_converts_instance(self, test_db, session: AsyncSession): instance_num=0, status=InstanceStatus.PENDING, unreachable=False, + health=HealthStatus.WARNING, project=project, job_provisioning_data='{"ssh_proxy":null, "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', From 50e19c488ce4f4e76e5e011ba92a981c1984baee Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 6 Aug 2025 17:05:41 +0000 Subject: [PATCH 2/5] CI: Fix shim build --- .github/workflows/build.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0726a65efe..b24b74cc25 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -181,11 +181,10 @@ jobs: env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} - CGO_ENABLED: 0 run: | VERSION=$((${{ github.run_number }} + ${{ env.BUILD_INCREMENT }})) - go build -ldflags "-X 'main.Version=$VERSION' -extldflags '-static'" -o dstack-runner-$GOOS-$GOARCH $REPO_NAME/runner/cmd/runner - go build -ldflags "-X 'main.Version=$VERSION' -extldflags '-static'" -o dstack-shim-$GOOS-$GOARCH $REPO_NAME/runner/cmd/shim + CGO_ENABLED=0 go build -ldflags "-X 'main.Version=$VERSION' -extldflags '-static'" -o dstack-runner-$GOOS-$GOARCH $REPO_NAME/runner/cmd/runner + CGO_ENABLED=1 go build -ldflags "-X 'main.Version=$VERSION'" -o dstack-shim-$GOOS-$GOARCH $REPO_NAME/runner/cmd/shim echo $VERSION - uses: actions/upload-artifact@v4 with: From d0fbba7408608373c0c0af5aa6654bfe328ea11c Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 6 Aug 2025 17:11:48 +0000 Subject: [PATCH 3/5] CI: disable shim/runner tests on macOS cannot use _Ctype_long(ts) (value of type _Ctype_long) as _Ctype_int64_t value in struct literal cannot use _Ctype_ulong(0) (constant 0 of type _Ctype_ulong) as _Ctype_uint64_t value in argument to (_Cfunc_dcgmPolicyRegister_v2)) --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b24b74cc25..e2ecabe916 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -126,7 +126,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest] steps: - uses: actions/checkout@v4 - name: Set up Go From bbdaac3edac25576d27227a22d107a656c42ece0 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 6 Aug 2025 17:35:46 +0000 Subject: [PATCH 4/5] CI: compile ARM binaries on ARM runners --- .github/workflows/build.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e2ecabe916..70b02fc6a9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -167,9 +167,9 @@ jobs: strategy: matrix: include: - - { goos: "linux", goarch: "amd64" } - - { goos: "linux", goarch: "arm64" } - runs-on: ubuntu-latest + - { runs-on: "ubuntu-24.04", goos: "linux", goarch: "amd64" } + - { runs-on: "ubuntu-24.04-arm", goos: "linux", goarch: "arm64" } + runs-on: ${{ matrix.runs-on }} steps: - uses: actions/checkout@v4 - name: Set up Go From 6d5c3398ee9ec695109396e2281adcfa7f5167a2 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Thu, 7 Aug 2025 08:08:45 +0000 Subject: [PATCH 5/5] Move provisioning.py from core to server --- .../core/backends/remote/__init__.py | 0 .../background/tasks/process_instances.py | 32 +++++++++++-------- .../remote => server/utils}/provisioning.py | 10 ++---- 3 files changed, 21 insertions(+), 21 deletions(-) delete mode 100644 src/dstack/_internal/core/backends/remote/__init__.py rename src/dstack/_internal/{core/backends/remote => server/utils}/provisioning.py (97%) diff --git a/src/dstack/_internal/core/backends/remote/__init__.py b/src/dstack/_internal/core/backends/remote/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index de3ab7d0e1..ec3d4d387f 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -28,18 +28,6 @@ BACKENDS_WITH_CREATE_INSTANCE_SUPPORT, BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, ) -from dstack._internal.core.backends.remote.provisioning import ( - detect_cpu_arch, - get_host_info, - get_paramiko_connection, - get_shim_healthcheck, - host_info_to_instance_type, - remove_dstack_runner_if_exists, - remove_host_info_if_exists, - run_pre_start_commands, - run_shim_as_systemd_service, - upload_envs, -) from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT # FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute @@ -85,7 +73,7 @@ ProjectModel, ) from dstack._internal.server.schemas.instances import InstanceCheck -from dstack._internal.server.schemas.runner import InstanceHealthResponse +from dstack._internal.server.schemas.runner import HealthcheckResponse, InstanceHealthResponse from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.fleets import ( fleet_model_to_fleet, @@ -108,6 +96,18 @@ from dstack._internal.server.services.runner import client as runner_client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.server.utils import sentry_utils +from dstack._internal.server.utils.provisioning import ( + detect_cpu_arch, + get_host_info, + get_paramiko_connection, + get_shim_healthcheck, + host_info_to_instance_type, + remove_dstack_runner_if_exists, + remove_host_info_if_exists, + run_pre_start_commands, + run_shim_as_systemd_service, + upload_envs, +) from dstack._internal.utils.common import ( get_current_datetime, get_or_error, @@ -479,7 +479,11 @@ def _deploy_instance( host_info = get_host_info(client, dstack_working_dir) logger.debug("Received a host_info %s", host_info) - healthcheck = get_shim_healthcheck(client) + healthcheck_out = get_shim_healthcheck(client) + try: + healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out) + except ValueError as e: + raise ProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck) return instance_check, host_info, arch diff --git a/src/dstack/_internal/core/backends/remote/provisioning.py b/src/dstack/_internal/server/utils/provisioning.py similarity index 97% rename from src/dstack/_internal/core/backends/remote/provisioning.py rename to src/dstack/_internal/server/utils/provisioning.py index c4731c8a4e..94a5347343 100644 --- a/src/dstack/_internal/core/backends/remote/provisioning.py +++ b/src/dstack/_internal/server/utils/provisioning.py @@ -20,7 +20,6 @@ Resources, SSHConnectionParams, ) -from dstack._internal.server.schemas.runner import HealthcheckResponse from dstack._internal.utils.gpu import ( convert_amd_gpu_name, convert_intel_accelerator_name, @@ -221,7 +220,7 @@ def get_host_info(client: paramiko.SSHClient, working_dir: str) -> Dict[str, Any raise ProvisioningError("Cannot get host_info") -def get_shim_healthcheck(client: paramiko.SSHClient) -> HealthcheckResponse: +def get_shim_healthcheck(client: paramiko.SSHClient) -> str: retries = 20 iter_delay = 3 for _ in range(retries): @@ -233,7 +232,7 @@ def get_shim_healthcheck(client: paramiko.SSHClient) -> HealthcheckResponse: raise ProvisioningError("Cannot get HealthcheckResponse") -def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[HealthcheckResponse]: +def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[str]: try: _, stdout, stderr = client.exec_command( f"curl -s http://localhost:{DSTACK_SHIM_HTTP_PORT}/api/healthcheck", timeout=15 @@ -246,10 +245,7 @@ def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[HealthcheckRes raise ProvisioningError(f"get_shim_healthcheck didn't work. stdout: {out}, stderr: {err}") if not out: return None - try: - return HealthcheckResponse.__response__.parse_raw(out) - except ValueError as e: - raise ProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e + return out def host_info_to_instance_type(host_info: Dict[str, Any], cpu_arch: GoArchType) -> InstanceType: