diff --git a/src/dstack/_internal/cli/commands/apply.py b/src/dstack/_internal/cli/commands/apply.py
index 3ab26c5f4d..dba99f08ff 100644
--- a/src/dstack/_internal/cli/commands/apply.py
+++ b/src/dstack/_internal/cli/commands/apply.py
@@ -1,5 +1,4 @@
import argparse
-from pathlib import Path
from argcomplete import FilesCompleter # type: ignore[attr-defined]
@@ -9,12 +8,7 @@
get_apply_configurator_class,
load_apply_configuration,
)
-from dstack._internal.cli.services.repos import (
- init_default_virtual_repo,
- init_repo,
- register_init_repo_args,
-)
-from dstack._internal.cli.utils.common import console, warn
+from dstack._internal.cli.utils.common import console
from dstack._internal.core.errors import CLIError
from dstack._internal.core.models.configurations import ApplyConfigurationType
@@ -66,37 +60,6 @@ def _register(self):
help="Exit immediately after submitting configuration",
action="store_true",
)
- self._parser.add_argument(
- "--ssh-identity",
- metavar="SSH_PRIVATE_KEY",
- help="The private SSH key path for SSH tunneling",
- type=Path,
- dest="ssh_identity_file",
- )
- repo_group = self._parser.add_argument_group("Repo Options")
- repo_group.add_argument(
- "-P",
- "--repo",
- help=("The repo to use for the run. Can be a local path or a Git repo URL."),
- dest="repo",
- )
- repo_group.add_argument(
- "--repo-branch",
- help="The repo branch to use for the run",
- dest="repo_branch",
- )
- repo_group.add_argument(
- "--repo-hash",
- help="The hash of the repo commit to use for the run",
- dest="repo_hash",
- )
- repo_group.add_argument(
- "--no-repo",
- help="Do not use any repo for the run",
- dest="no_repo",
- action="store_true",
- )
- register_init_repo_args(repo_group)
def _command(self, args: argparse.Namespace):
try:
@@ -117,26 +80,6 @@ def _command(self, args: argparse.Namespace):
super()._command(args)
if not args.yes and args.configuration_file == APPLY_STDIN_NAME:
raise CLIError("Cannot read configuration from stdin if -y/--yes is not specified")
- if args.repo and args.no_repo:
- raise CLIError("Either --repo or --no-repo can be specified")
- if args.local:
- warn(
- "Local repos are deprecated since 0.19.25 and will be removed soon."
- " Consider using `files` instead: https://dstack.ai/docs/concepts/tasks/#files"
- )
- repo = None
- if args.repo:
- repo = init_repo(
- api=self.api,
- repo_path=args.repo,
- repo_branch=args.repo_branch,
- repo_hash=args.repo_hash,
- local=args.local,
- git_identity_file=args.git_identity_file,
- oauth_token=args.gh_token,
- )
- elif args.no_repo:
- repo = init_default_virtual_repo(api=self.api)
configuration_path, configuration = load_apply_configuration(args.configuration_file)
configurator_class = get_apply_configurator_class(configuration.type)
configurator = configurator_class(api_client=self.api)
@@ -148,7 +91,6 @@ def _command(self, args: argparse.Namespace):
command_args=args,
configurator_args=known,
unknown_args=unknown,
- repo=repo,
)
except KeyboardInterrupt:
console.print("\nOperation interrupted by user. Exiting...")
diff --git a/src/dstack/_internal/cli/commands/init.py b/src/dstack/_internal/cli/commands/init.py
index 0076fc6ac1..1aab72d405 100644
--- a/src/dstack/_internal/cli/commands/init.py
+++ b/src/dstack/_internal/cli/commands/init.py
@@ -1,12 +1,17 @@
import argparse
import os
from pathlib import Path
+from typing import Optional
from dstack._internal.cli.commands import BaseCommand
-from dstack._internal.cli.services.repos import init_repo, register_init_repo_args
+from dstack._internal.cli.services.repos import (
+ get_repo_from_dir,
+ get_repo_from_url,
+ is_git_repo_url,
+ register_init_repo_args,
+)
from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
from dstack._internal.core.errors import ConfigurationError
-from dstack._internal.core.models.repos.base import RepoType
from dstack._internal.core.services.configs import ConfigManager
from dstack.api import Client
@@ -21,6 +26,15 @@ def _register(self):
help="The name of the project",
default=os.getenv("DSTACK_PROJECT"),
)
+ self._parser.add_argument(
+ "-P",
+ "--repo",
+ help=(
+ "The repo to initialize. Can be a local path or a Git repo URL."
+ " Defaults to the current working directory."
+ ),
+ dest="repo",
+ )
register_init_repo_args(self._parser)
# Deprecated since 0.19.25, ignored
self._parser.add_argument(
@@ -30,7 +44,7 @@ def _register(self):
type=Path,
dest="ssh_identity_file",
)
- # A hidden mode for transitional period only, remove it with local repos
+ # A hidden mode for transitional period only, remove it with repos in `config.yml`
self._parser.add_argument(
"--remove",
help=argparse.SUPPRESS,
@@ -39,44 +53,62 @@ def _register(self):
def _command(self, args: argparse.Namespace):
configure_logging()
+
+ repo_path: Optional[Path] = None
+ repo_url: Optional[str] = None
+ repo_arg: Optional[str] = args.repo
+ if repo_arg is not None:
+ if is_git_repo_url(repo_arg):
+ repo_url = repo_arg
+ else:
+ repo_path = Path(repo_arg).expanduser().resolve()
+ else:
+ repo_path = Path.cwd()
+
if args.remove:
+ if repo_url is not None:
+ raise ConfigurationError(f"Local path expected, got URL: {repo_url}")
+ assert repo_path is not None
config_manager = ConfigManager()
- repo_path = Path.cwd()
repo_config = config_manager.get_repo_config(repo_path)
if repo_config is None:
- raise ConfigurationError("The repo is not initialized, nothing to remove")
- if repo_config.repo_type != RepoType.LOCAL:
- raise ConfigurationError("`dstack init --remove` is for local repos only")
+ raise ConfigurationError("Repo record not found, nothing to remove")
console.print(
- f"You are about to remove the local repo {repo_path}\n"
+ f"You are about to remove the repo {repo_path}\n"
"Only the record about the repo will be removed,"
" the repo files will remain intact\n"
)
- if not confirm_ask("Remove the local repo?"):
+ if not confirm_ask("Remove the repo?"):
return
config_manager.delete_repo_config(repo_config.repo_id)
config_manager.save()
- console.print("Local repo has been removed")
+ console.print("Repo has been removed")
return
- api = Client.from_config(
- project_name=args.project, ssh_identity_file=args.ssh_identity_file
- )
- if args.local:
+
+ local: bool = args.local
+ if local:
warn(
- "Local repos are deprecated since 0.19.25 and will be removed soon."
- " Consider using `files` instead: https://dstack.ai/docs/concepts/tasks/#files"
+ "Local repos are deprecated since 0.19.25 and will be removed soon. Consider"
+ " using [code]files[/code] instead: https://dstack.ai/docs/concepts/tasks/#files"
)
if args.ssh_identity_file:
warn(
- "`--ssh-identity` in `dstack init` is deprecated and ignored since 0.19.25."
- " Use this option with `dstack apply` and `dstack attach` instead"
+ "[code]--ssh-identity[/code] in [code]dstack init[/code] is deprecated and ignored"
+ " since 0.19.25. Use this option with [code]dstack apply[/code]"
+ " and [code]dstack attach[/code] instead"
)
- init_repo(
- api=api,
- repo_path=Path.cwd(),
- repo_branch=None,
- repo_hash=None,
- local=args.local,
+
+ if repo_url is not None:
+ # Dummy repo branch to avoid autodetection that fails on private repos.
+ # We don't need branch/hash for repo_id anyway.
+ repo = get_repo_from_url(repo_url, repo_branch="master")
+ elif repo_path is not None:
+ repo = get_repo_from_dir(repo_path, local=local)
+ else:
+ assert False, "should not reach here"
+ api = Client.from_config(project_name=args.project)
+ api.repos.init(
+ repo=repo,
git_identity_file=args.git_identity_file,
oauth_token=args.gh_token,
)
diff --git a/src/dstack/_internal/cli/services/configurators/base.py b/src/dstack/_internal/cli/services/configurators/base.py
index 440a31d6c2..35414801ee 100644
--- a/src/dstack/_internal/cli/services/configurators/base.py
+++ b/src/dstack/_internal/cli/services/configurators/base.py
@@ -1,7 +1,7 @@
import argparse
import os
from abc import ABC, abstractmethod
-from typing import Generic, List, Optional, TypeVar, Union, cast
+from typing import Generic, List, TypeVar, Union, cast
from dstack._internal.cli.services.args import env_var
from dstack._internal.core.errors import ConfigurationError
@@ -10,7 +10,6 @@
ApplyConfigurationType,
)
from dstack._internal.core.models.envs import Env, EnvSentinel, EnvVarTuple
-from dstack._internal.core.models.repos.base import Repo
from dstack.api._public import Client
ArgsParser = Union[argparse._ArgumentGroup, argparse.ArgumentParser]
@@ -32,7 +31,6 @@ def apply_configuration(
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
- repo: Optional[Repo] = None,
):
"""
Implements `dstack apply` for a given configuration type.
@@ -43,7 +41,6 @@ def apply_configuration(
command_args: The args parsed by `dstack apply`.
configurator_args: The known args parsed by `cls.get_parser()`.
unknown_args: The unknown args after parsing by `cls.get_parser()`.
- repo: The repo to use with apply.
"""
pass
diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py
index 6718f4f0f2..a250058bc3 100644
--- a/src/dstack/_internal/cli/services/configurators/fleet.py
+++ b/src/dstack/_internal/cli/services/configurators/fleet.py
@@ -35,7 +35,6 @@
InstanceGroupPlacement,
)
from dstack._internal.core.models.instances import InstanceAvailability, InstanceStatus, SSHKey
-from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.services.diff import diff_models
from dstack._internal.utils.common import local_time
from dstack._internal.utils.logging import get_logger
@@ -56,7 +55,6 @@ def apply_configuration(
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
- repo: Optional[Repo] = None,
):
self.apply_args(conf, configurator_args, unknown_args)
profile = load_profile(Path.cwd(), None)
diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py
index 8a22277b17..7d26e220ba 100644
--- a/src/dstack/_internal/cli/services/configurators/gateway.py
+++ b/src/dstack/_internal/cli/services/configurators/gateway.py
@@ -1,6 +1,6 @@
import argparse
import time
-from typing import List, Optional
+from typing import List
from rich.table import Table
@@ -21,7 +21,6 @@
GatewaySpec,
GatewayStatus,
)
-from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.services.diff import diff_models
from dstack._internal.utils.common import local_time
from dstack.api._public import Client
@@ -37,7 +36,6 @@ def apply_configuration(
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
- repo: Optional[Repo] = None,
):
self.apply_args(conf, configurator_args, unknown_args)
spec = GatewaySpec(
diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py
index 02f02f7b42..4a4a4d453b 100644
--- a/src/dstack/_internal/cli/services/configurators/run.py
+++ b/src/dstack/_internal/cli/services/configurators/run.py
@@ -15,12 +15,14 @@
BaseApplyConfigurator,
)
from dstack._internal.cli.services.profile import apply_profile_args, register_profile_args
-from dstack._internal.cli.services.repos import init_default_virtual_repo
-from dstack._internal.cli.utils.common import (
- confirm_ask,
- console,
- warn,
+from dstack._internal.cli.services.repos import (
+ get_repo_from_dir,
+ get_repo_from_url,
+ init_default_virtual_repo,
+ is_git_repo_url,
+ register_init_repo_args,
)
+from dstack._internal.cli.utils.common import confirm_ask, console, warn
from dstack._internal.cli.utils.rich import MultiItemStatus
from dstack._internal.cli.utils.run import get_runs_table, print_run_plan
from dstack._internal.core.errors import (
@@ -46,6 +48,7 @@
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.core.services.diff import diff_models
+from dstack._internal.core.services.repos import load_repo
from dstack._internal.utils.common import local_time
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack._internal.utils.logging import get_logger
@@ -78,42 +81,18 @@ def apply_configuration(
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
- repo: Optional[Repo] = None,
):
+ if configurator_args.repo and configurator_args.no_repo:
+ raise CLIError("Either --repo or --no-repo can be specified")
+
self.apply_args(conf, configurator_args, unknown_args)
self.validate_gpu_vendor_and_image(conf)
self.validate_cpu_arch_and_image(conf)
+
config_manager = ConfigManager()
- if repo is None:
- repo_path = Path.cwd()
- repo_config = config_manager.get_repo_config(repo_path)
- if repo_config is None:
- warn(
- "Repo is not initialized. "
- "Use [code]--repo
[/code] or [code]--no-repo[/code] to initialize it.\n"
- "Starting from 0.19.26, repos will be configured via YAML and this message won't appear."
- )
- if not command_args.yes and not confirm_ask("Continue without the repo?"):
- console.print("\nExiting...")
- return
- repo = init_default_virtual_repo(self.api)
- else:
- # Unlikely, but may raise ConfigurationError if the repo does not exist
- # on the server side (stale entry in `config.yml`)
- repo = self.api.repos.load(repo_path)
- if isinstance(repo, LocalRepo):
- warn(
- f"{repo.repo_dir} is a local repo.\n"
- "Local repos are deprecated since 0.19.25"
- " and will be removed soon\n"
- "There are two options:\n"
- " - Migrate to `files`: https://dstack.ai/docs/concepts/tasks/#files\n"
- " - Specify `--no-repo` if you don't need the repo at all\n"
- "In either case, you can run `dstack init --remove` to remove the repo"
- " (only the record about the repo, not its files) and this warning"
- )
+ repo = self.get_repo(conf, configuration_path, configurator_args, config_manager)
self.api.ssh_identity_file = get_ssh_keypair(
- command_args.ssh_identity_file,
+ configurator_args.ssh_identity_file,
config_manager.dstack_key_path,
)
profile = load_profile(Path.cwd(), configurator_args.profile)
@@ -298,6 +277,13 @@ def delete_configuration(
@classmethod
def register_args(cls, parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--ssh-identity",
+ metavar="SSH_PRIVATE_KEY",
+ help="The private SSH key path for SSH tunneling",
+ type=Path,
+ dest="ssh_identity_file",
+ )
configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options")
configuration_group.add_argument(
"-n",
@@ -336,6 +322,30 @@ def register_args(cls, parser: argparse.ArgumentParser):
dest="disk_spec",
)
register_profile_args(parser)
+ repo_group = parser.add_argument_group("Repo Options")
+ repo_group.add_argument(
+ "-P",
+ "--repo",
+ help=("The repo to use for the run. Can be a local path or a Git repo URL."),
+ dest="repo",
+ )
+ repo_group.add_argument(
+ "--repo-branch",
+ help="The repo branch to use for the run",
+ dest="repo_branch",
+ )
+ repo_group.add_argument(
+ "--repo-hash",
+ help="The hash of the repo commit to use for the run",
+ dest="repo_hash",
+ )
+ repo_group.add_argument(
+ "--no-repo",
+ help="Do not use any repo for the run",
+ dest="no_repo",
+ action="store_true",
+ )
+ register_init_repo_args(repo_group)
def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace, unknown: List[str]):
apply_profile_args(args, conf)
@@ -465,6 +475,118 @@ def validate_cpu_arch_and_image(self, conf: RunConfigurationT) -> None:
if arch == gpuhunt.CPUArchitecture.ARM and conf.image is None:
raise ConfigurationError("`image` is required if `resources.cpu.arch` is `arm`")
+ def get_repo(
+ self,
+ conf: RunConfigurationT,
+ configuration_path: str,
+ configurator_args: argparse.Namespace,
+ config_manager: ConfigManager,
+ ) -> Repo:
+ if configurator_args.no_repo:
+ return init_default_virtual_repo(api=self.api)
+
+ repo: Optional[Repo] = None
+ repo_branch: Optional[str] = configurator_args.repo_branch
+ repo_hash: Optional[str] = configurator_args.repo_hash
+ # Should we (re)initialize the repo?
+ # If any Git credentials provided, we reinitialize the repo, as the user may have provided
+ # updated credentials.
+ init = (
+ configurator_args.git_identity_file is not None
+ or configurator_args.gh_token is not None
+ )
+
+ url: Optional[str] = None
+ local_path: Optional[Path] = None
+ # dummy value, safe to join with any path
+ root_dir = Path(".")
+ # True if no repo specified, but we found one in `config.yml`
+ legacy_local_path = False
+ if repo_arg := configurator_args.repo:
+ if is_git_repo_url(repo_arg):
+ url = repo_arg
+ else:
+ local_path = Path(repo_arg)
+ # rel paths in `--repo` are resolved relative to the current working dir
+ root_dir = Path.cwd()
+ elif conf.repos:
+ repo_spec = conf.repos[0]
+ if repo_spec.url:
+ url = repo_spec.url
+ elif repo_spec.local_path:
+ local_path = Path(repo_spec.local_path)
+ # rel paths in the conf are resolved relative to the conf's parent dir
+ root_dir = Path(configuration_path).resolve().parent
+ else:
+ assert False, f"should not reach here: {repo_spec}"
+ if repo_branch is None:
+ repo_branch = repo_spec.branch
+ if repo_hash is None:
+ repo_hash = repo_spec.hash
+ else:
+ local_path = Path.cwd()
+ legacy_local_path = True
+ if url:
+ repo = get_repo_from_url(repo_url=url, repo_branch=repo_branch, repo_hash=repo_hash)
+ if not self.api.repos.is_initialized(repo, by_user=True):
+ init = True
+ elif local_path:
+ if legacy_local_path:
+ if repo_config := config_manager.get_repo_config(local_path):
+ repo = load_repo(repo_config)
+ # allow users with legacy configurations use shared repo creds
+ if self.api.repos.is_initialized(repo, by_user=False):
+ warn(
+ "The repo is not specified but found and will be used in the run\n"
+ "Future versions will not load repos automatically\n"
+ "To prepare for future versions and get rid of this warning:\n"
+ "- If you need the repo in the run, either specify [code]repos[/code]"
+ " in the configuration or use [code]--repo .[/code]\n"
+ "- If you don't need the repo in the run, either run"
+ " [code]dstack init --remove[/code] once (it removes only the record"
+ " about the repo, the repo files will remain intact)"
+ " or use [code]--no-repo[/code]"
+ )
+ else:
+ # ignore stale entries in `config.yml`
+ repo = None
+ init = False
+ else:
+ original_local_path = local_path
+ local_path = local_path.expanduser()
+ if not local_path.is_absolute():
+ local_path = (root_dir / local_path).resolve()
+ if not local_path.exists():
+ raise ConfigurationError(
+ f"Invalid repo path: {original_local_path} -> {local_path}"
+ )
+ local: bool = configurator_args.local
+ repo = get_repo_from_dir(local_path, local=local)
+ if not self.api.repos.is_initialized(repo, by_user=True):
+ init = True
+ else:
+ assert False, "should not reach here"
+
+ if repo is None:
+ return init_default_virtual_repo(api=self.api)
+
+ if init:
+ self.api.repos.init(
+ repo=repo,
+ git_identity_file=configurator_args.git_identity_file,
+ oauth_token=configurator_args.gh_token,
+ )
+ if isinstance(repo, LocalRepo):
+ warn(
+ f"{repo.repo_dir} is a local repo\n"
+ "Local repos are deprecated since 0.19.25 and will be removed soon\n"
+ "There are two options:\n"
+ "- Migrate to [code]files[/code]: https://dstack.ai/docs/concepts/tasks/#files\n"
+ "- Specify [code]--no-repo[/code] if you don't need the repo at all"
+ )
+
+ return repo
+
class RunWithPortsConfiguratorMixin:
@classmethod
diff --git a/src/dstack/_internal/cli/services/configurators/volume.py b/src/dstack/_internal/cli/services/configurators/volume.py
index 72b21e5bb4..b0e25e503c 100644
--- a/src/dstack/_internal/cli/services/configurators/volume.py
+++ b/src/dstack/_internal/cli/services/configurators/volume.py
@@ -1,6 +1,6 @@
import argparse
import time
-from typing import List, Optional
+from typing import List
from rich.table import Table
@@ -14,7 +14,6 @@
from dstack._internal.cli.utils.volume import get_volumes_table
from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.configurations import ApplyConfigurationType
-from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.volumes import (
Volume,
VolumeConfiguration,
@@ -36,7 +35,6 @@ def apply_configuration(
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
- repo: Optional[Repo] = None,
):
self.apply_args(conf, configurator_args, unknown_args)
spec = VolumeSpec(
diff --git a/src/dstack/_internal/cli/services/repos.py b/src/dstack/_internal/cli/services/repos.py
index 5e7a10589f..0abc9d9ef4 100644
--- a/src/dstack/_internal/cli/services/repos.py
+++ b/src/dstack/_internal/cli/services/repos.py
@@ -1,10 +1,11 @@
import argparse
-from pathlib import Path
-from typing import Optional
+from typing import Literal, Optional, Union, overload
+
+import git
from dstack._internal.cli.services.configurators.base import ArgsParser
from dstack._internal.core.errors import CLIError
-from dstack._internal.core.models.repos.base import Repo
+from dstack._internal.core.models.repos.local import LocalRepo
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.core.services.repos import get_default_branch
@@ -36,51 +37,54 @@ def register_init_repo_args(parser: ArgsParser):
)
-def init_repo(
- api: Client,
- repo_path: PathLike,
- repo_branch: Optional[str],
- repo_hash: Optional[str],
- local: bool,
- git_identity_file: Optional[PathLike],
- oauth_token: Optional[str],
-) -> Repo:
- if Path(repo_path).exists():
- repo = api.repos.load(
- repo_dir=repo_path,
- local=local,
- init=True,
- git_identity_file=git_identity_file,
- oauth_token=oauth_token,
- )
- elif isinstance(repo_path, str):
- try:
- GitRepoURL.parse(repo_path)
- except RepoError as e:
- raise CLIError("Invalid repo path") from e
- if repo_branch is None and repo_hash is None:
- repo_branch = get_default_branch(repo_path)
- if repo_branch is None:
- raise CLIError(
- "Failed to automatically detect remote repo branch."
- " Specify --repo-branch or --repo-hash."
- )
- repo = RemoteRepo.from_url(
- repo_url=repo_path,
- repo_branch=repo_branch,
- repo_hash=repo_hash,
- )
- api.repos.init(
- repo=repo,
- git_identity_file=git_identity_file,
- oauth_token=oauth_token,
- )
- else:
- raise CLIError("Invalid repo path")
- return repo
-
-
def init_default_virtual_repo(api: Client) -> VirtualRepo:
repo = VirtualRepo()
api.repos.init(repo)
return repo
+
+
+def get_repo_from_url(
+ repo_url: str, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None
+) -> RemoteRepo:
+ if repo_branch is None and repo_hash is None:
+ repo_branch = get_default_branch(repo_url)
+ if repo_branch is None:
+ raise CLIError(
+ "Failed to automatically detect remote repo branch. Specify branch or hash."
+ )
+ return RemoteRepo.from_url(
+ repo_url=repo_url,
+ repo_branch=repo_branch,
+ repo_hash=repo_hash,
+ )
+
+
+@overload
+def get_repo_from_dir(repo_dir: PathLike, local: Literal[False] = False) -> RemoteRepo: ...
+
+
+@overload
+def get_repo_from_dir(repo_dir: PathLike, local: Literal[True]) -> LocalRepo: ...
+
+
+def get_repo_from_dir(repo_dir: PathLike, local: bool = False) -> Union[RemoteRepo, LocalRepo]:
+ if local:
+ return LocalRepo.from_dir(repo_dir)
+ try:
+ return RemoteRepo.from_dir(repo_dir)
+ except git.InvalidGitRepositoryError:
+ raise CLIError(
+ f"Git repo not found: {repo_dir}\n"
+ "Use `files` to mount an arbitrary directory:"
+ " https://dstack.ai/docs/concepts/tasks/#files"
+ )
+ except RepoError as e:
+ raise CLIError(str(e)) from e
+
+
+def is_git_repo_url(value: str) -> bool:
+ try:
+ GitRepoURL.parse(value)
+ except RepoError:
+ return False
+ return True
diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py
index f58be55e19..cfb809acb5 100644
--- a/src/dstack/_internal/core/compatibility/runs.py
+++ b/src/dstack/_internal/core/compatibility/runs.py
@@ -136,6 +136,7 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType:
configuration_excludes["schedule"] = True
if profile is not None and profile.schedule is None:
profile_excludes.add("schedule")
+ configuration_excludes["repos"] = True
if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py
index 4b92d6f82b..a72792da24 100644
--- a/src/dstack/_internal/core/models/configurations.py
+++ b/src/dstack/_internal/core/models/configurations.py
@@ -1,12 +1,13 @@
import re
+import string
from collections import Counter
from enum import Enum
from pathlib import PurePosixPath
-from typing import Any, Dict, List, Optional, Union
+from typing import Annotated, Any, Dict, List, Literal, Optional, Union
import orjson
from pydantic import Field, ValidationError, conint, constr, root_validator, validator
-from typing_extensions import Annotated, Literal
+from typing_extensions import Self
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth
@@ -83,6 +84,72 @@ def parse(cls, v: str) -> "PortMapping":
return PortMapping(local_port=local_port, container_port=int(container_port))
+class RepoSpec(CoreModel):
+ local_path: Annotated[
+ Optional[str],
+ Field(
+ description=(
+ "The path to the Git repo on the user's machine. Relative paths are resolved"
+ " relative to the parent directory of the the configuration file."
+ " Mutually exclusive with `url`"
+ )
+ ),
+ ] = None
+ url: Annotated[
+ Optional[str],
+ Field(description="The Git repo URL. Mutually exclusive with `local_path`"),
+ ] = None
+ branch: Annotated[
+ Optional[str],
+ Field(
+ description=(
+ "The repo branch. Defaults to the active branch for local paths"
+ " and the default branch for URLs"
+ )
+ ),
+ ] = None
+ hash: Annotated[
+ Optional[str],
+ Field(description="The commit hash"),
+ ] = None
+ # Not implemented, has no effect, hidden in the docs
+ path: str = DEFAULT_REPO_DIR
+
+ @classmethod
+ def parse(cls, v: str) -> Self:
+ is_url = False
+ parts = v.split(":")
+ if len(parts) > 1:
+ # Git repo, git@github.com:dstackai/dstack.git or https://github.com/dstackai/dstack
+ if "@" in parts[0] or parts[1].startswith("//"):
+ parts = [f"{parts[0]}:{parts[1]}", *parts[2:]]
+ is_url = True
+ # Windows path, e.g., `C:\path\to`, 'c:/path/to'
+ elif (
+ len(parts[0]) == 1
+ and parts[0] in string.ascii_letters
+ and parts[1][:1] in ["\\", "/"]
+ ):
+ parts = [f"{parts[0]}:{parts[1]}", *parts[2:]]
+ if len(parts) == 1:
+ if is_url:
+ return cls(url=parts[0])
+ return cls(local_path=parts[0])
+ if len(parts) == 2:
+ if is_url:
+ return cls(url=parts[0], path=parts[1])
+ return cls(local_path=parts[0], path=parts[1])
+ raise ValueError(f"Invalid repo: {v}")
+
+ @root_validator
+ def validate_local_path_or_url(cls, values):
+ if values["local_path"] and values["url"]:
+ raise ValueError("`local_path` and `url` are mutually exclusive")
+ if not values["local_path"] and not values["url"]:
+ raise ValueError("Either `local_path` or `url` must be specified")
+ return values
+
+
class ScalingSpec(CoreModel):
metric: Annotated[
Literal["rps"],
@@ -392,6 +459,10 @@ class BaseRunConfiguration(CoreModel):
description="Use Docker inside the container. Mutually exclusive with `image`, `python`, and `nvcc`. Overrides `privileged`"
),
] = None
+ repos: Annotated[
+ list[RepoSpec],
+ Field(description="The list of Git repos"),
+ ] = []
files: Annotated[
list[FilePathMapping],
Field(description="The local to container file path mappings"),
@@ -447,6 +518,18 @@ def convert_files(cls, v: Union[FilePathMapping, str]) -> FilePathMapping:
return FilePathMapping.parse(v)
return v
+ @validator("repos", pre=True, each_item=True)
+ def convert_repos(cls, v: Union[RepoSpec, str]) -> RepoSpec:
+ if isinstance(v, str):
+ return RepoSpec.parse(v)
+ return v
+
+ @validator("repos")
+ def validate_repos(cls, v) -> RepoSpec:
+ if len(v) > 1:
+ raise ValueError("A maximum of one repo is currently supported")
+ return v
+
@validator("user")
def validate_user(cls, v) -> Optional[str]:
if v is None:
diff --git a/src/dstack/_internal/core/models/files.py b/src/dstack/_internal/core/models/files.py
index f2e4f6826d..797c64da23 100644
--- a/src/dstack/_internal/core/models/files.py
+++ b/src/dstack/_internal/core/models/files.py
@@ -28,7 +28,7 @@ class FilePathMapping(CoreModel):
Field(
description=(
"The path in the container. Relative paths are resolved relative to"
- " the repo directory (`/workflow`)"
+ " the repo directory"
)
),
]
diff --git a/src/dstack/_internal/core/services/repos.py b/src/dstack/_internal/core/services/repos.py
index 61ff9b3abb..a054519fd1 100644
--- a/src/dstack/_internal/core/services/repos.py
+++ b/src/dstack/_internal/core/services/repos.py
@@ -2,7 +2,7 @@
from pathlib import Path
from typing import Optional, Union
-import git
+import git.cmd
import requests
import yaml
from git.exc import GitCommandError
@@ -24,6 +24,8 @@
gh_config_path = os.path.expanduser("~/.config/gh/hosts.yml")
default_ssh_key = os.path.expanduser("~/.ssh/id_rsa")
+no_prompt_env = dict(GIT_TERMINAL_PROMPT="0")
+
class InvalidRepoCredentialsError(DstackError):
pass
@@ -84,7 +86,7 @@ def get_local_repo_credentials(
def check_remote_repo_credentials_https(url: GitRepoURL, oauth_token: str) -> RemoteRepoCreds:
try:
- git.cmd.Git().ls_remote(url.as_https(oauth_token), env=dict(GIT_TERMINAL_PROMPT="0")) # type: ignore[attr-defined]
+ git.cmd.Git().ls_remote(url.as_https(oauth_token), env=no_prompt_env)
except GitCommandError:
masked = len(oauth_token[:-4]) * "*" + oauth_token[-4:]
raise InvalidRepoCredentialsError(
@@ -111,7 +113,7 @@ def check_remote_repo_credentials_ssh(url: GitRepoURL, identity_file: PathLike)
private_key = f.read()
try:
- git.cmd.Git().ls_remote( # type: ignore[attr-defined]
+ git.cmd.Git().ls_remote(
url.as_ssh(), env=dict(GIT_SSH_COMMAND=make_ssh_command_for_git(identity_file))
)
except GitCommandError:
@@ -131,7 +133,7 @@ def get_default_branch(remote_url: str) -> Optional[str]:
Get the default branch of a remote Git repository.
"""
try:
- output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD") # type: ignore[attr-defined]
+ output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD", env=no_prompt_env)
for line in output.splitlines():
if line.startswith("ref:"):
return line.split()[1].split("/")[-1]
diff --git a/src/dstack/api/_public/repos.py b/src/dstack/api/_public/repos.py
index 9655e48f91..c05f80de75 100644
--- a/src/dstack/api/_public/repos.py
+++ b/src/dstack/api/_public/repos.py
@@ -68,6 +68,7 @@ def init(
"""
creds = None
if isinstance(repo, RemoteRepo):
+ assert repo.repo_url is not None
try:
creds = get_local_repo_credentials(
repo_url=repo.repo_url,
@@ -140,22 +141,40 @@ def load(
def is_initialized(
self,
repo: Repo,
+ by_user: bool = False,
) -> bool:
"""
- Checks if the remote repo is initialized in the project
+ Checks if the repo is initialized in the project
Args:
repo: The repo to check.
+ by_user: Require the remote repo to be initialized by the user, that is, to have
+ the user's credentials. Ignored for other repo types.
Returns:
Whether the repo is initialized or not.
"""
+ if isinstance(repo, RemoteRepo) and by_user:
+ return self._is_initialized_by_user(repo)
try:
- self._api_client.repos.get(self._project, repo.repo_id, include_creds=False)
+ self._api_client.repos.get(self._project, repo.repo_id)
return True
except ResourceNotExistsError:
return False
+ def _is_initialized_by_user(self, repo: RemoteRepo) -> bool:
+ try:
+ repo_head = self._api_client.repos.get_with_creds(self._project, repo.repo_id)
+ except ResourceNotExistsError:
+ return False
+ # This works because:
+ # - RepoCollection.init() always submits RemoteRepoCreds for remote repos, even if
+ # the repo is public
+ # - Server returns creds only if there is RepoCredsModel for the user (or legacy
+ # shared creds in RepoModel)
+ # TODO: add an API method with the same logic returning a bool value?
+ return repo_head.repo_creds is not None
+
def get_ssh_keypair(key_path: Optional[PathLike], dstack_key_path: Path) -> str:
"""Returns a path to the private key"""
diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py
index 473c462139..8a87fd879f 100644
--- a/src/dstack/api/_public/runs.py
+++ b/src/dstack/api/_public/runs.py
@@ -23,7 +23,7 @@
PortMapping,
ServiceConfiguration,
)
-from dstack._internal.core.models.files import FileArchiveMapping, FilePathMapping
+from dstack._internal.core.models.files import FileArchiveMapping
from dstack._internal.core.models.profiles import (
CreationPolicy,
Profile,
@@ -499,7 +499,6 @@ def apply_plan(
self._validate_configuration_files(configuration, run_spec.configuration_path)
for file_mapping in configuration.files:
- assert isinstance(file_mapping, FilePathMapping)
with tempfile.TemporaryFile("w+b") as fp:
try:
archive_hash = create_file_archive(file_mapping.local_path, fp)
diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py
index 1cdbd5e7cb..ce0328c2ef 100644
--- a/src/dstack/api/server/__init__.py
+++ b/src/dstack/api/server/__init__.py
@@ -27,9 +27,6 @@
from dstack.api.server._users import UsersAPIClient
from dstack.api.server._volumes import VolumesAPIClient
-logger = get_logger(__name__)
-
-
_MAX_RETRIES = 3
_RETRY_INTERVAL = 1
@@ -66,6 +63,7 @@ def __init__(self, base_url: str, token: str):
client_api_version = os.getenv("DSTACK_CLIENT_API_VERSION", version.__version__)
if client_api_version is not None:
self._s.headers.update({"X-API-VERSION": client_api_version})
+ self._logger = get_logger(__name__)
@property
def base_url(self) -> str:
@@ -73,55 +71,55 @@ def base_url(self) -> str:
@property
def users(self) -> UsersAPIClient:
- return UsersAPIClient(self._request)
+ return UsersAPIClient(self._request, self._logger)
@property
def projects(self) -> ProjectsAPIClient:
- return ProjectsAPIClient(self._request)
+ return ProjectsAPIClient(self._request, self._logger)
@property
def backends(self) -> BackendsAPIClient:
- return BackendsAPIClient(self._request)
+ return BackendsAPIClient(self._request, self._logger)
@property
def fleets(self) -> FleetsAPIClient:
- return FleetsAPIClient(self._request)
+ return FleetsAPIClient(self._request, self._logger)
@property
def repos(self) -> ReposAPIClient:
- return ReposAPIClient(self._request)
+ return ReposAPIClient(self._request, self._logger)
@property
def runs(self) -> RunsAPIClient:
- return RunsAPIClient(self._request)
+ return RunsAPIClient(self._request, self._logger)
@property
def gpus(self) -> GpusAPIClient:
- return GpusAPIClient(self._request)
+ return GpusAPIClient(self._request, self._logger)
@property
def metrics(self) -> MetricsAPIClient:
- return MetricsAPIClient(self._request)
+ return MetricsAPIClient(self._request, self._logger)
@property
def logs(self) -> LogsAPIClient:
- return LogsAPIClient(self._request)
+ return LogsAPIClient(self._request, self._logger)
@property
def secrets(self) -> SecretsAPIClient:
- return SecretsAPIClient(self._request)
+ return SecretsAPIClient(self._request, self._logger)
@property
def gateways(self) -> GatewaysAPIClient:
- return GatewaysAPIClient(self._request)
+ return GatewaysAPIClient(self._request, self._logger)
@property
def volumes(self) -> VolumesAPIClient:
- return VolumesAPIClient(self._request)
+ return VolumesAPIClient(self._request, self._logger)
@property
def files(self) -> FilesAPIClient:
- return FilesAPIClient(self._request)
+ return FilesAPIClient(self._request, self._logger)
def _request(
self,
@@ -136,20 +134,20 @@ def _request(
kwargs.setdefault("headers", {})["Content-Type"] = "application/json"
kwargs["data"] = body
- logger.debug("POST /%s", path)
+ self._logger.debug("POST /%s", path)
for _ in range(_MAX_RETRIES):
try:
# TODO: set adequate timeout here or everywhere the method is used
resp = self._s.request(method, f"{self._base_url}/{path}", **kwargs)
break
except requests.exceptions.ConnectionError as e:
- logger.debug("Could not connect to server: %s", e)
+ self._logger.debug("Could not connect to server: %s", e)
time.sleep(_RETRY_INTERVAL)
else:
raise ClientError(f"Failed to connect to dstack server {self._base_url}")
if 400 <= resp.status_code < 600:
- logger.debug(
+ self._logger.debug(
"Error requesting %s. Status: %s. Headers: %s. Body: %s",
resp.request.url,
resp.status_code,
diff --git a/src/dstack/api/server/_group.py b/src/dstack/api/server/_group.py
index b893647f71..9d3ec1918a 100644
--- a/src/dstack/api/server/_group.py
+++ b/src/dstack/api/server/_group.py
@@ -1,3 +1,4 @@
+from logging import Logger
from typing import Optional
import requests
@@ -12,10 +13,10 @@ def __call__(
raise_for_status: bool = True,
method: str = "POST",
**kwargs,
- ) -> requests.Response:
- pass
+ ) -> requests.Response: ...
class APIClientGroup:
- def __init__(self, _request: APIRequest):
+ def __init__(self, _request: APIRequest, _logger: Logger):
self._request = _request
+ self._logger = _logger
diff --git a/src/dstack/api/server/_repos.py b/src/dstack/api/server/_repos.py
index a827d7e64a..03f9eb9eab 100644
--- a/src/dstack/api/server/_repos.py
+++ b/src/dstack/api/server/_repos.py
@@ -2,7 +2,12 @@
from pydantic import parse_obj_as
-from dstack._internal.core.models.repos import AnyRepoInfo, RemoteRepoCreds, RepoHead
+from dstack._internal.core.models.repos import (
+ AnyRepoInfo,
+ RemoteRepoCreds,
+ RepoHead,
+ RepoHeadWithCreds,
+)
from dstack._internal.server.schemas.repos import (
DeleteReposRequest,
GetRepoRequest,
@@ -16,11 +21,23 @@ def list(self, project_name: str) -> List[RepoHead]:
resp = self._request(f"/api/project/{project_name}/repos/list")
return parse_obj_as(List[RepoHead.__response__], resp.json())
- def get(self, project_name: str, repo_id: str, include_creds: bool) -> RepoHead:
- body = GetRepoRequest(repo_id=repo_id, include_creds=include_creds)
+ def get(
+ self, project_name: str, repo_id: str, include_creds: Optional[bool] = None
+ ) -> RepoHead:
+ if include_creds is not None:
+ self._logger.warning(
+ "`include_creds` argument is deprecated and has no effect, `get()` always returns"
+ " the repo without creds. Use `get_with_creds()` to get the repo with creds"
+ )
+ body = GetRepoRequest(repo_id=repo_id, include_creds=False)
resp = self._request(f"/api/project/{project_name}/repos/get", body=body.json())
return parse_obj_as(RepoHead.__response__, resp.json())
+ def get_with_creds(self, project_name: str, repo_id: str) -> RepoHeadWithCreds:
+ body = GetRepoRequest(repo_id=repo_id, include_creds=True)
+ resp = self._request(f"/api/project/{project_name}/repos/get", body=body.json())
+ return parse_obj_as(RepoHeadWithCreds.__response__, resp.json())
+
def init(
self,
project_name: str,
diff --git a/src/tests/_internal/core/models/test_configurations.py b/src/tests/_internal/core/models/test_configurations.py
index 0f081f615e..a4fbe8f4de 100644
--- a/src/tests/_internal/core/models/test_configurations.py
+++ b/src/tests/_internal/core/models/test_configurations.py
@@ -4,7 +4,11 @@
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import RegistryAuth
-from dstack._internal.core.models.configurations import parse_run_configuration
+from dstack._internal.core.models.configurations import (
+ DEFAULT_REPO_DIR,
+ RepoSpec,
+ parse_run_configuration,
+)
from dstack._internal.core.models.resources import Range
@@ -73,6 +77,49 @@ def test_shell_invalid(self):
parse_run_configuration(conf)
+class TestRepoSpec:
+ @pytest.mark.parametrize("value", [".", "rel/path", "/abs/path/"])
+ def test_parse_local_path_no_path(self, value: str):
+ assert RepoSpec.parse(value) == RepoSpec(local_path=value, path=DEFAULT_REPO_DIR)
+
+ @pytest.mark.parametrize(
+ ["value", "expected_repo_path"],
+ [[".:/repo", "."], ["rel/path:/repo", "rel/path"], ["/abs/path/:/repo", "/abs/path/"]],
+ )
+ def test_parse_local_path_with_path(self, value: str, expected_repo_path: str):
+ assert RepoSpec.parse(value) == RepoSpec(local_path=expected_repo_path, path="/repo")
+
+ def test_parse_windows_abs_local_path_no_path(self):
+ assert RepoSpec.parse("C:\\repo") == RepoSpec(local_path="C:\\repo", path=DEFAULT_REPO_DIR)
+
+ def test_parse_windows_abs_local_path_with_path(self):
+ assert RepoSpec.parse("C:\\repo:/repo") == RepoSpec(local_path="C:\\repo", path="/repo")
+
+ def test_parse_url_no_path(self):
+ assert RepoSpec.parse("https://example.com/repo.git") == RepoSpec(
+ url="https://example.com/repo.git", path=DEFAULT_REPO_DIR
+ )
+
+ def test_parse_url_with_path(self):
+ assert RepoSpec.parse("https://example.com/repo.git:/repo") == RepoSpec(
+ url="https://example.com/repo.git", path="/repo"
+ )
+
+ def test_parse_scp_no_path(self):
+ assert RepoSpec.parse("git@example.com:repo.git") == RepoSpec(
+ url="git@example.com:repo.git", path=DEFAULT_REPO_DIR
+ )
+
+ def test_parse_scp_with_path(self):
+ assert RepoSpec.parse("git@example.com:repo.git:/repo") == RepoSpec(
+ url="git@example.com:repo.git", path="/repo"
+ )
+
+ def test_error_invalid_mapping_if_more_than_two_parts(self):
+ with pytest.raises(ValueError, match="Invalid repo"):
+ RepoSpec.parse("./foo:bar:baz")
+
+
def test_registry_auth_hashable():
"""
RegistryAuth instances should be hashable
diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py
index 945e039495..4438b61dce 100644
--- a/src/tests/_internal/server/routers/test_runs.py
+++ b/src/tests/_internal/server/routers/test_runs.py
@@ -154,6 +154,7 @@ def get_dev_env_run_plan_dict(
"shm_size": None,
},
"volumes": [json.loads(v.json()) for v in volumes],
+ "repos": [],
"files": [],
"backends": ["local", "aws", "azure", "gcp", "lambda", "runpod"],
"regions": ["us"],
@@ -358,6 +359,7 @@ def get_dev_env_run_dict(
"shm_size": None,
},
"volumes": [],
+ "repos": [],
"files": [],
"backends": ["local", "aws", "azure", "gcp", "lambda"],
"regions": ["us"],