diff --git a/runner/internal/shim/api/handlers.go b/runner/internal/shim/api/handlers.go index 94eb485076..1374fbd803 100644 --- a/runner/internal/shim/api/handlers.go +++ b/runner/internal/shim/api/handlers.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net/http" - "reflect" "github.com/dstackai/dstack/runner/internal/api" "github.com/dstackai/dstack/runner/internal/log" @@ -25,7 +24,7 @@ func (s *ShimServer) HealthcheckHandler(w http.ResponseWriter, r *http.Request) func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { ctx := r.Context() response := InstanceHealthResponse{} - if !reflect.ValueOf(s.dcgmWrapper).IsNil() { + if s.dcgmWrapper != nil { if dcgmHealth, err := s.dcgmWrapper.GetHealth(); err != nil { log.Error(ctx, "failed to get health from DCGM", "err", err) } else { diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index e012a0b76c..4ba67a1f94 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/http" + "reflect" "sync" "github.com/dstackai/dstack/runner/internal/api" @@ -29,7 +30,7 @@ type ShimServer struct { runner TaskRunner dcgmExporter *dcgm.DCGMExporter - dcgmWrapper dcgm.DCGMWrapperInterface + dcgmWrapper dcgm.DCGMWrapperInterface // interface with nil value normalized to plain nil version string } @@ -38,6 +39,9 @@ func NewShimServer( ctx context.Context, address string, version string, runner TaskRunner, dcgmExporter *dcgm.DCGMExporter, dcgmWrapper dcgm.DCGMWrapperInterface, ) *ShimServer { + if dcgmWrapper != nil && reflect.ValueOf(dcgmWrapper).IsNil() { + dcgmWrapper = nil + } r := api.NewRouter() s := &ShimServer{ HttpServer: &http.Server{