diff --git a/docs/docs/reference/dstack.yml/dev-environment.md b/docs/docs/reference/dstack.yml/dev-environment.md index 0c872e1d18..9fb52b4893 100644 --- a/docs/docs/reference/dstack.yml/dev-environment.md +++ b/docs/docs/reference/dstack.yml/dev-environment.md @@ -90,4 +90,21 @@ The `dev-environment` configuration type allows running [dev environments](../.. The short syntax for volumes is a colon-separated string in the form of `source:destination` * `volume-name:/container/path` for network volumes - * `/instance/path:/container/path` for instance volumes + * `/instance/path:/container/path` for instance volumes + +### `files[n]` { #_files data-toc-label="files" } + +#SCHEMA# dstack._internal.core.models.files.FilePathMapping + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for files is a colon-separated string in the form of `local_path[:path]` where + `path` is optional and can be omitted if it's equal to `local_path`. + + * `~/.bashrc`, same as `~/.bashrc:~/.bashrc` + * `/opt/myorg`, same as `/opt/myorg/` and `/opt/myorg:/opt/myorg` + * `libs/patched_libibverbs.so.1:/lib/x86_64-linux-gnu/libibverbs.so.1` diff --git a/docs/docs/reference/dstack.yml/service.md b/docs/docs/reference/dstack.yml/service.md index d7cb13acd5..041c9d43fe 100644 --- a/docs/docs/reference/dstack.yml/service.md +++ b/docs/docs/reference/dstack.yml/service.md @@ -185,3 +185,20 @@ The `service` configuration type allows running [services](../../concepts/servic * `volume-name:/container/path` for network volumes * `/instance/path:/container/path` for instance volumes + +### `files[n]` { #_files data-toc-label="files" } + +#SCHEMA# dstack._internal.core.models.files.FilePathMapping + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for files is a colon-separated string in the form of `local_path[:path]` where + `path` is optional and can be omitted if it's equal to `local_path`. + + * `~/.bashrc`, same as `~/.bashrc:~/.bashrc` + * `/opt/myorg`, same as `/opt/myorg/` and `/opt/myorg:/opt/myorg` + * `libs/patched_libibverbs.so.1:/lib/x86_64-linux-gnu/libibverbs.so.1` diff --git a/docs/docs/reference/dstack.yml/task.md b/docs/docs/reference/dstack.yml/task.md index 0565dbf6f9..7a2cf791a2 100644 --- a/docs/docs/reference/dstack.yml/task.md +++ b/docs/docs/reference/dstack.yml/task.md @@ -91,3 +91,20 @@ The `task` configuration type allows running [tasks](../../concepts/tasks.md). * `volume-name:/container/path` for network volumes * `/instance/path:/container/path` for instance volumes + +### `files[n]` { #_files data-toc-label="files" } + +#SCHEMA# dstack._internal.core.models.files.FilePathMapping + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for files is a colon-separated string in the form of `local_path[:path]` where + `path` is optional and can be omitted if it's equal to `local_path`. + + * `~/.bashrc`, same as `~/.bashrc:~/.bashrc` + * `/opt/myorg`, same as `/opt/myorg/` and `/opt/myorg:/opt/myorg` + * `libs/patched_libibverbs.so.1:/lib/x86_64-linux-gnu/libibverbs.so.1` diff --git a/runner/go.mod b/runner/go.mod index 2eeee85393..22dad6466e 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -5,13 +5,16 @@ go 1.23 require ( github.com/alexellis/go-execute/v2 v2.2.1 github.com/bluekeyes/go-gitdiff v0.7.2 + github.com/codeclysm/extract/v4 v4.0.0 github.com/creack/pty v1.1.24 github.com/docker/docker v26.0.0+incompatible github.com/docker/go-connections v0.5.0 github.com/docker/go-units v0.5.0 github.com/go-git/go-git/v5 v5.12.0 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f + github.com/gorilla/websocket v1.5.1 github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf + github.com/prometheus/procfs v0.15.1 github.com/shirou/gopsutil/v4 v4.24.11 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.10.0 @@ -78,12 +81,6 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gotest.tools/v3 v3.5.0 // indirect -) - -require ( - github.com/codeclysm/extract/v3 v3.1.1 - github.com/gorilla/websocket v1.5.1 - github.com/prometheus/procfs v0.15.1 gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools/v3 v3.5.0 // indirect ) diff --git a/runner/go.sum b/runner/go.sum index c3cdaf43f8..41e133c465 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -13,8 +13,8 @@ github.com/alexellis/go-execute/v2 v2.2.1 h1:4Ye3jiCKQarstODOEmqDSRCqxMHLkC92Bhs github.com/alexellis/go-execute/v2 v2.2.1/go.mod h1:FMdRnUTiFAmYXcv23txrp3VYZfLo24nMpiIneWgKHTQ= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= -github.com/arduino/go-paths-helper v1.2.0 h1:qDW93PR5IZUN/jzO4rCtexiwF8P4OIcOmcSgAYLZfY4= -github.com/arduino/go-paths-helper v1.2.0/go.mod h1:HpxtKph+g238EJHq4geEPv9p+gl3v5YYu35Yb+w31Ck= +github.com/arduino/go-paths-helper v1.12.1 h1:WkxiVUxBjKWlLMiMuYy8DcmVrkxdP7aKxQOAq7r2lVM= +github.com/arduino/go-paths-helper v1.12.1/go.mod h1:jcpW4wr0u69GlXhTYydsdsqAjLaYK5n7oWHfKqOG6LM= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/bluekeyes/go-gitdiff v0.7.2 h1:42jrcVZdjjxXtVsFNYTo/I6T1ZvIiQL+iDDLiH904hw= @@ -26,8 +26,8 @@ github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyY github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= -github.com/codeclysm/extract/v3 v3.1.1 h1:iHZtdEAwSTqPrd+1n4jfhr1qBhUWtHlMTjT90+fJVXg= -github.com/codeclysm/extract/v3 v3.1.1/go.mod h1:ZJi80UG2JtfHqJI+lgJSCACttZi++dHxfWuPaMhlOfQ= +github.com/codeclysm/extract/v4 v4.0.0 h1:H87LFsUNaJTu2e/8p/oiuiUsOK/TaPQ5wxsjPnwPEIY= +github.com/codeclysm/extract/v4 v4.0.0/go.mod h1:SFju1lj6as7FvUgalpSct7torJE0zttbJUWtryPRG6s= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= diff --git a/runner/internal/api/common.go b/runner/internal/api/common.go index fcb5f40e6d..ae15369d73 100644 --- a/runner/internal/api/common.go +++ b/runner/internal/api/common.go @@ -105,18 +105,18 @@ func DecodeJSONBody(w http.ResponseWriter, r *http.Request, dst interface{}, all func JSONResponseHandler(handler func(http.ResponseWriter, *http.Request) (interface{}, error)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { status := 200 - msg := "" + errMsg := "" var apiErr *Error body, err := handler(w, r) if err != nil { if errors.As(err, &apiErr) { status = apiErr.Status - msg = apiErr.Error() - log.Warning(r.Context(), "API error", "err", apiErr.Err) + errMsg = apiErr.Error() + log.Warning(r.Context(), "API error", "err", errMsg, "status", status) } else { status = http.StatusInternalServerError - log.Error(r.Context(), "Unexpected API error", "err", err) + log.Error(r.Context(), "Unexpected API error", "err", err, "status", status) } } @@ -125,7 +125,7 @@ func JSONResponseHandler(handler func(http.ResponseWriter, *http.Request) (inter w.WriteHeader(status) _ = json.NewEncoder(w).Encode(body) } else { - http.Error(w, msg, status) + http.Error(w, errMsg, status) } log.Debug(r.Context(), "", "method", r.Method, "endpoint", r.URL.Path, "status", status) diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go index a8e1505520..2163ca9204 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/executor/base.go @@ -2,6 +2,7 @@ package executor import ( "context" + "io" "github.com/dstackai/dstack/runner/internal/schemas" "github.com/dstackai/dstack/runner/internal/types" @@ -22,6 +23,7 @@ type Executor interface { termination_message string, ) SetRunnerState(state string) + AddFileArchive(id string, src io.Reader) error Lock() RLock() RUnlock() diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 9fad1f8bf8..14f46d035c 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -31,8 +31,9 @@ type RunExecutor struct { tempDir string homeDir string workingDir string + archiveDir string sshPort int - uid uint32 + currentUid uint32 run schemas.Run jobSpec schemas.JobSpec @@ -77,8 +78,9 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort i tempDir: tempDir, homeDir: homeDir, workingDir: workingDir, + archiveDir: filepath.Join(tempDir, "file_archives"), sshPort: sshPort, - uid: uid, + currentUid: uid, mu: mu, state: WaitSubmit, @@ -131,6 +133,28 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { ctx = log.WithLogger(ctx, log.NewEntry(logger, int(log.DefaultEntry.Logger.Level))) // todo loglevel log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String()) + if ex.jobSpec.User != nil { + if err := fillUser(ex.jobSpec.User); err != nil { + ex.SetJobStateWithTerminationReason( + ctx, + types.JobStateFailed, + types.TerminationReasonExecutorError, + fmt.Sprintf("Failed to fill in the job user fields (%s)", err), + ) + return gerrors.Wrap(err) + } + } + + if err := ex.setupFiles(ctx); err != nil { + ex.SetJobStateWithTerminationReason( + ctx, + types.JobStateFailed, + types.TerminationReasonExecutorError, + fmt.Sprintf("Failed to set up files (%s)", err), + ) + return gerrors.Wrap(err) + } + if err := ex.setupRepo(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, @@ -140,6 +164,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { ) return gerrors.Wrap(err) } + cleanupCredentials, err := ex.setupCredentials(ctx) if err != nil { ex.SetJobState(ctx, types.JobStateFailed) @@ -300,16 +325,13 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error user := ex.jobSpec.User if user != nil { - if err := fillUser(user); err != nil { - return gerrors.Wrap(err) - } log.Trace( ctx, "Using credentials", "uid", *user.Uid, "gid", *user.Gid, "groups", user.GroupIds, "username", user.GetUsername(), "groupname", user.GetGroupname(), "home", user.HomeDir, ) - log.Trace(ctx, "Current user", "uid", ex.uid) + log.Trace(ctx, "Current user", "uid", ex.currentUid) // 1. Ideally, We should check uid, gid, and supplementary groups mismatches, // but, for the sake of simplicity, we only check uid. Unprivileged runner @@ -318,8 +340,8 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error // 2. Strictly speaking, we need CAP_SETUID and CAP_GUID (for Cmd.Start()-> // Cmd.SysProcAttr.Credential) and CAP_CHOWN (for startCommand()->os.Chown()), // but for the sake of simplicity we instead check if we are root or not - if *user.Uid != ex.uid && ex.uid != 0 { - return gerrors.Newf("cannot start job as %d, current uid is %d", *user.Uid, ex.uid) + if *user.Uid != ex.currentUid && ex.currentUid != 0 { + return gerrors.Newf("cannot start job as %d, current uid is %d", *user.Uid, ex.currentUid) } if cmd.SysProcAttr == nil { diff --git a/runner/internal/executor/files.go b/runner/internal/executor/files.go new file mode 100644 index 0000000000..89de206338 --- /dev/null +++ b/runner/internal/executor/files.go @@ -0,0 +1,132 @@ +package executor + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path" + "regexp" + "slices" + "strings" + + "github.com/codeclysm/extract/v4" + "github.com/dstackai/dstack/runner/internal/gerrors" + "github.com/dstackai/dstack/runner/internal/log" +) + +var renameRegex = regexp.MustCompile(`^([^/]*)(/|$)`) + +func (ex *RunExecutor) AddFileArchive(id string, src io.Reader) error { + if err := os.MkdirAll(ex.archiveDir, 0o755); err != nil { + return gerrors.Wrap(err) + } + archivePath := path.Join(ex.archiveDir, id) + archive, err := os.Create(archivePath) + if err != nil { + return gerrors.Wrap(err) + } + defer func() { _ = archive.Close() }() + if _, err = io.Copy(archive, src); err != nil { + return gerrors.Wrap(err) + } + return nil +} + +// setupFiles must be called from Run +func (ex *RunExecutor) setupFiles(ctx context.Context) error { + homeDir := ex.workingDir + uid := -1 + gid := -1 + if ex.jobSpec.User != nil { + if ex.jobSpec.User.HomeDir != "" { + homeDir = ex.jobSpec.User.HomeDir + } + if ex.jobSpec.User.Uid != nil { + uid = int(*ex.jobSpec.User.Uid) + } + if ex.jobSpec.User.Gid != nil { + gid = int(*ex.jobSpec.User.Gid) + } + } + + for _, fa := range ex.run.RunSpec.FileArchives { + log.Trace(ctx, "Extracting file archive", "id", fa.Id, "path", fa.Path) + + p := path.Clean(fa.Path) + // `~username[/path/to]` is not supported + if p == "~" { + p = homeDir + } else if rest, found := strings.CutPrefix(p, "~/"); found { + p = path.Join(homeDir, rest) + } else if !path.IsAbs(p) { + p = path.Join(ex.workingDir, p) + } + dir, root := path.Split(p) + if err := mkdirAll(ctx, dir, uid, gid); err != nil { + return gerrors.Wrap(err) + } + + if err := os.RemoveAll(p); err != nil { + log.Warning(ctx, "Failed to remove", "path", p, "err", err) + } + + archivePath := path.Join(ex.archiveDir, fa.Id) + archive, err := os.Open(archivePath) + if err != nil { + return gerrors.Wrap(err) + } + defer func() { + _ = archive.Close() + if err := os.Remove(archivePath); err != nil { + log.Warning(ctx, "Failed to remove archive", "path", archivePath, "err", err) + } + }() + + var paths []string + repl := fmt.Sprintf("%s$2", root) + renameAndRemember := func(s string) string { + s = renameRegex.ReplaceAllString(s, repl) + paths = append(paths, s) + return s + } + if err := extract.Tar(ctx, archive, dir, renameAndRemember); err != nil { + return gerrors.Wrap(err) + } + + if uid != -1 || gid != -1 { + for _, p := range paths { + if err := os.Chown(path.Join(dir, p), uid, gid); err != nil { + log.Warning(ctx, "Failed to chown", "path", p, "err", err) + } + } + } + } + + return nil +} + +func mkdirAll(ctx context.Context, p string, uid int, gid int) error { + var paths []string + for { + p = path.Dir(p) + if p == "/" { + break + } + paths = append(paths, p) + } + for _, p := range slices.Backward(paths) { + if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) { + if err := os.Mkdir(p, 0o755); err != nil { + return err + } + if err := os.Chown(p, uid, gid); err != nil { + log.Warning(ctx, "Failed to chown", "path", p, "err", err) + } + } else if err != nil { + return err + } + } + return nil +} diff --git a/runner/internal/executor/repo.go b/runner/internal/executor/repo.go index bb4121d726..5afbb2515a 100644 --- a/runner/internal/executor/repo.go +++ b/runner/internal/executor/repo.go @@ -7,7 +7,7 @@ import ( "os/exec" "path/filepath" - "github.com/codeclysm/extract/v3" + "github.com/codeclysm/extract/v4" "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/repo" diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 36c63527fb..f323de86ed 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -2,8 +2,12 @@ package api import ( "context" + "errors" + "fmt" "io" "math" + "mime" + "mime/multipart" "net/http" "os" "path/filepath" @@ -58,6 +62,60 @@ func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (inte return nil, nil } +// uploadArchivePostHandler may be called 0 or more times, and must be called after submitPostHandler +// and before uploadCodePostHandler +func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + s.executor.Lock() + defer s.executor.Unlock() + if s.executor.GetRunnerState() != executor.WaitCode { + return nil, &api.Error{Status: http.StatusConflict} + } + + contentType := r.Header.Get("Content-Type") + if contentType == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing content-type header"} + } + mediaType, params, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, gerrors.Wrap(err) + } + if mediaType != "multipart/form-data" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: fmt.Sprintf("multipart/form-data expected, got %s", mediaType)} + } + boundary := params["boundary"] + if boundary == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing boundary"} + } + + formReader := multipart.NewReader(r.Body, boundary) + part, err := formReader.NextPart() + if errors.Is(err, io.EOF) { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty form"} + } + if err != nil { + return nil, gerrors.Wrap(err) + } + fieldName := part.FormName() + if fieldName == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing field name"} + } + if fieldName != "archive" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: fmt.Sprintf("unexpected field %s", fieldName)} + } + archiveId := part.FileName() + if archiveId == "" { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing file name"} + } + if err := s.executor.AddFileArchive(archiveId, part); err != nil { + return nil, gerrors.Wrap(err) + } + if _, err := formReader.NextPart(); !errors.Is(err, io.EOF) { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "extra form field(s)"} + } + + return nil, nil +} + func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.executor.Lock() defer s.executor.Unlock() diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 732d5ed50f..a204c1c468 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -61,6 +61,7 @@ func NewServer(tempDir string, homeDir string, workingDir string, address string r.AddHandler("GET", "/api/healthcheck", s.healthcheckGetHandler) r.AddHandler("GET", "/api/metrics", s.metricsGetHandler) r.AddHandler("POST", "/api/submit", s.submitPostHandler) + r.AddHandler("POST", "/api/upload_archive", s.uploadArchivePostHandler) r.AddHandler("POST", "/api/upload_code", s.uploadCodePostHandler) r.AddHandler("POST", "/api/run", s.runPostHandler) r.AddHandler("GET", "/api/pull", s.pullGetHandler) diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index 4391874007..8d4cef2f1c 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -47,6 +47,7 @@ type RunSpec struct { RunName string `json:"run_name"` RepoId string `json:"repo_id"` RepoData RepoData `json:"repo_data"` + FileArchives []FileArchive `json:"file_archives"` Configuration Configuration `json:"configuration"` ConfigurationPath string `json:"configuration_path"` } @@ -96,6 +97,11 @@ type RepoData struct { RepoConfigEmail string `json:"repo_config_email"` } +type FileArchive struct { + Id string `json:"id"` + Path string `json:"path"` +} + type Configuration struct { Type string `json:"type"` } diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 97f90c8d2e..696cd7f025 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -109,6 +109,10 @@ def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]: configuration_excludes["stop_criteria"] = True if profile is not None and profile.stop_criteria is None: profile_excludes.add("stop_criteria") + if not configuration.files: + configuration_excludes["files"] = True + if not run_spec.file_archives: + spec_excludes["file_archives"] = True if configuration_excludes: spec_excludes["configuration"] = configuration_excludes diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 92ae999ba0..97be403ca1 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -10,6 +10,7 @@ from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth from dstack._internal.core.models.envs import Env +from dstack._internal.core.models.files import FilePathMapping from dstack._internal.core.models.fleets import FleetConfiguration from dstack._internal.core.models.gateways import GatewayConfiguration from dstack._internal.core.models.profiles import ProfileParams, parse_off_duration @@ -252,6 +253,10 @@ class BaseRunConfiguration(CoreModel): description="Use Docker inside the container. Mutually exclusive with `image`, `python`, and `nvcc`. Overrides `privileged`" ), ] = None + files: Annotated[ + list[Union[FilePathMapping, str]], + Field(description="The local to container file path mappings"), + ] = [] # deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init` setup: CommandsList = [] @@ -285,6 +290,12 @@ def convert_volumes(cls, v) -> MountPoint: return parse_mount_point(v) return v + @validator("files", each_item=True) + def convert_files(cls, v) -> FilePathMapping: + if isinstance(v, str): + return FilePathMapping.parse(v) + return v + @validator("user") def validate_user(cls, v) -> Optional[str]: if v is None: diff --git a/src/dstack/_internal/core/models/files.py b/src/dstack/_internal/core/models/files.py new file mode 100644 index 0000000000..f2e4f6826d --- /dev/null +++ b/src/dstack/_internal/core/models/files.py @@ -0,0 +1,67 @@ +import pathlib +import string +from uuid import UUID + +from pydantic import Field, validator +from typing_extensions import Annotated, Self + +from dstack._internal.core.models.common import CoreModel + + +class FileArchive(CoreModel): + id: UUID + hash: str + + +class FilePathMapping(CoreModel): + local_path: Annotated[ + str, + Field( + description=( + "The path on the user's machine. Relative paths are resolved relative to" + " the parent directory of the the configuration file" + ) + ), + ] + path: Annotated[ + str, + Field( + description=( + "The path in the container. Relative paths are resolved relative to" + " the repo directory (`/workflow`)" + ) + ), + ] + + @classmethod + def parse(cls, v: str) -> Self: + local_path: str + path: str + parts = v.split(":") + # A special case for Windows paths, e.g., `C:\path\to`, 'c:/path/to' + if ( + len(parts) > 1 + and len(parts[0]) == 1 + and parts[0] in string.ascii_letters + and parts[1][:1] in ["\\", "/"] + ): + parts = [f"{parts[0]}:{parts[1]}", *parts[2:]] + if len(parts) == 1: + local_path = path = parts[0] + elif len(parts) == 2: + local_path, path = parts + else: + raise ValueError(f"invalid file path mapping: {v}") + return cls(local_path=local_path, path=path) + + @validator("path") + def validate_path(cls, v) -> str: + # True for `C:/.*`, False otherwise, including `/abs/unix/path`, `rel\windows\path`, etc. + if pathlib.PureWindowsPath(v).is_absolute(): + raise ValueError(f"path must be a Unix file path: {v}") + return v + + +class FileArchiveMapping(CoreModel): + id: Annotated[UUID, Field(description="The File archive ID")] + path: Annotated[str, Field(description="The path in the container")] diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 860aa5238e..df53247338 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -12,6 +12,7 @@ AnyRunConfiguration, RunConfiguration, ) +from dstack._internal.core.models.files import FileArchiveMapping from dstack._internal.core.models.instances import ( InstanceOfferWithAvailability, InstanceType, @@ -413,6 +414,10 @@ class RunSpec(CoreModel): Optional[str], Field(description="The hash of the repo diff. Can be omitted if there is no repo diff."), ] = None + file_archives: Annotated[ + list[FileArchiveMapping], + Field(description="The list of file archive ID to container path mappings"), + ] = [] working_dir: Annotated[ Optional[str], Field( diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 6377e4c261..6a76dadf48 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -23,6 +23,7 @@ from dstack._internal.server.db import get_db, get_session_ctx, migrate from dstack._internal.server.routers import ( backends, + files, fleets, gateways, instances, @@ -197,6 +198,7 @@ def register_routes(app: FastAPI, ui: bool = True): app.include_router(service_proxy.router, prefix="/proxy/services", tags=["service-proxy"]) app.include_router(model_proxy.router, prefix="/proxy/models", tags=["model-proxy"]) app.include_router(prometheus.router) + app.include_router(files.router) @app.exception_handler(ForbiddenError) async def forbidden_error_handler(request: Request, exc: ForbiddenError): diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index f287d11128..ef07becdfd 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -1,5 +1,6 @@ import asyncio import re +import uuid from collections.abc import Iterable from datetime import timedelta, timezone from typing import Dict, List, Optional @@ -14,6 +15,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode, RegistryAuth from dstack._internal.core.models.configurations import DevEnvironmentConfiguration +from dstack._internal.core.models.files import FileArchiveMapping from dstack._internal.core.models.instances import ( InstanceStatus, RemoteConnectionInfo, @@ -42,8 +44,10 @@ ProjectModel, RepoModel, RunModel, + UserModel, ) from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus +from dstack._internal.server.services import files as files_services from dstack._internal.server.services import logs as logs_services from dstack._internal.server.services import services from dstack._internal.server.services.instances import get_instance_ssh_private_keys @@ -226,12 +230,20 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): fmt(job_model), job_submission.age, ) + # FIXME: downloading file archives and code here is a waste of time if + # the runner is not ready yet + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=run.run_spec.file_archives, + user=run_model.user, + ) code = await _get_job_code( session=session, project=project, repo=repo_model, code_hash=run.run_spec.repo_code_hash, ) + success = await common_utils.run_async( _submit_job_to_runner, server_ssh_private_keys, @@ -242,6 +254,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job, cluster_info, code, + file_archives, secrets, repo_creds, success_if_not_available=False, @@ -269,6 +282,13 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): logger.debug( "%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age ) + # FIXME: downloading file archives and code here is a waste of time if + # the runner is not ready yet + file_archives = await _get_job_file_archives( + session=session, + archive_mappings=run.run_spec.file_archives, + user=run_model.user, + ) code = await _get_job_code( session=session, project=project, @@ -285,6 +305,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job, cluster_info, code, + file_archives, secrets, repo_creds, server_ssh_private_keys, @@ -588,6 +609,7 @@ def _process_pulling_with_shim( job: Job, cluster_info: ClusterInfo, code: bytes, + file_archives: Iterable[tuple[uuid.UUID, bytes]], secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], server_ssh_private_keys: tuple[str, Optional[str]], @@ -663,6 +685,7 @@ def _process_pulling_with_shim( job=job, cluster_info=cluster_info, code=code, + file_archives=file_archives, secrets=secrets, repo_credentials=repo_credentials, success_if_not_available=True, @@ -853,6 +876,43 @@ async def _get_job_code( return blob +async def _get_job_file_archives( + session: AsyncSession, + archive_mappings: Iterable[FileArchiveMapping], + user: UserModel, +) -> list[tuple[uuid.UUID, bytes]]: + archives: list[tuple[uuid.UUID, bytes]] = [] + for archive_mapping in archive_mappings: + archive_id = archive_mapping.id + archive_blob = await _get_job_file_archive( + session=session, archive_id=archive_id, user=user + ) + archives.append((archive_id, archive_blob)) + return archives + + +async def _get_job_file_archive( + session: AsyncSession, archive_id: uuid.UUID, user: UserModel +) -> bytes: + archive_model = await files_services.get_archive_model(session, id=archive_id, user=user) + if archive_model is None: + return b"" + if archive_model.blob is not None: + return archive_model.blob + storage = get_default_storage() + if storage is None: + return b"" + blob = await common_utils.run_async( + storage.get_archive, + str(archive_model.user_id), + archive_model.blob_hash, + ) + if blob is None: + logger.error("Failed to get file archive %s from storage", archive_id) + return b"" + return blob + + @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) def _submit_job_to_runner( ports: Dict[int, int], @@ -861,6 +921,7 @@ def _submit_job_to_runner( job: Job, cluster_info: ClusterInfo, code: bytes, + file_archives: Iterable[tuple[uuid.UUID, bytes]], secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], success_if_not_available: bool, @@ -900,6 +961,9 @@ def _submit_job_to_runner( repo_credentials=repo_credentials, instance_env=instance_env, ) + logger.debug("%s: uploading file archive(s)", fmt(job_model)) + for archive_id, archive in file_archives: + runner_client.upload_archive(archive_id, archive) logger.debug("%s: uploading code", fmt(job_model)) runner_client.upload_code(code) logger.debug("%s: starting job", fmt(job_model)) diff --git a/src/dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py b/src/dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py new file mode 100644 index 0000000000..a73d9db250 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/5f1707c525d2_add_filearchivemodel.py @@ -0,0 +1,39 @@ +"""Add FileArchiveModel + +Revision ID: 5f1707c525d2 +Revises: 35e90e1b0d3e +Create Date: 2025-06-12 12:28:26.678380 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5f1707c525d2" +down_revision = "35e90e1b0d3e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "file_archives", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("user_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("blob_hash", sa.Text(), nullable=False), + sa.Column("blob", sa.LargeBinary(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("fk_file_archives_user_id_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_file_archives")), + sa.UniqueConstraint("user_id", "blob_hash", name="uq_file_archives_user_id_blob_hash"), + ) + + +def downgrade() -> None: + op.drop_table("file_archives") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index c5e4749c99..dd8fd5f125 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -315,6 +315,21 @@ class CodeModel(BaseModel): blob: Mapped[Optional[bytes]] = mapped_column(LargeBinary) # None means blob is stored on s3 +class FileArchiveModel(BaseModel): + __tablename__ = "file_archives" + __table_args__ = ( + UniqueConstraint("user_id", "blob_hash", name="uq_file_archives_user_id_blob_hash"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + user_id: Mapped["UserModel"] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) + user: Mapped["UserModel"] = relationship() + blob_hash: Mapped[str] = mapped_column(Text) + blob: Mapped[Optional[bytes]] = mapped_column(LargeBinary) # None means blob is stored on s3 + + class RunModel(BaseModel): __tablename__ = "runs" diff --git a/src/dstack/_internal/server/routers/files.py b/src/dstack/_internal/server/routers/files.py new file mode 100644 index 0000000000..574ef01776 --- /dev/null +++ b/src/dstack/_internal/server/routers/files.py @@ -0,0 +1,67 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Request, UploadFile +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ResourceNotExistsError, ServerClientError +from dstack._internal.core.models.files import FileArchive +from dstack._internal.server.db import get_session +from dstack._internal.server.models import UserModel +from dstack._internal.server.schemas.files import GetFileArchiveByHashRequest +from dstack._internal.server.security.permissions import Authenticated +from dstack._internal.server.services import files +from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT +from dstack._internal.server.utils.routers import ( + get_base_api_additional_responses, + get_request_size, +) +from dstack._internal.utils.common import sizeof_fmt + +router = APIRouter( + prefix="/api/files", + tags=["files"], + responses=get_base_api_additional_responses(), +) + + +@router.post("/get_archive_by_hash") +async def get_archive_by_hash( + body: GetFileArchiveByHashRequest, + session: Annotated[AsyncSession, Depends(get_session)], + user: Annotated[UserModel, Depends(Authenticated())], +) -> FileArchive: + archive = await files.get_archive_by_hash( + session=session, + user=user, + hash=body.hash, + ) + if archive is None: + raise ResourceNotExistsError() + return archive + + +@router.post("/upload_archive") +async def upload_archive( + request: Request, + file: UploadFile, + session: Annotated[AsyncSession, Depends(get_session)], + user: Annotated[UserModel, Depends(Authenticated())], +) -> FileArchive: + request_size = get_request_size(request) + if SERVER_CODE_UPLOAD_LIMIT > 0 and request_size > SERVER_CODE_UPLOAD_LIMIT: + diff_size_fmt = sizeof_fmt(request_size) + limit_fmt = sizeof_fmt(SERVER_CODE_UPLOAD_LIMIT) + if diff_size_fmt == limit_fmt: + diff_size_fmt = f"{request_size}B" + limit_fmt = f"{SERVER_CODE_UPLOAD_LIMIT}B" + raise ServerClientError( + f"Archive size is {diff_size_fmt}, which exceeds the limit of {limit_fmt}." + " Use .gitignore/.dstackignore to exclude large files." + " This limit can be modified by setting the DSTACK_SERVER_CODE_UPLOAD_LIMIT environment variable." + ) + archive = await files.upload_archive( + session=session, + user=user, + file=file, + ) + return archive diff --git a/src/dstack/_internal/server/schemas/files.py b/src/dstack/_internal/server/schemas/files.py new file mode 100644 index 0000000000..8cab50c9cf --- /dev/null +++ b/src/dstack/_internal/server/schemas/files.py @@ -0,0 +1,5 @@ +from dstack._internal.core.models.common import CoreModel + + +class GetFileArchiveByHashRequest(CoreModel): + hash: str diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 2572aa4b38..3cd5a92809 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -57,6 +57,7 @@ class SubmitBody(CoreModel): "repo_data", "configuration", "configuration_path", + "file_archives", }, } ), diff --git a/src/dstack/_internal/server/services/files.py b/src/dstack/_internal/server/services/files.py new file mode 100644 index 0000000000..d77ad94c78 --- /dev/null +++ b/src/dstack/_internal/server/services/files.py @@ -0,0 +1,91 @@ +import uuid +from typing import Optional + +from fastapi import UploadFile +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.files import FileArchive +from dstack._internal.server.models import FileArchiveModel, UserModel +from dstack._internal.server.services.storage import get_default_storage +from dstack._internal.utils.common import run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def get_archive_model( + session: AsyncSession, + id: uuid.UUID, + user: Optional[UserModel] = None, +) -> Optional[FileArchiveModel]: + stmt = select(FileArchiveModel).where(FileArchiveModel.id == id) + if user is not None: + stmt = stmt.where(FileArchiveModel.user_id == user.id) + res = await session.execute(stmt) + return res.scalar() + + +async def get_archive_model_by_hash( + session: AsyncSession, + user: UserModel, + hash: str, +) -> Optional[FileArchiveModel]: + res = await session.execute( + select(FileArchiveModel).where( + FileArchiveModel.user_id == user.id, + FileArchiveModel.blob_hash == hash, + ) + ) + return res.scalar() + + +async def get_archive_by_hash( + session: AsyncSession, + user: UserModel, + hash: str, +) -> Optional[FileArchive]: + archive_model = await get_archive_model_by_hash( + session=session, + user=user, + hash=hash, + ) + if archive_model is None: + return None + return archive_model_to_archive(archive_model) + + +async def upload_archive( + session: AsyncSession, + user: UserModel, + file: UploadFile, +) -> FileArchive: + if file.filename is None: + raise ServerClientError("filename not specified") + archive_hash = file.filename + archive_model = await get_archive_model_by_hash( + session=session, + user=user, + hash=archive_hash, + ) + if archive_model is not None: + logger.debug("File archive (user_id=%s, hash=%s) already uploaded", user.id, archive_hash) + return archive_model_to_archive(archive_model) + blob = await file.read() + storage = get_default_storage() + if storage is not None: + await run_async(storage.upload_archive, str(user.id), archive_hash, blob) + archive_model = FileArchiveModel( + user_id=user.id, + blob_hash=archive_hash, + blob=blob if storage is None else None, + ) + session.add(archive_model) + await session.commit() + logger.debug("File archive (user_id=%s, hash=%s) has been uploaded", user.id, archive_hash) + return archive_model_to_archive(archive_model) + + +def archive_model_to_archive(archive_model: FileArchiveModel) -> FileArchive: + return FileArchive(id=archive_model.id, hash=archive_model.blob_hash) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 157090dd0e..979daa44c2 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -33,6 +33,7 @@ RunSpec, ) from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus +from dstack._internal.server import settings from dstack._internal.server.models import ( InstanceModel, JobModel, @@ -380,8 +381,10 @@ def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel): message=job_model.termination_reason_message, timeout=0, ) - # maybe somehow postpone removing old tasks to allow inspecting failed jobs? - shim_client.remove_task(task_id=job_model.id) + # maybe somehow postpone removing old tasks to allow inspecting failed jobs without + # the following setting? + if not settings.SERVER_KEEP_SHIM_TASKS: + shim_client.remove_task(task_id=job_model.id) else: shim_client.stop(force=True) diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index e205990aa8..be6a11be83 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -109,6 +109,14 @@ def submit_job( ) resp.raise_for_status() + def upload_archive(self, id: uuid.UUID, file: Union[BinaryIO, bytes]): + resp = requests.post( + self._url("/api/upload_archive"), + files={"archive": (str(id), file)}, + timeout=UPLOAD_CODE_REQUEST_TIMEOUT, + ) + resp.raise_for_status() + def upload_code(self, file: Union[BinaryIO, bytes]): resp = requests.post( self._url("/api/upload_code"), data=file, timeout=UPLOAD_CODE_REQUEST_TIMEOUT diff --git a/src/dstack/_internal/server/services/storage/base.py b/src/dstack/_internal/server/services/storage/base.py index 5864f061d5..de11599693 100644 --- a/src/dstack/_internal/server/services/storage/base.py +++ b/src/dstack/_internal/server/services/storage/base.py @@ -22,6 +22,27 @@ def get_code( ) -> Optional[bytes]: pass + @abstractmethod + def upload_archive( + self, + user_id: str, + archive_hash: str, + blob: bytes, + ): + pass + + @abstractmethod + def get_archive( + self, + user_id: str, + archive_hash: str, + ) -> Optional[bytes]: + pass + @staticmethod def _get_code_key(project_id: str, repo_id: str, code_hash: str) -> str: return f"data/projects/{project_id}/codes/{repo_id}/{code_hash}" + + @staticmethod + def _get_archive_key(user_id: str, archive_hash: str) -> str: + return f"data/users/{user_id}/file_archives/{archive_hash}" diff --git a/src/dstack/_internal/server/services/storage/gcs.py b/src/dstack/_internal/server/services/storage/gcs.py index 1075aba01f..6c565625e2 100644 --- a/src/dstack/_internal/server/services/storage/gcs.py +++ b/src/dstack/_internal/server/services/storage/gcs.py @@ -25,9 +25,8 @@ def upload_code( code_hash: str, blob: bytes, ): - blob_name = self._get_code_key(project_id, repo_id, code_hash) - blob_obj = self._bucket.blob(blob_name) - blob_obj.upload_from_string(blob) + key = self._get_code_key(project_id, repo_id, code_hash) + self._upload(key, blob) def get_code( self, @@ -35,10 +34,33 @@ def get_code( repo_id: str, code_hash: str, ) -> Optional[bytes]: + key = self._get_code_key(project_id, repo_id, code_hash) + return self._get(key) + + def upload_archive( + self, + user_id: str, + archive_hash: str, + blob: bytes, + ): + key = self._get_archive_key(user_id, archive_hash) + self._upload(key, blob) + + def get_archive( + self, + user_id: str, + archive_hash: str, + ) -> Optional[bytes]: + key = self._get_archive_key(user_id, archive_hash) + return self._get(key) + + def _upload(self, key: str, blob: bytes): + blob_obj = self._bucket.blob(key) + blob_obj.upload_from_string(blob) + + def _get(self, key: str) -> Optional[bytes]: try: - blob_name = self._get_code_key(project_id, repo_id, code_hash) - blob = self._bucket.blob(blob_name) + blob = self._bucket.blob(key) except NotFound: return None - return blob.download_as_bytes() diff --git a/src/dstack/_internal/server/services/storage/s3.py b/src/dstack/_internal/server/services/storage/s3.py index 8c67f28c2b..a0b993c731 100644 --- a/src/dstack/_internal/server/services/storage/s3.py +++ b/src/dstack/_internal/server/services/storage/s3.py @@ -27,11 +27,8 @@ def upload_code( code_hash: str, blob: bytes, ): - self._client.put_object( - Bucket=self.bucket, - Key=self._get_code_key(project_id, repo_id, code_hash), - Body=blob, - ) + key = self._get_code_key(project_id, repo_id, code_hash) + self._upload(key, blob) def get_code( self, @@ -39,11 +36,32 @@ def get_code( repo_id: str, code_hash: str, ) -> Optional[bytes]: + key = self._get_code_key(project_id, repo_id, code_hash) + return self._get(key) + + def upload_archive( + self, + user_id: str, + archive_hash: str, + blob: bytes, + ): + key = self._get_archive_key(user_id, archive_hash) + self._upload(key, blob) + + def get_archive( + self, + user_id: str, + archive_hash: str, + ) -> Optional[bytes]: + key = self._get_archive_key(user_id, archive_hash) + return self._get(key) + + def _upload(self, key: str, blob: bytes): + self._client.put_object(Bucket=self.bucket, Key=key, Body=blob) + + def _get(self, key: str) -> Optional[bytes]: try: - response = self._client.get_object( - Bucket=self.bucket, - Key=self._get_code_key(project_id, repo_id, code_hash), - ) + response = self._client.get_object(Bucket=self.bucket, Key=key) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] == "NoSuchKey": return None diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 5df67123bb..e71c8ce657 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -70,6 +70,8 @@ os.getenv("DSTACK_SERVER_METRICS_FINISHED_TTL_SECONDS", 7 * 24 * 3600) ) +SERVER_KEEP_SHIM_TASKS = os.getenv("DSTACK_SERVER_KEEP_SHIM_TASKS") is not None + DEFAULT_PROJECT_NAME = "main" SENTRY_DSN = os.getenv("DSTACK_SENTRY_DSN") diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 31045a036b..7aadb48979 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -77,6 +77,7 @@ from dstack._internal.server.models import ( BackendModel, DecryptedString, + FileArchiveModel, FleetModel, GatewayComputeModel, GatewayModel, @@ -232,6 +233,22 @@ async def create_repo_creds( return repo_creds +async def create_file_archive( + session: AsyncSession, + user_id: UUID, + blob_hash: str = "blob_hash", + blob: bytes = b"blob_content", +) -> FileArchiveModel: + archive = FileArchiveModel( + user_id=user_id, + blob_hash=blob_hash, + blob=blob, + ) + session.add(archive) + await session.commit() + return archive + + def get_run_spec( run_name: str, repo_id: str, diff --git a/src/dstack/_internal/utils/files.py b/src/dstack/_internal/utils/files.py new file mode 100644 index 0000000000..71a0f8ef70 --- /dev/null +++ b/src/dstack/_internal/utils/files.py @@ -0,0 +1,69 @@ +import tarfile +from pathlib import Path +from typing import BinaryIO + +import ignore +import ignore.overrides + +from dstack._internal.utils.hash import get_sha256 +from dstack._internal.utils.path import PathLike, normalize_path + + +def create_file_archive(root: PathLike, fp: BinaryIO) -> str: + """ + Packs the directory or file to a tar archive and writes it to the file-like object. + + Archives can be used to transfer file(s) (e.g., over the network) preserving + file properties such as permissions, timestamps, etc. + + NOTE: `.gitignore` and `.dstackignore` are respected. + + Args: + root: The absolute path to the directory or file. + fp: The binary file-like object. + + Returns: + The SHA-256 hash of the archive as a hex string. + + Raises: + ValueError: If the path is not absolute. + OSError: Underlying errors from the tarfile module + """ + root = Path(root) + if not root.is_absolute(): + raise ValueError(f"path must be absolute: {root}") + walk = ( + ignore.WalkBuilder(root) + .overrides(ignore.overrides.OverrideBuilder(root).add("!/.git/").build()) + .hidden(False) # do not ignore files that start with a dot + .require_git(False) # respect git ignore rules even if not a git repo + .add_custom_ignore_filename(".dstackignore") + .build() + ) + # sort paths to ensure archive reproducibility + paths = sorted(entry.path() for entry in walk) + with tarfile.TarFile(mode="w", fileobj=fp) as t: + for path in paths: + arcname = str(path.relative_to(root.parent)) + info = t.gettarinfo(path, arcname) + if info.issym(): + # Symlinks are handled as follows: each symlink in the chain is checked, and + # * if the target is inside the root: keep relative links as is, replace absolute + # links with relative ones; + # * if the target is outside the root: replace the link with the actual file. + target = Path(info.linkname) + if not target.is_absolute(): + target = path.parent / target + target = normalize_path(target) + try: + target.relative_to(root) + except ValueError: + # Adding as a file + t.add(path.resolve(), arcname, recursive=False) + else: + # Adding as a relative symlink + info.linkname = str(target.relative_to(path.parent, walk_up=True)) + t.addfile(info) + else: + t.add(path, arcname, recursive=False) + return get_sha256(fp) diff --git a/src/dstack/_internal/utils/path.py b/src/dstack/_internal/utils/path.py index 8341407dcf..fc51c488a1 100644 --- a/src/dstack/_internal/utils/path.py +++ b/src/dstack/_internal/utils/path.py @@ -27,16 +27,24 @@ def path_in_dir(path: PathLike, directory: PathLike) -> bool: return False -def resolve_relative_path(path: str) -> PurePath: +def normalize_path(path: PathLike) -> PurePath: path = PurePath(path) - if path.is_absolute(): - raise ValueError("Path should be relative") stack = [] for part in path.parts: if part == "..": if not stack: - raise ValueError("Path is outside of the repo") + raise ValueError("Path is outside of the top directory") stack.pop() else: stack.append(part) return PurePath(*stack) + + +def resolve_relative_path(path: PathLike) -> PurePath: + path = PurePath(path) + if path.is_absolute(): + raise ValueError("Path should be relative") + try: + return normalize_path(path) + except ValueError: + raise ValueError("Path is outside of the repo") diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index da16fc3c6a..36c5a5d11e 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -17,6 +17,7 @@ from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration, PortMapping +from dstack._internal.core.models.files import FileArchiveMapping, FilePathMapping from dstack._internal.core.models.profiles import ( CreationPolicy, Profile, @@ -42,6 +43,7 @@ from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.utils.common import get_or_error, make_proxy_url +from dstack._internal.utils.files import create_file_archive from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike, path_in_dir from dstack.api.server import APIClient @@ -476,20 +478,38 @@ def apply_plan( # TODO handle multiple jobs ports_lock = _reserve_ports(run_plan.job_plans[0].job_spec) + run_spec = run_plan.run_spec + configuration = run_spec.configuration + + self._validate_configuration_files(configuration, run_spec.configuration_path) + for file_mapping in configuration.files: + assert isinstance(file_mapping, FilePathMapping) + with tempfile.TemporaryFile("w+b") as fp: + try: + archive_hash = create_file_archive(file_mapping.local_path, fp) + except OSError as e: + raise ClientError(f"failed to archive '{file_mapping.local_path}': {e}") from e + fp.seek(0) + archive = self._api_client.files.upload_archive(hash=archive_hash, fp=fp) + run_spec.file_archives.append( + FileArchiveMapping(id=archive.id, path=file_mapping.path) + ) + if repo is None: repo = VirtualRepo() else: # Do not upload the diff without a repo (a default virtual repo) # since upload_code() requires a repo to be initialized. with tempfile.TemporaryFile("w+b") as fp: - run_plan.run_spec.repo_code_hash = repo.write_code_file(fp) + run_spec.repo_code_hash = repo.write_code_file(fp) fp.seek(0) self._api_client.repos.upload_code( project_name=self._project, repo_id=repo.repo_id, - code_hash=run_plan.run_spec.repo_code_hash, + code_hash=run_spec.repo_code_hash, fp=fp, ) + run = self._api_client.runs.apply_plan(self._project, run_plan) return self._model_to_submitted_run(run, ports_lock) @@ -762,6 +782,30 @@ def _model_to_submitted_run(self, run: RunModel, ports_lock: Optional[PortsLock] ports_lock, ) + def _validate_configuration_files( + self, configuration: AnyRunConfiguration, configuration_path: Optional[PathLike] + ) -> None: + """ + Expands, normalizes and validates local paths specified in + the `files` configuration property. + """ + base_dir: Optional[Path] = None + if configuration_path is not None: + base_dir = Path(configuration_path).expanduser().resolve().parent + for file_mapping in configuration.files: + assert isinstance(file_mapping, FilePathMapping) + path = Path(file_mapping.local_path).expanduser() + if not path.is_absolute(): + if base_dir is None: + raise ConfigurationError( + f"Path '{path}' is relative but `configuration_path` is not provided" + ) + else: + path = base_dir / path + if not path.exists(): + raise ConfigurationError(f"Path '{path}' specified in `files` does not exist") + file_mapping.local_path = str(path) + def _reserve_ports( job_spec: JobSpec, diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index f09e8672c2..5cba94f552 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -14,6 +14,7 @@ ) from dstack._internal.utils.logging import get_logger from dstack.api.server._backends import BackendsAPIClient +from dstack.api.server._files import FilesAPIClient from dstack.api.server._fleets import FleetsAPIClient from dstack.api.server._gateways import GatewaysAPIClient from dstack.api.server._logs import LogsAPIClient @@ -47,6 +48,7 @@ class APIClient: logs: operations with logs gateways: operations with gateways volumes: operations with volumes + files: operations with files """ def __init__(self, base_url: str, token: str): @@ -111,6 +113,10 @@ def gateways(self) -> GatewaysAPIClient: def volumes(self) -> VolumesAPIClient: return VolumesAPIClient(self._request) + @property + def files(self) -> FilesAPIClient: + return FilesAPIClient(self._request) + def _request( self, path: str, diff --git a/src/dstack/api/server/_files.py b/src/dstack/api/server/_files.py new file mode 100644 index 0000000000..e7bdde91a3 --- /dev/null +++ b/src/dstack/api/server/_files.py @@ -0,0 +1,18 @@ +from typing import BinaryIO + +from pydantic import parse_obj_as + +from dstack._internal.core.models.files import FileArchive +from dstack._internal.server.schemas.files import GetFileArchiveByHashRequest +from dstack.api.server._group import APIClientGroup + + +class FilesAPIClient(APIClientGroup): + def get_archive_by_hash(self, hash: str) -> FileArchive: + body = GetFileArchiveByHashRequest(hash=hash) + resp = self._request("/api/files/get_archive_by_hash", body=body.json()) + return parse_obj_as(FileArchive.__response__, resp.json()) + + def upload_archive(self, hash: str, fp: BinaryIO) -> FileArchive: + resp = self._request("/api/files/upload_archive", files={"file": (hash, fp)}) + return parse_obj_as(FileArchive.__response__, resp.json()) diff --git a/src/tests/_internal/core/models/test_files.py b/src/tests/_internal/core/models/test_files.py new file mode 100644 index 0000000000..d2761d4e92 --- /dev/null +++ b/src/tests/_internal/core/models/test_files.py @@ -0,0 +1,29 @@ +import pytest +from pydantic import ValidationError + +from dstack._internal.core.models.files import FilePathMapping + + +class TestFilePathMapping: + @pytest.mark.parametrize("value", ["./file", "file", "~/file", "/file"]) + def test_parse_only_local_path(self, value: str): + assert FilePathMapping.parse(value) == FilePathMapping(local_path=value, path=value) + + def test_parse_both_paths(self): + assert FilePathMapping.parse("./foo:./bar") == FilePathMapping( + local_path="./foo", path="./bar" + ) + + def test_parse_windows_abs_path(self): + assert FilePathMapping.parse("C:\\dir:dir") == FilePathMapping( + local_path="C:\\dir", path="dir" + ) + + def test_error_invalid_mapping_if_more_than_two_parts(self): + with pytest.raises(ValueError, match="invalid file path mapping"): + FilePathMapping.parse("./foo:bar:baz") + + @pytest.mark.parametrize("value", ["C:\\", "d:/path/to"]) + def test_error_must_be_unix_path(self, value: str): + with pytest.raises(ValidationError, match="path must be a Unix file path"): + FilePathMapping.parse(value) diff --git a/src/tests/_internal/server/routers/test_files.py b/src/tests/_internal/server/routers/test_files.py new file mode 100644 index 0000000000..931a24b568 --- /dev/null +++ b/src/tests/_internal/server/routers/test_files.py @@ -0,0 +1,148 @@ +from unittest.mock import Mock + +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.users import GlobalRole +from dstack._internal.server.models import FileArchiveModel +from dstack._internal.server.services.storage import BaseStorage +from dstack._internal.server.testing.common import ( + create_file_archive, + create_user, + get_auth_headers, +) + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.usefixtures("test_db"), + pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True), +] + + +class TestGetArchiveByHash: + async def test_returns_403_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/files/get_archive_by_hash", + json={"hash": "blob_hash"}, + ) + assert response.status_code == 403 + + async def test_returns_400_if_archive_does_not_exist( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + response = await client.post( + "/api/files/get_archive_by_hash", + headers=get_auth_headers(user.token), + json={"hash": "blob_hash"}, + ) + assert response.status_code == 400, response.json() + + async def test_returns_archive(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + archive = await create_file_archive( + session=session, user_id=user.id, blob_hash="blob_hash", blob=b"blob_content" + ) + response = await client.post( + "/api/files/get_archive_by_hash", + headers=get_auth_headers(user.token), + json={"hash": archive.blob_hash}, + ) + assert response.status_code == 200, response.json() + assert response.json() == { + "id": str(archive.id), + "hash": archive.blob_hash, + } + + +class TestUploadArchive: + file_hash = "blob_hash" + file_content = b"blob_content" + file = (file_hash, file_content) + + @pytest.fixture + def default_storage_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + storage_mock = Mock(spec_set=BaseStorage) + monkeypatch.setattr( + "dstack._internal.server.services.files.get_default_storage", lambda: storage_mock + ) + return storage_mock + + @pytest.fixture + def no_default_storage(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + "dstack._internal.server.services.files.get_default_storage", lambda: None + ) + + async def test_returns_403_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/files/upload_archive", + files={"file": self.file}, + ) + assert response.status_code == 403 + + async def test_returns_existing_archive( + self, session: AsyncSession, client: AsyncClient, default_storage_mock: Mock + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + existing_archive = await create_file_archive( + session=session, user_id=user.id, blob_hash=self.file_hash, blob=b"existing_blob" + ) + response = await client.post( + "/api/files/upload_archive", + headers=get_auth_headers(user.token), + files={"file": self.file}, + ) + assert response.status_code == 200, response.json() + assert response.json() == { + "id": str(existing_archive.id), + "hash": self.file_hash, + } + res = await session.execute( + select(FileArchiveModel).where(FileArchiveModel.user_id == user.id) + ) + archive = res.scalar_one() + assert archive.id == existing_archive.id + assert archive.blob_hash == self.file_hash + assert archive.blob == existing_archive.blob + default_storage_mock.upload_archive.assert_not_called() + + @pytest.mark.usefixtures("no_default_storage") + async def test_uploads_archive_to_db(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + response = await client.post( + "/api/files/upload_archive", + headers=get_auth_headers(user.token), + files={"file": self.file}, + ) + assert response.status_code == 200, response.json() + assert response.json()["hash"] == self.file_hash + res = await session.execute( + select(FileArchiveModel).where(FileArchiveModel.user_id == user.id) + ) + archive = res.scalar_one() + assert archive.blob_hash == self.file_hash + assert archive.blob == self.file_content + + async def test_uploads_archive_to_storage( + self, session: AsyncSession, client: AsyncClient, default_storage_mock: Mock + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + response = await client.post( + "/api/files/upload_archive", + headers=get_auth_headers(user.token), + files={"file": self.file}, + ) + assert response.status_code == 200, response.json() + assert response.json()["hash"] == self.file_hash + res = await session.execute( + select(FileArchiveModel).where(FileArchiveModel.user_id == user.id) + ) + archive = res.scalar_one() + assert archive.blob_hash == self.file_hash + assert archive.blob is None + default_storage_mock.upload_archive.assert_called_once_with( + str(user.id), self.file_hash, self.file_content + ) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index c2e50794a6..880c3be5df 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -153,6 +153,7 @@ def get_dev_env_run_plan_dict( "shm_size": None, }, "volumes": [json.loads(v.json()) for v in volumes], + "files": [], "backends": ["local", "aws", "azure", "gcp", "lambda", "runpod"], "regions": ["us"], "availability_zones": None, @@ -174,6 +175,7 @@ def get_dev_env_run_plan_dict( "priority": 0, }, "configuration_path": "dstack.yaml", + "file_archives": [], "profile": { "backends": ["local", "aws", "azure", "gcp", "lambda", "runpod"], "regions": ["us"], @@ -348,6 +350,7 @@ def get_dev_env_run_dict( "shm_size": None, }, "volumes": [], + "files": [], "backends": ["local", "aws", "azure", "gcp", "lambda"], "regions": ["us"], "availability_zones": None, @@ -369,6 +372,7 @@ def get_dev_env_run_dict( "priority": 0, }, "configuration_path": "dstack.yaml", + "file_archives": [], "profile": { "backends": ["local", "aws", "azure", "gcp", "lambda"], "regions": ["us"], @@ -494,6 +498,7 @@ def get_service_run_spec( "model": "test-model", }, "configuration_path": "dstack.yaml", + "file_archives": [], "profile": { "name": "string", }, diff --git a/src/tests/_internal/utils/test_path.py b/src/tests/_internal/utils/test_path.py index 04fb5dfd2d..cb707892bc 100644 --- a/src/tests/_internal/utils/test_path.py +++ b/src/tests/_internal/utils/test_path.py @@ -2,7 +2,19 @@ import pytest -from dstack._internal.utils.path import resolve_relative_path +from dstack._internal.utils.path import normalize_path, resolve_relative_path + + +class TestNormalizePath: + def test_escape_top(self): + with pytest.raises(ValueError): + normalize_path("dir/../..") + + def test_normalize_rel(self): + assert normalize_path("dir/.///..///sibling") == PurePath("sibling") + + def test_normalize_abs(self): + assert normalize_path("/dir/.///..///sibling") == PurePath("/sibling") class TestResolveRelativePath: