From e26837fa85b96c86345bcf1d2d1d88916071fdee Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Thu, 26 Jun 2025 16:55:01 +0000 Subject: [PATCH] Add `files` property to run configurations Each item in the `files` property maps a local path (a file or a dir) to a container path. Each local path is packed into a tar archive and uploaded to the server similar to the code repo archive/diff. On the server, each archive is stored in the DB as `FileArchiveModel` linked to the user. Archive blobs are optionally uploaded to the storage (again, similar to the code blob). When the job is submitted to the runner, files (if any) are uploaded after `/api/submit` but before `/api/upload_code`. The runner unpacks archives as follows: * If the path already exists, it's removed * If any parent dirs of the path are missing, they are created as owned by the run user. The owner of the existing dirs is not changed. * The owner of the path (and all subpaths in the case of the directory) is set to the run user. The permissions from the archive (and thus from the user's machine) are preserved. Part-of: https://github.com/dstackai/dstack/issues/2738 --- .../reference/dstack.yml/dev-environment.md | 19 ++- docs/docs/reference/dstack.yml/service.md | 17 ++ docs/docs/reference/dstack.yml/task.md | 17 ++ runner/go.mod | 11 +- runner/go.sum | 8 +- runner/internal/api/common.go | 10 +- runner/internal/executor/base.go | 2 + runner/internal/executor/executor.go | 38 ++++- runner/internal/executor/files.go | 132 ++++++++++++++++ runner/internal/executor/repo.go | 2 +- runner/internal/runner/api/http.go | 58 +++++++ runner/internal/runner/api/server.go | 1 + runner/internal/schemas/schemas.go | 6 + .../_internal/core/compatibility/runs.py | 4 + .../_internal/core/models/configurations.py | 11 ++ src/dstack/_internal/core/models/files.py | 67 ++++++++ src/dstack/_internal/core/models/runs.py | 5 + src/dstack/_internal/server/app.py | 2 + .../background/tasks/process_running_jobs.py | 64 ++++++++ .../5f1707c525d2_add_filearchivemodel.py | 39 +++++ src/dstack/_internal/server/models.py | 15 ++ src/dstack/_internal/server/routers/files.py | 67 ++++++++ src/dstack/_internal/server/schemas/files.py | 5 + src/dstack/_internal/server/schemas/runner.py | 1 + src/dstack/_internal/server/services/files.py | 91 +++++++++++ .../server/services/jobs/__init__.py | 7 +- .../server/services/runner/client.py | 8 + .../_internal/server/services/storage/base.py | 21 +++ .../_internal/server/services/storage/gcs.py | 34 +++- .../_internal/server/services/storage/s3.py | 36 +++-- src/dstack/_internal/server/settings.py | 2 + src/dstack/_internal/server/testing/common.py | 17 ++ src/dstack/_internal/utils/files.py | 69 ++++++++ src/dstack/_internal/utils/path.py | 16 +- src/dstack/api/_public/runs.py | 48 +++++- src/dstack/api/server/__init__.py | 6 + src/dstack/api/server/_files.py | 18 +++ src/tests/_internal/core/models/test_files.py | 29 ++++ .../_internal/server/routers/test_files.py | 148 ++++++++++++++++++ .../_internal/server/routers/test_runs.py | 5 + src/tests/_internal/utils/test_path.py | 14 +- 41 files changed, 1120 insertions(+), 50 deletions(-) create mode 100644 runner/internal/executor/files.go create mode 100644 src/dstack/_internal/core/models/files.py create mode 100644 src/dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py create mode 100644 src/dstack/_internal/server/routers/files.py create mode 100644 src/dstack/_internal/server/schemas/files.py create mode 100644 src/dstack/_internal/server/services/files.py create mode 100644 src/dstack/_internal/utils/files.py create mode 100644 src/dstack/api/server/_files.py create mode 100644 src/tests/_internal/core/models/test_files.py create mode 100644 src/tests/_internal/server/routers/test_files.py diff --git a/docs/docs/reference/dstack.yml/dev-environment.md b/docs/docs/reference/dstack.yml/dev-environment.md index 0c872e1d18..9fb52b4893 100644 --- a/docs/docs/reference/dstack.yml/dev-environment.md +++ b/docs/docs/reference/dstack.yml/dev-environment.md @@ -90,4 +90,21 @@ The `dev-environment` configuration type allows running [dev environments](../.. The short syntax for volumes is a colon-separated string in the form of `source:destination` * `volume-name:/container/path` for network volumes - * `/instance/path:/container/path` for instance volumes + * `/instance/path:/container/path` for instance volumes + +### `files[n]` { #_files data-toc-label="files" } + +#SCHEMA# dstack._internal.core.models.files.FilePathMapping + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for files is a colon-separated string in the form of `local_path[:path]` where + `path` is optional and can be omitted if it's equal to `local_path`. + + * `~/.bashrc`, same as `~/.bashrc:~/.bashrc` + * `/opt/myorg`, same as `/opt/myorg/` and `/opt/myorg:/opt/myorg` + * `libs/patched_libibverbs.so.1:/lib/x86_64-linux-gnu/libibverbs.so.1` diff --git a/docs/docs/reference/dstack.yml/service.md b/docs/docs/reference/dstack.yml/service.md index d7cb13acd5..041c9d43fe 100644 --- a/docs/docs/reference/dstack.yml/service.md +++ b/docs/docs/reference/dstack.yml/service.md @@ -185,3 +185,20 @@ The `service` configuration type allows running [services](../../concepts/servic * `volume-name:/container/path` for network volumes * `/instance/path:/container/path` for instance volumes + +### `files[n]` { #_files data-toc-label="files" } + +#SCHEMA# dstack._internal.core.models.files.FilePathMapping + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for files is a colon-separated string in the form of `local_path[:path]` where + `path` is optional and can be omitted if it's equal to `local_path`. + + * `~/.bashrc`, same as `~/.bashrc:~/.bashrc` + * `/opt/myorg`, same as `/opt/myorg/` and `/opt/myorg:/opt/myorg` + * `libs/patched_libibverbs.so.1:/lib/x86_64-linux-gnu/libibverbs.so.1` diff --git a/docs/docs/reference/dstack.yml/task.md b/docs/docs/reference/dstack.yml/task.md index 0565dbf6f9..7a2cf791a2 100644 --- a/docs/docs/reference/dstack.yml/task.md +++ b/docs/docs/reference/dstack.yml/task.md @@ -91,3 +91,20 @@ The `task` configuration type allows running [tasks](../../concepts/tasks.md). * `volume-name:/container/path` for network volumes * `/instance/path:/container/path` for instance volumes + +### `files[n]` { #_files data-toc-label="files" } + +#SCHEMA# dstack._internal.core.models.files.FilePathMapping + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for files is a colon-separated string in the form of `local_path[:path]` where + `path` is optional and can be omitted if it's equal to `local_path`. + + * `~/.bashrc`, same as `~/.bashrc:~/.bashrc` + * `/opt/myorg`, same as `/opt/myorg/` and `/opt/myorg:/opt/myorg` + * `libs/patched_libibverbs.so.1:/lib/x86_64-linux-gnu/libibverbs.so.1` diff --git a/runner/go.mod b/runner/go.mod index 2eeee85393..22dad6466e 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -5,13 +5,16 @@ go 1.23 require ( github.com/alexellis/go-execute/v2 v2.2.1 github.com/bluekeyes/go-gitdiff v0.7.2 + github.com/codeclysm/extract/v4 v4.0.0 github.com/creack/pty v1.1.24 github.com/docker/docker v26.0.0+incompatible github.com/docker/go-connections v0.5.0 github.com/docker/go-units v0.5.0 github.com/go-git/go-git/v5 v5.12.0 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f + github.com/gorilla/websocket v1.5.1 github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf + github.com/prometheus/procfs v0.15.1 github.com/shirou/gopsutil/v4 v4.24.11 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.10.0 @@ -78,12 +81,6 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gotest.tools/v3 v3.5.0 // indirect -) - -require ( - github.com/codeclysm/extract/v3 v3.1.1 - github.com/gorilla/websocket v1.5.1 - github.com/prometheus/procfs v0.15.1 gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools/v3 v3.5.0 // indirect ) diff --git a/runner/go.sum b/runner/go.sum index c3cdaf43f8..41e133c465 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -13,8 +13,8 @@ github.com/alexellis/go-execute/v2 v2.2.1 h1:4Ye3jiCKQarstODOEmqDSRCqxMHLkC92Bhs github.com/alexellis/go-execute/v2 v2.2.1/go.mod h1:FMdRnUTiFAmYXcv23txrp3VYZfLo24nMpiIneWgKHTQ= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= -github.com/arduino/go-paths-helper v1.2.0 h1:qDW93PR5IZUN/jzO4rCtexiwF8P4OIcOmcSgAYLZfY4= -github.com/arduino/go-paths-helper v1.2.0/go.mod h1:HpxtKph+g238EJHq4geEPv9p+gl3v5YYu35Yb+w31Ck= +github.com/arduino/go-paths-helper v1.12.1 h1:WkxiVUxBjKWlLMiMuYy8DcmVrkxdP7aKxQOAq7r2lVM= +github.com/arduino/go-paths-helper v1.12.1/go.mod h1:jcpW4wr0u69GlXhTYydsdsqAjLaYK5n7oWHfKqOG6LM= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/bluekeyes/go-gitdiff v0.7.2 h1:42jrcVZdjjxXtVsFNYTo/I6T1ZvIiQL+iDDLiH904hw= @@ -26,8 +26,8 @@ github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyY github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= -github.com/codeclysm/extract/v3 v3.1.1 h1:iHZtdEAwSTqPrd+1n4jfhr1qBhUWtHlMTjT90+fJVXg= -github.com/codeclysm/extract/v3 v3.1.1/go.mod h1:ZJi80UG2JtfHqJI+lgJSCACttZi++dHxfWuPaMhlOfQ= +github.com/codeclysm/extract/v4 v4.0.0 h1:H87LFsUNaJTu2e/8p/oiuiUsOK/TaPQ5wxsjPnwPEIY= +github.com/codeclysm/extract/v4 v4.0.0/go.mod h1:SFju1lj6as7FvUgalpSct7torJE0zttbJUWtryPRG6s= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= diff --git a/runner/internal/api/common.go b/runner/internal/api/common.go index fcb5f40e6d..ae15369d73 100644 --- a/runner/internal/api/common.go +++ b/runner/internal/api/common.go @@ -105,18 +105,18 @@ func DecodeJSONBody(w http.ResponseWriter, r *http.Request, dst interface{}, all func JSONResponseHandler(handler func(http.ResponseWriter, *http.Request) (interface{}, error)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { status := 200 - msg := "" + errMsg := "" var apiErr *Error body, err := handler(w, r) if err != nil { if errors.As(err, &apiErr) { status = apiErr.Status - msg = apiErr.Error() - log.Warning(r.Context(), "API error", "err", apiErr.Err) + errMsg = apiErr.Error() + log.Warning(r.Context(), "API error", "err", errMsg, "status", status) } else { status = http.StatusInternalServerError - log.Error(r.Context(), "Unexpected API error", "err", err) + log.Error(r.Context(), "Unexpected API error", "err", err, "status", status) } } @@ -125,7 +125,7 @@ func JSONResponseHandler(handler func(http.ResponseWriter, *http.Request) (inter w.WriteHeader(status) _ = json.NewEncoder(w).Encode(body) } else { - http.Error(w, msg, status) + http.Error(w, errMsg, status) } log.Debug(r.Context(), "", "method", r.Method, "endpoint", r.URL.Path, "status", status) diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go index a8e1505520..2163ca9204 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/executor/base.go @@ -2,6 +2,7 @@ package executor import ( "context" + "io" "github.com/dstackai/dstack/runner/internal/schemas" "github.com/dstackai/dstack/runner/internal/types" @@ -22,6 +23,7 @@ type Executor interface { termination_message string, ) SetRunnerState(state string) + AddFileArchive(id string, src io.Reader) error Lock() RLock() RUnlock() diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 9fad1f8bf8..14f46d035c 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -31,8 +31,9 @@ type RunExecutor struct { tempDir string homeDir string workingDir string + archiveDir string sshPort int - uid uint32 + currentUid uint32 run schemas.Run jobSpec schemas.JobSpec @@ -77,8 +78,9 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort i tempDir: tempDir, homeDir: homeDir, workingDir: workingDir, + archiveDir: filepath.Join(tempDir, "file_archives"), sshPort: sshPort, - uid: uid, + currentUid: uid, mu: mu, state: WaitSubmit, @@ -131,6 +133,28 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { ctx = log.WithLogger(ctx, log.NewEntry(logger, int(log.DefaultEntry.Logger.Level))) // todo loglevel log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String()) + if ex.jobSpec.User != nil { + if err := fillUser(ex.jobSpec.User); err != nil { + ex.SetJobStateWithTerminationReason( + ctx, + types.JobStateFailed, + types.TerminationReasonExecutorError, + fmt.Sprintf("Failed to fill in the job user fields (%s)", err), + ) + return gerrors.Wrap(err) + } + } + + if err := ex.setupFiles(ctx); err != nil { + ex.SetJobStateWithTerminationReason( + ctx, + types.JobStateFailed, + types.TerminationReasonExecutorError, + fmt.Sprintf("Failed to set up files (%s)", err), + ) + return gerrors.Wrap(err) + } + if err := ex.setupRepo(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, @@ -140,6 +164,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { ) return gerrors.Wrap(err) } + cleanupCredentials, err := ex.setupCredentials(ctx) if err != nil { ex.SetJobState(ctx, types.JobStateFailed) @@ -300,16 +325,13 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error user := ex.jobSpec.User if user != nil { - if err := fillUser(user); err != nil { - return gerrors.Wrap(err) - } log.Trace( ctx, "Using credentials", "uid", *user.Uid, "gid", *user.Gid, "groups", user.GroupIds, "username", user.GetUsername(), "groupname", user.GetGroupname(), "home", user.HomeDir, ) - log.Trace(ctx, "Current user", "uid", ex.uid) + log.Trace(ctx, "Current user", "uid", ex.currentUid) // 1. Ideally, We should check uid, gid, and supplementary groups mismatches, // but, for the sake of simplicity, we only check uid. Unprivileged runner @@ -318,8 +340,8 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error // 2. Strictly speaking, we need CAP_SETUID and CAP_GUID (for Cmd.Start()-> // Cmd.SysProcAttr.Credential) and CAP_CHOWN (for startCommand()->os.Chown()), // but for the sake of simplicity we instead check if we are root or not - if *user.Uid != ex.uid && ex.uid != 0 { - return gerrors.Newf("cannot start job as %d, current uid is %d", *user.Uid, ex.uid) + if *user.Uid != ex.currentUid && ex.currentUid != 0 { + return gerrors.Newf("cannot start job as %d, current uid is %d", *user.Uid, ex.currentUid) } if cmd.SysProcAttr == nil { diff --git a/runner/internal/executor/files.go b/runner/internal/executor/files.go new file mode 100644 index 0000000000..89de206338 --- /dev/null +++ b/runner/internal/executor/files.go @@ -0,0 +1,132 @@ +package executor + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path" + "regexp" + "slices" + "strings" + + "github.com/codeclysm/extract/v4" + "github.com/dstackai/dstack/runner/internal/gerrors" + "github.com/dstackai/dstack/runner/internal/log" +) + +var renameRegex = regexp.MustCompile(`^([^/]*)(/|$)`) + +func (ex *RunExecutor) AddFileArchive(id string, src io.Reader) error { + if err := os.MkdirAll(ex.archiveDir, 0o755); err != nil { + return gerrors.Wrap(err) + } + archivePath := path.Join(ex.archiveDir, id) + archive, err := os.Create(archivePath) + if err != nil { + return gerrors.Wrap(err) + } + defer func() { _ = archive.Close() }() + if _, err = io.Copy(archive, src); err != nil { + return gerrors.Wrap(err) + } + return nil +} + +// setupFiles must be called from Run +func (ex *RunExecutor) setupFiles(ctx context.Context) error { + homeDir := ex.workingDir + uid := -1 + gid := -1 + if ex.jobSpec.User != nil { + if ex.jobSpec.User.HomeDir != "" { + homeDir = ex.jobSpec.User.HomeDir + } + if ex.jobSpec.User.Uid != nil { + uid = int(*ex.jobSpec.User.Uid) + } + if ex.jobSpec.User.Gid != nil { + gid = int(*ex.jobSpec.User.Gid) + } + } + + for _, fa := range ex.run.RunSpec.FileArchives { + log.Trace(ctx, "Extracting file archive", "id", fa.Id, "path", fa.Path) + + p := path.Clean(fa.Path) + // `~username[/path/to]` is not supported + if p == "~" { + p = homeDir + } else if rest, found := strings.CutPrefix(p, "~/"); found { + p = path.Join(homeDir, rest) + } else if !path.IsAbs(p) { + p = path.Join(ex.workingDir, p) + } + dir, root := path.Split(p) + if err := mkdirAll(ctx, dir, uid, gid); err != nil { + return gerrors.Wrap(err) + } + + if err := os.RemoveAll(p); err != nil { + log.Warning(ctx, "Failed to remove", "path", p, "err", err) + } + + archivePath := path.Join(ex.archiveDir, fa.Id) + archive, err := os.Open(archivePath) + if err != nil { + return gerrors.Wrap(err) + } + defer func() { + _ = archive.Close() + if err := os.Remove(archivePath); err != nil { + log.Warning(ctx, "Failed to remove archive", "path", archivePath, "err", err) + } + }() + + var paths []string + repl := fmt.Sprintf("%s$2", root) + renameAndRemember := func(s string) string { + s = renameRegex.ReplaceAllString(s, repl) + paths = append(paths, s) + return s + } + if err := extract.Tar(ctx, archive, dir, renameAndRemember); err != nil { + return gerrors.Wrap(err) + } + + if uid != -1 || gid != -1 { + for _, p := range paths { + if err := os.Chown(path.Join(dir, p), uid, gid); err != nil { + log.Warning(ctx, "Failed to chown", "path", p, "err", err) + } + } + } + } + + return nil +} + +func mkdirAll(ctx context.Context, p string, uid int, gid int) error { + var paths []string + for { + p = path.Dir(p) + if p == "/" { + break + } + paths = append(paths, p) + } + for _, p := range slices.Backward(paths) { + if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) { + if err := os.Mkdir(p, 0o755); err != nil { + return err + } + if err := os.Chown(p, uid, gid); err != nil { + log.Warning(ctx, "Failed to chown", "path", p, "err", err) + } + } else if err != nil { + return err + } + } + return nil +} diff --git a/runner/internal/executor/repo.go b/runner/internal/executor/repo.go index bb4121d726..5afbb2515a 100644 --- a/runner/internal/executor/repo.go +++ b/runner/internal/executor/repo.go @@ -7,7 +7,7 @@ import ( "os/exec" "path/filepath" - "github.com/codeclysm/extract/v3" + "github.com/codeclysm/extract/v4" "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/repo" diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 36c63527fb..f323de86ed 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -2,8 +2,12 @@ package api import ( "context" + "errors" + "fmt" "io" "math" + "mime" + "mime/multipart" "net/http" "os" "path/filepath" @@ -58,6 +62,60 @@ func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (inte return nil, nil } +// uploadArchivePostHandler may be called 0 or more times, and must be called after submitPostHandler +// and before uploadCodePostHandler +func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + s.executor.Lock() + defer s.executor.Unlock() + if s.executor.GetRunnerState() != executor.WaitCode { + return nil, &api.Error{Status: http.StatusConflict} + } + + contentType := r.Header.Get("Content-Type") + if contentType == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing content-type header"} + } + mediaType, params, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, gerrors.Wrap(err) + } + if mediaType != "multipart/form-data" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: fmt.Sprintf("multipart/form-data expected, got %s", mediaType)} + } + boundary := params["boundary"] + if boundary == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing boundary"} + } + + formReader := multipart.NewReader(r.Body, boundary) + part, err := formReader.NextPart() + if errors.Is(err, io.EOF) { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty form"} + } + if err != nil { + return nil, gerrors.Wrap(err) + } + fieldName := part.FormName() + if fieldName == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing field name"} + } + if fieldName != "archive" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: fmt.Sprintf("unexpected field %s", fieldName)} + } + archiveId := part.FileName() + if archiveId == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing file name"} + } + if err := s.executor.AddFileArchive(archiveId, part); err != nil { + return nil, gerrors.Wrap(err) + } + if _, err := formReader.NextPart(); !errors.Is(err, io.EOF) { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "extra form field(s)"} + } + + return nil, nil +} + func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.executor.Lock() defer s.executor.Unlock() diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 732d5ed50f..a204c1c468 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -61,6 +61,7 @@ func NewServer(tempDir string, homeDir string, workingDir string, address string r.AddHandler("GET", "/api/healthcheck", s.healthcheckGetHandler) r.AddHandler("GET", "/api/metrics", s.metricsGetHandler) r.AddHandler("POST", "/api/submit", s.submitPostHandler) + r.AddHandler("POST", "/api/upload_archive", s.uploadArchivePostHandler) r.AddHandler("POST", "/api/upload_code", s.uploadCodePostHandler) r.AddHandler("POST", "/api/run", s.runPostHandler) r.AddHandler("GET", "/api/pull", s.pullGetHandler) diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index 4391874007..8d4cef2f1c 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -47,6 +47,7 @@ type RunSpec struct { RunName string `json:"run_name"` RepoId string `json:"repo_id"` RepoData RepoData `json:"repo_data"` + FileArchives []FileArchive `json:"file_archives"` Configuration Configuration `json:"configuration"` ConfigurationPath string `json:"configuration_path"` } @@ -96,6 +97,11 @@ type RepoData struct { RepoConfigEmail string `json:"repo_config_email"` } +type FileArchive struct { + Id string `json:"id"` + Path string `json:"path"` +} + type Configuration struct { Type string `json:"type"` } diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 97f90c8d2e..696cd7f025 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -109,6 +109,10 @@ def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: configuration_excludes["stop_criteria"] = True if profile is not None and profile.stop_criteria is None: profile_excludes.add("stop_criteria") + if not configuration.files: + configuration_excludes["files"] = True + if not run_spec.file_archives: + spec_excludes["file_archives"] = True if configuration_excludes: spec_excludes["configuration"] = configuration_excludes diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 92ae999ba0..97be403ca1 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -10,6 +10,7 @@ from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth from dstack._internal.core.models.envs import Env +from dstack._internal.core.models.files import FilePathMapping from dstack._internal.core.models.fleets import FleetConfiguration from dstack._internal.core.models.gateways import GatewayConfiguration from dstack._internal.core.models.profiles import ProfileParams, parse_off_duration @@ -252,6 +253,10 @@ class BaseRunConfiguration(CoreModel): description="Use Docker inside the container. Mutually exclusive with `image`, `python`, and `nvcc`. Overrides `privileged`" ), ] = None + files: Annotated[ + list[Union[FilePathMapping, str]], + Field(description="The local to container file path mappings"), + ] = [] # deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init` setup: CommandsList = [] @@ -285,6 +290,12 @@ def convert_volumes(cls, v) -> MountPoint: return parse_mount_point(v) return v + @validator("files", each_item=True) + def convert_files(cls, v) -> FilePathMapping: + if isinstance(v, str): + return FilePathMapping.parse(v) + return v + @validator("user") def validate_user(cls, v) -> Optional[str]: if v is None: diff --git a/src/dstack/_internal/core/models/files.py b/src/dstack/_internal/core/models/files.py new file mode 100644 index 0000000000..f2e4f6826d --- /dev/null +++ b/src/dstack/_internal/core/models/files.py @@ -0,0 +1,67 @@ +import pathlib +import string +from uuid import UUID + +from pydantic import Field, validator +from typing_extensions import Annotated, Self + +from dstack._internal.core.models.common import CoreModel + + +class FileArchive(CoreModel): + id: UUID + hash: str + + +class FilePathMapping(CoreModel): + local_path: Annotated[ + str, + Field( + description=( + "The path on the user's machine. Relative paths are resolved relative to" + " the parent directory of the the configuration file" + ) + ), + ] + path: Annotated[ + str, + Field( + description=( + "The path in the container. Relative paths are resolved relative to" + " the repo directory (`/workflow`)" + ) + ), + ] + + @classmethod + def parse(cls, v: str) -> Self: + local_path: str + path: str + parts = v.split(":") + # A special case for Windows paths, e.g., `C:\path\to`, 'c:/path/to' + if ( + len(parts) > 1 + and len(parts[0]) == 1 + and parts[0] in string.ascii_letters + and parts[1][:1] in ["\\", "/"] + ): + parts = [f"{parts[0]}:{parts[1]}", *parts[2:]] + if len(parts) == 1: + local_path = path = parts[0] + elif len(parts) == 2: + local_path, path = parts + else: + raise ValueError(f"invalid file path mapping: {v}") + return cls(local_path=local_path, path=path) + + @validator("path") + def validate_path(cls, v) -> str: + # True for `C:/.*`, False otherwise, including `/abs/unix/path`, `rel\windows\path`, etc. + if pathlib.PureWindowsPath(v).is_absolute(): + raise ValueError(f"path must be a Unix file path: {v}") + return v + + +class FileArchiveMapping(CoreModel): + id: Annotated[UUID, Field(description="The File archive ID")] + path: Annotated[str, Field(description="The path in the container")] diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 860aa5238e..df53247338 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -12,6 +12,7 @@ AnyRunConfiguration, RunConfiguration, ) +from dstack._internal.core.models.files import FileArchiveMapping from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, InstanceType, @@ -413,6 +414,10 @@ class RunSpec(CoreModel): Optional[str], Field(description="The hash of the repo diff. Can be omitted if there is no repo diff."), ] = None + file_archives: Annotated[ + list[FileArchiveMapping], + Field(description="The list of file archive ID to container path mappings"), + ] = [] working_dir: Annotated[ Optional[str], Field( diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 6377e4c261..6a76dadf48 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -23,6 +23,7 @@ from dstack._internal.server.db import get_db, get_session_ctx, migrate from dstack._internal.server.routers import ( backends, + files, fleets, gateways, instances, @@ -197,6 +198,7 @@ def register_routes(app: FastAPI, ui: bool = True): app.include_router(service_proxy.router, prefix="/proxy/services", tags=["service-proxy"]) app.include_router(model_proxy.router, prefix="/proxy/models", tags=["model-proxy"]) app.include_router(prometheus.router) + app.include_router(files.router) @app.exception_handler(ForbiddenError) async def forbidden_error_handler(request: Request, exc: ForbiddenError): diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index f287d11128..ef07becdfd 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -1,5 +1,6 @@ import asyncio import re +import uuid from collections.abc import Iterable from datetime import timedelta, timezone from typing import Dict, List, Optional @@ -14,6 +15,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode, RegistryAuth from dstack._internal.core.models.configurations import DevEnvironmentConfiguration +from dstack._internal.core.models.files import FileArchiveMapping from dstack._internal.core.models.instances import ( InstanceStatus, RemoteConnectionInfo, @@ -42,8 +44,10 @@ ProjectModel, RepoModel, RunModel, + UserModel, ) from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus +from dstack._internal.server.services import files as files_services from dstack._internal.server.services import logs as logs_services from dstack._internal.server.services import services from dstack._internal.server.services.instances import get_instance_ssh_private_keys @@ -226,12 +230,20 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): fmt(job_model), job_submission.age, ) + # FIXME: downloading file archives and code here is a waste of time if + # the runner is not ready yet + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=run.run_spec.file_archives, + user=run_model.user, + ) code = await _get_job_code( session=session, project=project, repo=repo_model, code_hash=run.run_spec.repo_code_hash, ) + success = await common_utils.run_async( _submit_job_to_runner, server_ssh_private_keys, @@ -242,6 +254,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job, cluster_info, code, + file_archives, secrets, repo_creds, success_if_not_available=False, @@ -269,6 +282,13 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): logger.debug( "%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age ) + # FIXME: downloading file archives and code here is a waste of time if + # the runner is not ready yet + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=run.run_spec.file_archives, + user=run_model.user, + ) code = await _get_job_code( session=session, project=project, @@ -285,6 +305,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job, cluster_info, code, + file_archives, secrets, repo_creds, server_ssh_private_keys, @@ -588,6 +609,7 @@ def _process_pulling_with_shim( job: Job, cluster_info: ClusterInfo, code: bytes, + file_archives: Iterable[tuple[uuid.UUID, bytes]], secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], server_ssh_private_keys: tuple[str, Optional[str]], @@ -663,6 +685,7 @@ def _process_pulling_with_shim( job=job, cluster_info=cluster_info, code=code, + file_archives=file_archives, secrets=secrets, repo_credentials=repo_credentials, success_if_not_available=True, @@ -853,6 +876,43 @@ async def _get_job_code( return blob +async def _get_job_file_archives( + session: AsyncSession, + archive_mappings: Iterable[FileArchiveMapping], + user: UserModel, +) -> list[tuple[uuid.UUID, bytes]]: + archives: list[tuple[uuid.UUID, bytes]] = [] + for archive_mapping in archive_mappings: + archive_id = archive_mapping.id + archive_blob = await _get_job_file_archive( + session=session, archive_id=archive_id, user=user + ) + archives.append((archive_id, archive_blob)) + return archives + + +async def _get_job_file_archive( + session: AsyncSession, archive_id: uuid.UUID, user: UserModel +) -> bytes: + archive_model = await files_services.get_archive_model(session, id=archive_id, user=user) + if archive_model is None: + return b"" + if archive_model.blob is not None: + return archive_model.blob + storage = get_default_storage() + if storage is None: + return b"" + blob = await common_utils.run_async( + storage.get_archive, + str(archive_model.user_id), + archive_model.blob_hash, + ) + if blob is None: + logger.error("Failed to get file archive %s from storage", archive_id) + return b"" + return blob + + @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) def _submit_job_to_runner( ports: Dict[int, int], @@ -861,6 +921,7 @@ def _submit_job_to_runner( job: Job, cluster_info: ClusterInfo, code: bytes, + file_archives: Iterable[tuple[uuid.UUID, bytes]], secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], success_if_not_available: bool, @@ -900,6 +961,9 @@ def _submit_job_to_runner( repo_credentials=repo_credentials, instance_env=instance_env, ) + logger.debug("%s: uploading file archive(s)", fmt(job_model)) + for archive_id, archive in file_archives: + runner_client.upload_archive(archive_id, archive) logger.debug("%s: uploading code", fmt(job_model)) runner_client.upload_code(code) logger.debug("%s: starting job", fmt(job_model)) diff --git a/src/dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py b/src/dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py new file mode 100644 index 0000000000..a73d9db250 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py @@ -0,0 +1,39 @@ +"""Add FileArchiveModel + +Revision ID: 5f1707c525d2 +Revises: 35e90e1b0d3e +Create Date: 2025-06-12 12:28:26.678380 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5f1707c525d2" +down_revision = "35e90e1b0d3e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "file_archives", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("user_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("blob_hash", sa.Text(), nullable=False), + sa.Column("blob", sa.LargeBinary(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("fk_file_archives_user_id_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_file_archives")), + sa.UniqueConstraint("user_id", "blob_hash", name="uq_file_archives_user_id_blob_hash"), + ) + + +def downgrade() -> None: + op.drop_table("file_archives") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index c5e4749c99..dd8fd5f125 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -315,6 +315,21 @@ class CodeModel(BaseModel): blob: Mapped[Optional[bytes]] = mapped_column(LargeBinary) # None means blob is stored on s3 +class FileArchiveModel(BaseModel): + __tablename__ = "file_archives" + __table_args__ = ( + UniqueConstraint("user_id", "blob_hash", name="uq_file_archives_user_id_blob_hash"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + user_id: Mapped["UserModel"] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) + user: Mapped["UserModel"] = relationship() + blob_hash: Mapped[str] = mapped_column(Text) + blob: Mapped[Optional[bytes]] = mapped_column(LargeBinary) # None means blob is stored on s3 + + class RunModel(BaseModel): __tablename__ = "runs" diff --git a/src/dstack/_internal/server/routers/files.py b/src/dstack/_internal/server/routers/files.py new file mode 100644 index 0000000000..574ef01776 --- /dev/null +++ b/src/dstack/_internal/server/routers/files.py @@ -0,0 +1,67 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Request, UploadFile +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ResourceNotExistsError, ServerClientError +from dstack._internal.core.models.files import FileArchive +from dstack._internal.server.db import get_session +from dstack._internal.server.models import UserModel +from dstack._internal.server.schemas.files import GetFileArchiveByHashRequest +from dstack._internal.server.security.permissions import Authenticated +from dstack._internal.server.services import files +from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT +from dstack._internal.server.utils.routers import ( + get_base_api_additional_responses, + get_request_size, +) +from dstack._internal.utils.common import sizeof_fmt + +router = APIRouter( + prefix="/api/files", + tags=["files"], + responses=get_base_api_additional_responses(), +) + + +@router.post("/get_archive_by_hash") +async def get_archive_by_hash( + body: GetFileArchiveByHashRequest, + session: Annotated[AsyncSession, Depends(get_session)], + user: Annotated[UserModel, Depends(Authenticated())], +) -> FileArchive: + archive = await files.get_archive_by_hash( + session=session, + user=user, + hash=body.hash, + ) + if archive is None: + raise ResourceNotExistsError() + return archive + + +@router.post("/upload_archive") +async def upload_archive( + request: Request, + file: UploadFile, + session: Annotated[AsyncSession, Depends(get_session)], + user: Annotated[UserModel, Depends(Authenticated())], +) -> FileArchive: + request_size = get_request_size(request) + if SERVER_CODE_UPLOAD_LIMIT > 0 and request_size > SERVER_CODE_UPLOAD_LIMIT: + diff_size_fmt = sizeof_fmt(request_size) + limit_fmt = sizeof_fmt(SERVER_CODE_UPLOAD_LIMIT) + if diff_size_fmt == limit_fmt: + diff_size_fmt = f"{request_size}B" + limit_fmt = f"{SERVER_CODE_UPLOAD_LIMIT}B" + raise ServerClientError( + f"Archive size is {diff_size_fmt}, which exceeds the limit of {limit_fmt}." + " Use .gitignore/.dstackignore to exclude large files." + " This limit can be modified by setting the DSTACK_SERVER_CODE_UPLOAD_LIMIT environment variable." + ) + archive = await files.upload_archive( + session=session, + user=user, + file=file, + ) + return archive diff --git a/src/dstack/_internal/server/schemas/files.py b/src/dstack/_internal/server/schemas/files.py new file mode 100644 index 0000000000..8cab50c9cf --- /dev/null +++ b/src/dstack/_internal/server/schemas/files.py @@ -0,0 +1,5 @@ +from dstack._internal.core.models.common import CoreModel + + +class GetFileArchiveByHashRequest(CoreModel): + hash: str diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 2572aa4b38..3cd5a92809 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -57,6 +57,7 @@ class SubmitBody(CoreModel): "repo_data", "configuration", "configuration_path", + "file_archives", }, } ), diff --git a/src/dstack/_internal/server/services/files.py b/src/dstack/_internal/server/services/files.py new file mode 100644 index 0000000000..d77ad94c78 --- /dev/null +++ b/src/dstack/_internal/server/services/files.py @@ -0,0 +1,91 @@ +import uuid +from typing import Optional + +from fastapi import UploadFile +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.files import FileArchive +from dstack._internal.server.models import FileArchiveModel, UserModel +from dstack._internal.server.services.storage import get_default_storage +from dstack._internal.utils.common import run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def get_archive_model( + session: AsyncSession, + id: uuid.UUID, + user: Optional[UserModel] = None, +) -> Optional[FileArchiveModel]: + stmt = select(FileArchiveModel).where(FileArchiveModel.id == id) + if user is not None: + stmt = stmt.where(FileArchiveModel.user_id == user.id) + res = await session.execute(stmt) + return res.scalar() + + +async def get_archive_model_by_hash( + session: AsyncSession, + user: UserModel, + hash: str, +) -> Optional[FileArchiveModel]: + res = await session.execute( + select(FileArchiveModel).where( + FileArchiveModel.user_id == user.id, + FileArchiveModel.blob_hash == hash, + ) + ) + return res.scalar() + + +async def get_archive_by_hash( + session: AsyncSession, + user: UserModel, + hash: str, +) -> Optional[FileArchive]: + archive_model = await get_archive_model_by_hash( + session=session, + user=user, + hash=hash, + ) + if archive_model is None: + return None + return archive_model_to_archive(archive_model) + + +async def upload_archive( + session: AsyncSession, + user: UserModel, + file: UploadFile, +) -> FileArchive: + if file.filename is None: + raise ServerClientError("filename not specified") + archive_hash = file.filename + archive_model = await get_archive_model_by_hash( + session=session, + user=user, + hash=archive_hash, + ) + if archive_model is not None: + logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash) + return archive_model_to_archive(archive_model) + blob = await file.read() + storage = get_default_storage() + if storage is not None: + await run_async(storage.upload_archive, str(user.id), archive_hash, blob) + archive_model = FileArchiveModel( + user_id=user.id, + blob_hash=archive_hash, + blob=blob if storage is None else None, + ) + session.add(archive_model) + await session.commit() + logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash) + return archive_model_to_archive(archive_model) + + +def archive_model_to_archive(archive_model: FileArchiveModel) -> FileArchive: + return FileArchive(id=archive_model.id, hash=archive_model.blob_hash) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 157090dd0e..979daa44c2 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -33,6 +33,7 @@ RunSpec, ) from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus +from dstack._internal.server import settings from dstack._internal.server.models import ( InstanceModel, JobModel, @@ -380,8 +381,10 @@ def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel): message=job_model.termination_reason_message, timeout=0, ) - # maybe somehow postpone removing old tasks to allow inspecting failed jobs? - shim_client.remove_task(task_id=job_model.id) + # maybe somehow postpone removing old tasks to allow inspecting failed jobs without + # the following setting? + if not settings.SERVER_KEEP_SHIM_TASKS: + shim_client.remove_task(task_id=job_model.id) else: shim_client.stop(force=True) diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index e205990aa8..be6a11be83 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -109,6 +109,14 @@ def submit_job( ) resp.raise_for_status() + def upload_archive(self, id: uuid.UUID, file: Union[BinaryIO, bytes]): + resp = requests.post( + self._url("/api/upload_archive"), + files={"archive": (str(id), file)}, + timeout=UPLOAD_CODE_REQUEST_TIMEOUT, + ) + resp.raise_for_status() + def upload_code(self, file: Union[BinaryIO, bytes]): resp = requests.post( self._url("/api/upload_code"), data=file, timeout=UPLOAD_CODE_REQUEST_TIMEOUT diff --git a/src/dstack/_internal/server/services/storage/base.py b/src/dstack/_internal/server/services/storage/base.py index 5864f061d5..de11599693 100644 --- a/src/dstack/_internal/server/services/storage/base.py +++ b/src/dstack/_internal/server/services/storage/base.py @@ -22,6 +22,27 @@ def get_code( ) -> Optional[bytes]: pass + @abstractmethod + def upload_archive( + self, + user_id: str, + archive_hash: str, + blob: bytes, + ): + pass + + @abstractmethod + def get_archive( + self, + user_id: str, + archive_hash: str, + ) -> Optional[bytes]: + pass + @staticmethod def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str: return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}" + + @staticmethod + def _get_archive_key(user_id: str, archive_hash: str) -> str: + return f"data/users/{user_id}/file_archives/{archive_hash}" diff --git a/src/dstack/_internal/server/services/storage/gcs.py b/src/dstack/_internal/server/services/storage/gcs.py index 1075aba01f..6c565625e2 100644 --- a/src/dstack/_internal/server/services/storage/gcs.py +++ b/src/dstack/_internal/server/services/storage/gcs.py @@ -25,9 +25,8 @@ def upload_code( code_hash: str, blob: bytes, ): - blob_name = self._get_code_key(project_id, repo_id, code_hash) - blob_obj = self._bucket.blob(blob_name) - blob_obj.upload_from_string(blob) + key = self._get_code_key(project_id, repo_id, code_hash) + self._upload(key, blob) def get_code( self, @@ -35,10 +34,33 @@ def get_code( repo_id: str, code_hash: str, ) -> Optional[bytes]: + key = self._get_code_key(project_id, repo_id, code_hash) + return self._get(key) + + def upload_archive( + self, + user_id: str, + archive_hash: str, + blob: bytes, + ): + key = self._get_archive_key(user_id, archive_hash) + self._upload(key, blob) + + def get_archive( + self, + user_id: str, + archive_hash: str, + ) -> Optional[bytes]: + key = self._get_archive_key(user_id, archive_hash) + return self._get(key) + + def _upload(self, key: str, blob: bytes): + blob_obj = self._bucket.blob(key) + blob_obj.upload_from_string(blob) + + def _get(self, key: str) -> Optional[bytes]: try: - blob_name = self._get_code_key(project_id, repo_id, code_hash) - blob = self._bucket.blob(blob_name) + blob = self._bucket.blob(key) except NotFound: return None - return blob.download_as_bytes() diff --git a/src/dstack/_internal/server/services/storage/s3.py b/src/dstack/_internal/server/services/storage/s3.py index 8c67f28c2b..a0b993c731 100644 --- a/src/dstack/_internal/server/services/storage/s3.py +++ b/src/dstack/_internal/server/services/storage/s3.py @@ -27,11 +27,8 @@ def upload_code( code_hash: str, blob: bytes, ): - self._client.put_object( - Bucket=self.bucket, - Key=self._get_code_key(project_id, repo_id, code_hash), - Body=blob, - ) + key = self._get_code_key(project_id, repo_id, code_hash) + self._upload(key, blob) def get_code( self, @@ -39,11 +36,32 @@ def get_code( repo_id: str, code_hash: str, ) -> Optional[bytes]: + key = self._get_code_key(project_id, repo_id, code_hash) + return self._get(key) + + def upload_archive( + self, + user_id: str, + archive_hash: str, + blob: bytes, + ): + key = self._get_archive_key(user_id, archive_hash) + self._upload(key, blob) + + def get_archive( + self, + user_id: str, + archive_hash: str, + ) -> Optional[bytes]: + key = self._get_archive_key(user_id, archive_hash) + return self._get(key) + + def _upload(self, key: str, blob: bytes): + self._client.put_object(Bucket=self.bucket, Key=key, Body=blob) + + def _get(self, key: str) -> Optional[bytes]: try: - response = self._client.get_object( - Bucket=self.bucket, - Key=self._get_code_key(project_id, repo_id, code_hash), - ) + response = self._client.get_object(Bucket=self.bucket, Key=key) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "NoSuchKey": return None diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 5df67123bb..e71c8ce657 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -70,6 +70,8 @@ os.getenv("DSTACK_SERVER_METRICS_FINISHED_TTL_SECONDS", 7 * 24 * 3600) ) +SERVER_KEEP_SHIM_TASKS = os.getenv("DSTACK_SERVER_KEEP_SHIM_TASKS") is not None + DEFAULT_PROJECT_NAME = "main" SENTRY_DSN = os.getenv("DSTACK_SENTRY_DSN") diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 31045a036b..7aadb48979 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -77,6 +77,7 @@ from dstack._internal.server.models import ( BackendModel, DecryptedString, + FileArchiveModel, FleetModel, GatewayComputeModel, GatewayModel, @@ -232,6 +233,22 @@ async def create_repo_creds( return repo_creds +async def create_file_archive( + session: AsyncSession, + user_id: UUID, + blob_hash: str = "blob_hash", + blob: bytes = b"blob_content", +) -> FileArchiveModel: + archive = FileArchiveModel( + user_id=user_id, + blob_hash=blob_hash, + blob=blob, + ) + session.add(archive) + await session.commit() + return archive + + def get_run_spec( run_name: str, repo_id: str, diff --git a/src/dstack/_internal/utils/files.py b/src/dstack/_internal/utils/files.py new file mode 100644 index 0000000000..71a0f8ef70 --- /dev/null +++ b/src/dstack/_internal/utils/files.py @@ -0,0 +1,69 @@ +import tarfile +from pathlib import Path +from typing import BinaryIO + +import ignore +import ignore.overrides + +from dstack._internal.utils.hash import get_sha256 +from dstack._internal.utils.path import PathLike, normalize_path + + +def create_file_archive(root: PathLike, fp: BinaryIO) -> str: + """ + Packs the directory or file to a tar archive and writes it to the file-like object. + + Archives can be used to transfer file(s) (e.g., over the network) preserving + file properties such as permissions, timestamps, etc. + + NOTE: `.gitignore` and `.dstackignore` are respected. + + Args: + root: The absolute path to the directory or file. + fp: The binary file-like object. + + Returns: + The SHA-256 hash of the archive as a hex string. + + Raises: + ValueError: If the path is not absolute. + OSError: Underlying errors from the tarfile module + """ + root = Path(root) + if not root.is_absolute(): + raise ValueError(f"path must be absolute: {root}") + walk = ( + ignore.WalkBuilder(root) + .overrides(ignore.overrides.OverrideBuilder(root).add("!/.git/").build()) + .hidden(False) # do not ignore files that start with a dot + .require_git(False) # respect git ignore rules even if not a git repo + .add_custom_ignore_filename(".dstackignore") + .build() + ) + # sort paths to ensure archive reproducibility + paths = sorted(entry.path() for entry in walk) + with tarfile.TarFile(mode="w", fileobj=fp) as t: + for path in paths: + arcname = str(path.relative_to(root.parent)) + info = t.gettarinfo(path, arcname) + if info.issym(): + # Symlinks are handled as follows: each symlink in the chain is checked, and + # * if the target is inside the root: keep relative links as is, replace absolute + # links with relative ones; + # * if the target is outside the root: replace the link with the actual file. + target = Path(info.linkname) + if not target.is_absolute(): + target = path.parent / target + target = normalize_path(target) + try: + target.relative_to(root) + except ValueError: + # Adding as a file + t.add(path.resolve(), arcname, recursive=False) + else: + # Adding as a relative symlink + info.linkname = str(target.relative_to(path.parent, walk_up=True)) + t.addfile(info) + else: + t.add(path, arcname, recursive=False) + return get_sha256(fp) diff --git a/src/dstack/_internal/utils/path.py b/src/dstack/_internal/utils/path.py index 8341407dcf..fc51c488a1 100644 --- a/src/dstack/_internal/utils/path.py +++ b/src/dstack/_internal/utils/path.py @@ -27,16 +27,24 @@ def path_in_dir(path: PathLike, directory: PathLike) -> bool: return False -def resolve_relative_path(path: str) -> PurePath: +def normalize_path(path: PathLike) -> PurePath: path = PurePath(path) - if path.is_absolute(): - raise ValueError("Path should be relative") stack = [] for part in path.parts: if part == "..": if not stack: - raise ValueError("Path is outside of the repo") + raise ValueError("Path is outside of the top directory") stack.pop() else: stack.append(part) return PurePath(*stack) + + +def resolve_relative_path(path: PathLike) -> PurePath: + path = PurePath(path) + if path.is_absolute(): + raise ValueError("Path should be relative") + try: + return normalize_path(path) + except ValueError: + raise ValueError("Path is outside of the repo") diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index da16fc3c6a..36c5a5d11e 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -17,6 +17,7 @@ from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration, PortMapping +from dstack._internal.core.models.files import FileArchiveMapping, FilePathMapping from dstack._internal.core.models.profiles import ( CreationPolicy, Profile, @@ -42,6 +43,7 @@ from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.utils.common import get_or_error, make_proxy_url +from dstack._internal.utils.files import create_file_archive from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike, path_in_dir from dstack.api.server import APIClient @@ -476,20 +478,38 @@ def apply_plan( # TODO handle multiple jobs ports_lock = _reserve_ports(run_plan.job_plans[0].job_spec) + run_spec = run_plan.run_spec + configuration = run_spec.configuration + + self._validate_configuration_files(configuration, run_spec.configuration_path) + for file_mapping in configuration.files: + assert isinstance(file_mapping, FilePathMapping) + with tempfile.TemporaryFile("w+b") as fp: + try: + archive_hash = create_file_archive(file_mapping.local_path, fp) + except OSError as e: + raise ClientError(f"failed to archive '{file_mapping.local_path}': {e}") from e + fp.seek(0) + archive = self._api_client.files.upload_archive(hash=archive_hash, fp=fp) + run_spec.file_archives.append( + FileArchiveMapping(id=archive.id, path=file_mapping.path) + ) + if repo is None: repo = VirtualRepo() else: # Do not upload the diff without a repo (a default virtual repo) # since upload_code() requires a repo to be initialized. with tempfile.TemporaryFile("w+b") as fp: - run_plan.run_spec.repo_code_hash = repo.write_code_file(fp) + run_spec.repo_code_hash = repo.write_code_file(fp) fp.seek(0) self._api_client.repos.upload_code( project_name=self._project, repo_id=repo.repo_id, - code_hash=run_plan.run_spec.repo_code_hash, + code_hash=run_spec.repo_code_hash, fp=fp, ) + run = self._api_client.runs.apply_plan(self._project, run_plan) return self._model_to_submitted_run(run, ports_lock) @@ -762,6 +782,30 @@ def _model_to_submitted_run(self, run: RunModel, ports_lock: Optional[PortsLock] ports_lock, ) + def _validate_configuration_files( + self, configuration: AnyRunConfiguration, configuration_path: Optional[PathLike] + ) -> None: + """ + Expands, normalizes and validates local paths specified in + the `files` configuration property. + """ + base_dir: Optional[Path] = None + if configuration_path is not None: + base_dir = Path(configuration_path).expanduser().resolve().parent + for file_mapping in configuration.files: + assert isinstance(file_mapping, FilePathMapping) + path = Path(file_mapping.local_path).expanduser() + if not path.is_absolute(): + if base_dir is None: + raise ConfigurationError( + f"Path '{path}' is relative but `configuration_path` is not provided" + ) + else: + path = base_dir / path + if not path.exists(): + raise ConfigurationError(f"Path '{path}' specified in `files` does not exist") + file_mapping.local_path = str(path) + def _reserve_ports( job_spec: JobSpec, diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index f09e8672c2..5cba94f552 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -14,6 +14,7 @@ ) from dstack._internal.utils.logging import get_logger from dstack.api.server._backends import BackendsAPIClient +from dstack.api.server._files import FilesAPIClient from dstack.api.server._fleets import FleetsAPIClient from dstack.api.server._gateways import GatewaysAPIClient from dstack.api.server._logs import LogsAPIClient @@ -47,6 +48,7 @@ class APIClient: logs: operations with logs gateways: operations with gateways volumes: operations with volumes + files: operations with files """ def __init__(self, base_url: str, token: str): @@ -111,6 +113,10 @@ def gateways(self) -> GatewaysAPIClient: def volumes(self) -> VolumesAPIClient: return VolumesAPIClient(self._request) + @property + def files(self) -> FilesAPIClient: + return FilesAPIClient(self._request) + def _request( self, path: str, diff --git a/src/dstack/api/server/_files.py b/src/dstack/api/server/_files.py new file mode 100644 index 0000000000..e7bdde91a3 --- /dev/null +++ b/src/dstack/api/server/_files.py @@ -0,0 +1,18 @@ +from typing import BinaryIO + +from pydantic import parse_obj_as + +from dstack._internal.core.models.files import FileArchive +from dstack._internal.server.schemas.files import GetFileArchiveByHashRequest +from dstack.api.server._group import APIClientGroup + + +class FilesAPIClient(APIClientGroup): + def get_archive_by_hash(self, hash: str) -> FileArchive: + body = GetFileArchiveByHashRequest(hash=hash) + resp = self._request("/api/files/get_archive_by_hash", body=body.json()) + return parse_obj_as(FileArchive.__response__, resp.json()) + + def upload_archive(self, hash: str, fp: BinaryIO) -> FileArchive: + resp = self._request("/api/files/upload_archive", files={"file": (hash, fp)}) + return parse_obj_as(FileArchive.__response__, resp.json()) diff --git a/src/tests/_internal/core/models/test_files.py b/src/tests/_internal/core/models/test_files.py new file mode 100644 index 0000000000..d2761d4e92 --- /dev/null +++ b/src/tests/_internal/core/models/test_files.py @@ -0,0 +1,29 @@ +import pytest +from pydantic import ValidationError + +from dstack._internal.core.models.files import FilePathMapping + + +class TestFilePathMapping: + @pytest.mark.parametrize("value", ["./file", "file", "~/file", "/file"]) + def test_parse_only_local_path(self, value: str): + assert FilePathMapping.parse(value) == FilePathMapping(local_path=value, path=value) + + def test_parse_both_paths(self): + assert FilePathMapping.parse("./foo:./bar") == FilePathMapping( + local_path="./foo", path="./bar" + ) + + def test_parse_windows_abs_path(self): + assert FilePathMapping.parse("C:\\dir:dir") == FilePathMapping( + local_path="C:\\dir", path="dir" + ) + + def test_error_invalid_mapping_if_more_than_two_parts(self): + with pytest.raises(ValueError, match="invalid file path mapping"): + FilePathMapping.parse("./foo:bar:baz") + + @pytest.mark.parametrize("value", ["C:\\", "d:/path/to"]) + def test_error_must_be_unix_path(self, value: str): + with pytest.raises(ValidationError, match="path must be a Unix file path"): + FilePathMapping.parse(value) diff --git a/src/tests/_internal/server/routers/test_files.py b/src/tests/_internal/server/routers/test_files.py new file mode 100644 index 0000000000..931a24b568 --- /dev/null +++ b/src/tests/_internal/server/routers/test_files.py @@ -0,0 +1,148 @@ +from unittest.mock import Mock + +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.users import GlobalRole +from dstack._internal.server.models import FileArchiveModel +from dstack._internal.server.services.storage import BaseStorage +from dstack._internal.server.testing.common import ( + create_file_archive, + create_user, + get_auth_headers, +) + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.usefixtures("test_db"), + pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True), +] + + +class TestGetArchiveByHash: + async def test_returns_403_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/files/get_archive_by_hash", + json={"hash": "blob_hash"}, + ) + assert response.status_code == 403 + + async def test_returns_400_if_archive_does_not_exist( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + response = await client.post( + "/api/files/get_archive_by_hash", + headers=get_auth_headers(user.token), + json={"hash": "blob_hash"}, + ) + assert response.status_code == 400, response.json() + + async def test_returns_archive(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + archive = await create_file_archive( + session=session, user_id=user.id, blob_hash="blob_hash", blob=b"blob_content" + ) + response = await client.post( + "/api/files/get_archive_by_hash", + headers=get_auth_headers(user.token), + json={"hash": archive.blob_hash}, + ) + assert response.status_code == 200, response.json() + assert response.json() == { + "id": str(archive.id), + "hash": archive.blob_hash, + } + + +class TestUploadArchive: + file_hash = "blob_hash" + file_content = b"blob_content" + file = (file_hash, file_content) + + @pytest.fixture + def default_storage_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + storage_mock = Mock(spec_set=BaseStorage) + monkeypatch.setattr( + "dstack._internal.server.services.files.get_default_storage", lambda: storage_mock + ) + return storage_mock + + @pytest.fixture + def no_default_storage(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + "dstack._internal.server.services.files.get_default_storage", lambda: None + ) + + async def test_returns_403_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/files/upload_archive", + files={"file": self.file}, + ) + assert response.status_code == 403 + + async def test_returns_existing_archive( + self, session: AsyncSession, client: AsyncClient, default_storage_mock: Mock + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + existing_archive = await create_file_archive( + session=session, user_id=user.id, blob_hash=self.file_hash, blob=b"existing_blob" + ) + response = await client.post( + "/api/files/upload_archive", + headers=get_auth_headers(user.token), + files={"file": self.file}, + ) + assert response.status_code == 200, response.json() + assert response.json() == { + "id": str(existing_archive.id), + "hash": self.file_hash, + } + res = await session.execute( + select(FileArchiveModel).where(FileArchiveModel.user_id == user.id) + ) + archive = res.scalar_one() + assert archive.id == existing_archive.id + assert archive.blob_hash == self.file_hash + assert archive.blob == existing_archive.blob + default_storage_mock.upload_archive.assert_not_called() + + @pytest.mark.usefixtures("no_default_storage") + async def test_uploads_archive_to_db(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + response = await client.post( + "/api/files/upload_archive", + headers=get_auth_headers(user.token), + files={"file": self.file}, + ) + assert response.status_code == 200, response.json() + assert response.json()["hash"] == self.file_hash + res = await session.execute( + select(FileArchiveModel).where(FileArchiveModel.user_id == user.id) + ) + archive = res.scalar_one() + assert archive.blob_hash == self.file_hash + assert archive.blob == self.file_content + + async def test_uploads_archive_to_storage( + self, session: AsyncSession, client: AsyncClient, default_storage_mock: Mock + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + response = await client.post( + "/api/files/upload_archive", + headers=get_auth_headers(user.token), + files={"file": self.file}, + ) + assert response.status_code == 200, response.json() + assert response.json()["hash"] == self.file_hash + res = await session.execute( + select(FileArchiveModel).where(FileArchiveModel.user_id == user.id) + ) + archive = res.scalar_one() + assert archive.blob_hash == self.file_hash + assert archive.blob is None + default_storage_mock.upload_archive.assert_called_once_with( + str(user.id), self.file_hash, self.file_content + ) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index c2e50794a6..880c3be5df 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -153,6 +153,7 @@ def get_dev_env_run_plan_dict( "shm_size": None, }, "volumes": [json.loads(v.json()) for v in volumes], + "files": [], "backends": ["local", "aws", "azure", "gcp", "lambda", "runpod"], "regions": ["us"], "availability_zones": None, @@ -174,6 +175,7 @@ def get_dev_env_run_plan_dict( "priority": 0, }, "configuration_path": "dstack.yaml", + "file_archives": [], "profile": { "backends": ["local", "aws", "azure", "gcp", "lambda", "runpod"], "regions": ["us"], @@ -348,6 +350,7 @@ def get_dev_env_run_dict( "shm_size": None, }, "volumes": [], + "files": [], "backends": ["local", "aws", "azure", "gcp", "lambda"], "regions": ["us"], "availability_zones": None, @@ -369,6 +372,7 @@ def get_dev_env_run_dict( "priority": 0, }, "configuration_path": "dstack.yaml", + "file_archives": [], "profile": { "backends": ["local", "aws", "azure", "gcp", "lambda"], "regions": ["us"], @@ -494,6 +498,7 @@ def get_service_run_spec( "model": "test-model", }, "configuration_path": "dstack.yaml", + "file_archives": [], "profile": { "name": "string", }, diff --git a/src/tests/_internal/utils/test_path.py b/src/tests/_internal/utils/test_path.py index 04fb5dfd2d..cb707892bc 100644 --- a/src/tests/_internal/utils/test_path.py +++ b/src/tests/_internal/utils/test_path.py @@ -2,7 +2,19 @@ import pytest -from dstack._internal.utils.path import resolve_relative_path +from dstack._internal.utils.path import normalize_path, resolve_relative_path + + +class TestNormalizePath: + def test_escape_top(self): + with pytest.raises(ValueError): + normalize_path("dir/../..") + + def test_normalize_rel(self): + assert normalize_path("dir/.///..///sibling") == PurePath("sibling") + + def test_normalize_abs(self): + assert normalize_path("/dir/.///..///sibling") == PurePath("/sibling") class TestResolveRelativePath: