Skip to content

Commit 9afdf2d

Browse files
committed
Implement shim auto-update
shim binary is replaced at any time, but restart is postponed until all tasks are terminated, as safe restart with running tasks requires additional work (see _get_restart_safe_task_statuses() comment). Closes: #3288
1 parent 9ea6bf7 commit 9afdf2d

File tree

22 files changed

+1043
-259
lines changed

22 files changed

+1043
-259
lines changed

runner/cmd/shim/main.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ func mainInner() int {
4040
log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel))
4141
log.DefaultEntry.Logger.SetOutput(os.Stderr)
4242

43+
shimBinaryPath, err := os.Executable()
44+
if err != nil {
45+
shimBinaryPath = consts.ShimBinaryPath
46+
}
47+
4348
cmd := &cli.Command{
4449
Name: "dstack-shim",
4550
Usage: "Starts dstack-runner or docker container.",
@@ -54,6 +59,14 @@ func mainInner() int {
5459
DefaultText: path.Join("~", consts.DstackDirPath),
5560
Sources: cli.EnvVars("DSTACK_SHIM_HOME"),
5661
},
62+
&cli.StringFlag{
63+
Name: "shim-binary-path",
64+
Usage: "Path to shim's binary",
65+
Value: shimBinaryPath,
66+
Destination: &args.Shim.BinaryPath,
67+
TakesFile: true,
68+
Sources: cli.EnvVars("DSTACK_SHIM_BINARY_PATH"),
69+
},
5770
&cli.IntFlag{
5871
Name: "shim-http-port",
5972
Usage: "Set shim's http port",
@@ -172,6 +185,7 @@ func mainInner() int {
172185

173186
func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) {
174187
log.DefaultEntry.Logger.SetLevel(logrus.Level(args.Shim.LogLevel))
188+
log.Info(ctx, "Starting dstack-shim", "version", Version)
175189

176190
shimHomeDir := args.Shim.HomeDir
177191
if shimHomeDir == "" {
@@ -211,6 +225,10 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
211225
} else if runnerErr != nil {
212226
return runnerErr
213227
}
228+
shimManager, shimErr := components.NewShimManager(ctx, args.Shim.BinaryPath)
229+
if shimErr != nil {
230+
return shimErr
231+
}
214232

215233
log.Debug(ctx, "Shim", "args", args.Shim)
216234
log.Debug(ctx, "Runner", "args", args.Runner)
@@ -259,7 +277,11 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
259277
}
260278

261279
address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort)
262-
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager)
280+
shimServer := api.NewShimServer(
281+
ctx, address, Version,
282+
dockerRunner, dcgmExporter, dcgmWrapper,
283+
runnerManager, shimManager,
284+
)
263285

