Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 54 additions & 38 deletions runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,157 +5,169 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"path"
"path/filepath"
"syscall"
"time"

"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v3"

"github.com/dstackai/dstack/runner/consts"
"github.com/dstackai/dstack/runner/internal/common"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/shim"
"github.com/dstackai/dstack/runner/internal/shim/api"
"github.com/dstackai/dstack/runner/internal/shim/components"
"github.com/dstackai/dstack/runner/internal/shim/dcgm"
)

// Version is a build-time variable. The value is overridden by ldflags.
var Version string

func main() {
os.Exit(mainInner())
}

func mainInner() int {
var args shim.CLIArgs
var serviceMode bool

const defaultLogLevel = int(logrus.InfoLevel)

ctx := context.Background()

log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel))
log.DefaultEntry.Logger.SetOutput(os.Stderr)

app := &cli.App{
cmd := &cli.Command{
Name: "dstack-shim",
Usage: "Starts dstack-runner or docker container.",
Version: Version,
Flags: []cli.Flag{
/* Shim Parameters */
&cli.PathFlag{
&cli.StringFlag{
Name: "shim-home",
Usage: "Set shim's home directory",
Destination: &args.Shim.HomeDir,
TakesFile: true,
DefaultText: path.Join("~", consts.DstackDirPath),
EnvVars: []string{"DSTACK_SHIM_HOME"},
Sources: cli.EnvVars("DSTACK_SHIM_HOME"),
},
&cli.IntFlag{
Name: "shim-http-port",
Usage: "Set shim's http port",
Value: 10998,
Destination: &args.Shim.HTTPPort,
EnvVars: []string{"DSTACK_SHIM_HTTP_PORT"},
Sources: cli.EnvVars("DSTACK_SHIM_HTTP_PORT"),
},
&cli.IntFlag{
Name: "shim-log-level",
Usage: "Set shim's log level",
Value: defaultLogLevel,
Destination: &args.Shim.LogLevel,
EnvVars: []string{"DSTACK_SHIM_LOG_LEVEL"},
Sources: cli.EnvVars("DSTACK_SHIM_LOG_LEVEL"),
},
/* Runner Parameters */
&cli.StringFlag{
Name: "runner-download-url",
Usage: "Set runner's download URL",
Destination: &args.Runner.DownloadURL,
EnvVars: []string{"DSTACK_RUNNER_DOWNLOAD_URL"},
Sources: cli.EnvVars("DSTACK_RUNNER_DOWNLOAD_URL"),
},
&cli.PathFlag{
&cli.StringFlag{
Name: "runner-binary-path",
Usage: "Path to runner's binary",
Value: consts.RunnerBinaryPath,
Destination: &args.Runner.BinaryPath,
EnvVars: []string{"DSTACK_RUNNER_BINARY_PATH"},
TakesFile: true,
Sources: cli.EnvVars("DSTACK_RUNNER_BINARY_PATH"),
},
&cli.IntFlag{
Name: "runner-http-port",
Usage: "Set runner's http port",
Value: consts.RunnerHTTPPort,
Destination: &args.Runner.HTTPPort,
EnvVars: []string{"DSTACK_RUNNER_HTTP_PORT"},
Sources: cli.EnvVars("DSTACK_RUNNER_HTTP_PORT"),
},
&cli.IntFlag{
Name: "runner-ssh-port",
Usage: "Set runner's ssh port",
Value: consts.RunnerSSHPort,
Destination: &args.Runner.SSHPort,
EnvVars: []string{"DSTACK_RUNNER_SSH_PORT"},
Sources: cli.EnvVars("DSTACK_RUNNER_SSH_PORT"),
},
&cli.IntFlag{
Name: "runner-log-level",
Usage: "Set runner's log level",
Value: defaultLogLevel,
Destination: &args.Runner.LogLevel,
EnvVars: []string{"DSTACK_RUNNER_LOG_LEVEL"},
Sources: cli.EnvVars("DSTACK_RUNNER_LOG_LEVEL"),
},
/* DCGM Exporter Parameters */
&cli.IntFlag{
Name: "dcgm-exporter-http-port",
Usage: "DCGM Exporter http port",
Value: 10997,
Destination: &args.DCGMExporter.HTTPPort,
EnvVars: []string{"DSTACK_DCGM_EXPORTER_HTTP_PORT"},
Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_HTTP_PORT"),
},
&cli.IntFlag{
Name: "dcgm-exporter-interval",
Usage: "DCGM Exporter collect interval, milliseconds",
Value: 5000,
Destination: &args.DCGMExporter.Interval,
EnvVars: []string{"DSTACK_DCGM_EXPORTER_INTERVAL"},
Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_INTERVAL"),
},
/* DCGM Parameters */
&cli.StringFlag{
Name: "dcgm-address",
Usage: "nv-hostengine `hostname`, e.g., `localhost`",
DefaultText: "start libdcgm in embedded mode",
Destination: &args.DCGM.Address,
EnvVars: []string{"DSTACK_DCGM_ADDRESS"},
Sources: cli.EnvVars("DSTACK_DCGM_ADDRESS"),
},
/* Docker Parameters */
&cli.BoolFlag{
Name: "privileged",
Usage: "Give extended privileges to the container",
Destination: &args.Docker.Privileged,
EnvVars: []string{"DSTACK_DOCKER_PRIVILEGED"},
Sources: cli.EnvVars("DSTACK_DOCKER_PRIVILEGED"),
},
&cli.StringFlag{
Name: "ssh-key",
Usage: "Public SSH key",
Destination: &args.Docker.ConcatinatedPublicSSHKeys,
EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"},
Sources: cli.EnvVars("DSTACK_PUBLIC_SSH_KEY"),
},
&cli.StringFlag{
Name: "pjrt-device",
Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)",
Destination: &args.Docker.PJRTDevice,
EnvVars: []string{"PJRT_DEVICE"},
Sources: cli.EnvVars("PJRT_DEVICE"),
},
/* Misc Parameters */
&cli.BoolFlag{
Name: "service",
Usage: "Start as a service",
Destination: &serviceMode,
EnvVars: []string{"DSTACK_SERVICE_MODE"},
Sources: cli.EnvVars("DSTACK_SERVICE_MODE"),
},
},
Action: func(c *cli.Context) error {
Action: func(ctx context.Context, cmd *cli.Command) error {
return start(ctx, args, serviceMode)
},
}

if err := app.Run(os.Args); err != nil {
log.Fatal(ctx, err.Error())
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

if err := cmd.Run(ctx, os.Args); err != nil {
log.Error(ctx, err.Error())
return 1
}

return 0
}

func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) {
Expand Down Expand Up @@ -191,8 +203,13 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
}
}()

if err := args.DownloadRunner(ctx); err != nil {
return err
runnerManager, runnerErr := components.NewRunnerManager(ctx, args.Runner.BinaryPath)
if args.Runner.DownloadURL != "" {
if err := runnerManager.Install(ctx, args.Runner.DownloadURL, false); err != nil {
return err
}
} else if runnerErr != nil {
return runnerErr
}

log.Debug(ctx, "Shim", "args", args.Shim)
Expand Down Expand Up @@ -242,13 +259,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
}

address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort)
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper)

defer func() {
shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
defer cancelShutdown()
_ = shimServer.HttpServer.Shutdown(shutdownCtx)
}()
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager)

if serviceMode {
if err := shim.WriteHostInfo(shimHomeDir, dockerRunner.Resources(ctx)); err != nil {
Expand All @@ -260,9 +271,14 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
}
}

if err := shimServer.HttpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
go func() {
if err := shimServer.Serve(); err != nil {
log.Error(ctx, "serve", "err", err)
}
}()

return nil
<-ctx.Done()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If shimServer.Serve() errors, this still blocks waiting for the signal. I think shim should exit with non-zero in that case.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I've added a communication channel for ShimServer.Serve() errors: cf9b9a5

shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
defer cancelShutdown()
return shimServer.Shutdown(shutdownCtx)
}
Loading