264286
if serviceMode {
265287
if err := shim.WriteHostInfo(shimHomeDir, dockerRunner.Resources(ctx)); err != nil {
@@ -278,6 +300,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
278300
if err := shimServer.Serve(); err != nil {
279301
serveErrCh <- err
280302
}
303+
close(serveErrCh)
281304
}()
282305

283306
select {
@@ -287,7 +310,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
287310

288311
shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
289312
defer cancelShutdown()
290-
shutdownErr := shimServer.Shutdown(shutdownCtx)
313+
shutdownErr := shimServer.Shutdown(shutdownCtx, false)
291314
if serveErr != nil {
292315
return serveErr
293316
}

runner/consts/consts.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ const (
1313
// 2. A default path on the host unless overridden via shim CLI
1414
const RunnerBinaryPath = "/usr/local/bin/dstack-runner"
1515

16+
// A fallback path on the host used if os.Executable() has failed
17+
const ShimBinaryPath = "/usr/local/bin/dstack-shim"
18+
1619
// Error-containing messages will be identified by this signature
1720
const ExecutorFailedSignature = "Executor failed"
1821

runner/docs/shim.openapi.yaml

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ openapi: 3.1.2
22

33
info:
44
title: dstack-shim API
5-
version: v2/0.19.41
5+
version: v2/0.20.1
66
x-logo:
77
url: https://avatars.githubusercontent.com/u/54146142?s=260
88
description: >
@@ -41,7 +41,7 @@ paths:
4141
4242
**Important**: Since this endpoint is used for negotiation, it should always stay
4343
backward/future compatible, specifically the `version` field
44-
44+
tags: [shim]
4545
responses:
4646
"200":
4747
description: ""
@@ -50,6 +50,29 @@ paths:
5050
schema:
5151
$ref: "#/components/schemas/HealthcheckResponse"
5252

53+
/shutdown:
54+
post:
55+
summary: Request shim shutdown
56+
description: |
57+
(since [0.20.1](https://github.com/dstackai/dstack/releases/tag/0.20.1)) Request shim to shut down itself.
58+
Restart must be handled by an external process supervisor, e.g., `systemd`.
59+
60+
**Note**: background jobs (e.g., component installation) are canceled regardless of the `force` option.
61+
tags: [shim]
62+
requestBody:
63+
required: true
64+
content:
65+
application/json:
66+
schema:
67+
$ref: "#/components/schemas/ShutdownRequest"
68+
responses:
69+
"200":
70+
description: Request accepted
71+
$ref: "#/components/responses/PlainTextOk"
72+
"400":
73+
description: Malformed JSON body or validation error
74+
$ref: "#/components/responses/PlainTextBadRequest"
75+
5376
/instance/health:
5477
get:
5578
summary: Get instance health
@@ -66,7 +89,7 @@ paths:
6689
/components:
6790
get:
6891
summary: Get components
69-
description: (since [0.19.41](https://github.com/dstackai/dstack/releases/tag/0.19.41)) Returns a list of software components (e.g., `dstack-runner`)
92+
description: (since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) Returns a list of software components (e.g., `dstack-runner`)
7093
tags: [Components]
7194
responses:
7295
"200":
@@ -80,7 +103,7 @@ paths:
80103
post:
81104
summary: Install component
82105
description: >
83-
(since [0.19.41](https://github.com/dstackai/dstack/releases/tag/0.19.41)) Request installing/updating the software component.
106+
(since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) Request installing/updating the software component.
84107
Components are installed asynchronously
85108
tags: [Components]
86109
requestBody:
@@ -410,6 +433,10 @@ components:
410433
type: string
411434
enum:
412435
- dstack-runner
436+
- dstack-shim
437+
description: |
438+
* (since [0.20.0](https://github.com/dstackai/dstack/releases/tag/0.20.0)) `dstack-runner`
439+
* (since [0.20.1](https://github.com/dstackai/dstack/releases/tag/0.20.1)) `dstack-shim`
413440
414441
ComponentStatus:
415442
title: shim.components.ComponentStatus
@@ -430,7 +457,7 @@ components:
430457
type: string
431458
description: An empty string if status != installed
432459
examples:
433-
- 0.19.41
460+
- 0.20.1
434461
status:
435462
allOf:
436463
- $ref: "#/components/schemas/ComponentStatus"
@@ -457,6 +484,18 @@ components:
457484
- version
458485
additionalProperties: false
459486

487+
ShutdownRequest:
488+
title: shim.api.ShutdownRequest
489+
type: object
490+
properties:
491+
force:
492+
type: boolean
493+
examples:
494+
- false
495+
description: If `true`, don't wait for background job coroutines to complete after canceling them and close HTTP server forcefully.
496+
required:
497+
- force
498+
460499
InstanceHealthResponse:
461500
title: shim.api.InstanceHealthResponse
462501
type: object
@@ -486,7 +525,7 @@ components:
486525
url:
487526
type: string
488527
examples:
489-
- https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/0.19.41/binaries/dstack-runner-linux-amd64
528+
- https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/0.20.1/binaries/dstack-runner-linux-amd64
490529
required:
491530
- name
492531
- url

runner/internal/shim/api/handlers.go

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,21 @@ func (s *ShimServer) HealthcheckHandler(w http.ResponseWriter, r *http.Request)
2222
}, nil
2323
}
2424

25+
func (s *ShimServer) ShutdownHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
26+
var req ShutdownRequest
27+
if err := api.DecodeJSONBody(w, r, &req, true); err != nil {
28+
return nil, err
29+
}
30+
31+
go func() {
32+
if err := s.Shutdown(s.ctx, req.Force); err != nil {
33+
log.Error(s.ctx, "Shutdown", "err", err)
34+
}
35+
}()
36+
37+
return nil, nil
38+
}
39+
2540
func (s *ShimServer) InstanceHealthHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
2641
ctx := r.Context()
2742
response := InstanceHealthResponse{}
@@ -159,9 +174,11 @@ func (s *ShimServer) TaskMetricsHandler(w http.ResponseWriter, r *http.Request)
159174
}
160175

161176
func (s *ShimServer) ComponentListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
162-
runnerStatus := s.runnerManager.GetInfo(r.Context())
163177
response := &ComponentListResponse{
164-
Components: []components.ComponentInfo{runnerStatus},
178+
Components: []components.ComponentInfo{
179+
s.runnerManager.GetInfo(r.Context()),
180+
s.shimManager.GetInfo(r.Context()),
181+
},
165182
}
166183
return response, nil
167184
}
@@ -176,27 +193,31 @@ func (s *ShimServer) ComponentInstallHandler(w http.ResponseWriter, r *http.Requ
176193
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty name"}
177194
}
178195

196+
var componentManager components.ComponentManager
179197
switch components.ComponentName(req.Name) {
180198
case components.ComponentNameRunner:
181-
if req.URL == "" {
182-
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty url"}
183-
}
184-
185-
// There is still a small chance of time-of-check race condition, but we ignore it.
186-
runnerInfo := s.runnerManager.GetInfo(r.Context())
187-
if runnerInfo.Status == components.ComponentStatusInstalling {
188-
return nil, &api.Error{Status: http.StatusConflict, Msg: "already installing"}
189-
}
190-
191-
s.bgJobsGroup.Go(func() {
192-
if err := s.runnerManager.Install(s.bgJobsCtx, req.URL, true); err != nil {
193-
log.Error(s.bgJobsCtx, "runner background install", "err", err)
194-
}
195-
})
196-
199+
componentManager = s.runnerManager
200+
case components.ComponentNameShim:
201+
componentManager = s.shimManager
197202
default:
198203
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "unknown component"}
199204
}
200205

206+
if req.URL == "" {
207+
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty url"}
208+
}
209+
210+
// There is still a small chance of time-of-check race condition, but we ignore it.
211+
componentInfo := componentManager.GetInfo(r.Context())
212+
if componentInfo.Status == components.ComponentStatusInstalling {
213+
return nil, &api.Error{Status: http.StatusConflict, Msg: "already installing"}
214+
}
215+
216+
s.bgJobsGroup.Go(func() {
217+
if err := componentManager.Install(s.bgJobsCtx, req.URL, true); err != nil {
218+
log.Error(s.bgJobsCtx, "component background install", "name", componentInfo.Name, "err", err)
219+
}
220+
})
221+
201222
return nil, nil
202223
}

runner/internal/shim/api/handlers_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func TestHealthcheck(t *testing.T) {
1313
request := httptest.NewRequest("GET", "/api/healthcheck", nil)
1414
responseRecorder := httptest.NewRecorder()
1515

16-
server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil)
16+
server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil)
1717

1818
f := common.JSONResponseHandler(server.HealthcheckHandler)
1919
f(responseRecorder, request)
@@ -30,7 +30,7 @@ func TestHealthcheck(t *testing.T) {
3030
}
3131

3232
func TestTaskSubmit(t *testing.T) {
33-
server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil)
33+
server := NewShimServer(context.Background(), ":12340", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil)
3434
requestBody := `{
3535
"id": "dummy-id",
3636
"name": "dummy-name",

runner/internal/shim/api/schemas.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ type HealthcheckResponse struct {
1111
Version string `json:"version"`
1212
}
1313

14+
type ShutdownRequest struct {
15+
Force bool `json:"force"`
16+
}
17+
1418
type InstanceHealthResponse struct {
1519
DCGM *dcgm.Health `json:"dcgm"`
1620
}

0 commit comments

Comments
 (0)