diff --git a/src/dstack/_internal/cli/commands/apply.py b/src/dstack/_internal/cli/commands/apply.py index 3ab26c5f4d..dba99f08ff 100644 --- a/src/dstack/_internal/cli/commands/apply.py +++ b/src/dstack/_internal/cli/commands/apply.py @@ -1,5 +1,4 @@ import argparse -from pathlib import Path from argcomplete import FilesCompleter # type: ignore[attr-defined] @@ -9,12 +8,7 @@ get_apply_configurator_class, load_apply_configuration, ) -from dstack._internal.cli.services.repos import ( - init_default_virtual_repo, - init_repo, - register_init_repo_args, -) -from dstack._internal.cli.utils.common import console, warn +from dstack._internal.cli.utils.common import console from dstack._internal.core.errors import CLIError from dstack._internal.core.models.configurations import ApplyConfigurationType @@ -66,37 +60,6 @@ def _register(self): help="Exit immediately after submitting configuration", action="store_true", ) - self._parser.add_argument( - "--ssh-identity", - metavar="SSH_PRIVATE_KEY", - help="The private SSH key path for SSH tunneling", - type=Path, - dest="ssh_identity_file", - ) - repo_group = self._parser.add_argument_group("Repo Options") - repo_group.add_argument( - "-P", - "--repo", - help=("The repo to use for the run. Can be a local path or a Git repo URL."), - dest="repo", - ) - repo_group.add_argument( - "--repo-branch", - help="The repo branch to use for the run", - dest="repo_branch", - ) - repo_group.add_argument( - "--repo-hash", - help="The hash of the repo commit to use for the run", - dest="repo_hash", - ) - repo_group.add_argument( - "--no-repo", - help="Do not use any repo for the run", - dest="no_repo", - action="store_true", - ) - register_init_repo_args(repo_group) def _command(self, args: argparse.Namespace): try: @@ -117,26 +80,6 @@ def _command(self, args: argparse.Namespace): super()._command(args) if not args.yes and args.configuration_file == APPLY_STDIN_NAME: raise CLIError("Cannot read configuration from stdin if -y/--yes is not specified") - if args.repo and args.no_repo: - raise CLIError("Either --repo or --no-repo can be specified") - if args.local: - warn( - "Local repos are deprecated since 0.19.25 and will be removed soon." - " Consider using `files` instead: https://dstack.ai/docs/concepts/tasks/#files" - ) - repo = None - if args.repo: - repo = init_repo( - api=self.api, - repo_path=args.repo, - repo_branch=args.repo_branch, - repo_hash=args.repo_hash, - local=args.local, - git_identity_file=args.git_identity_file, - oauth_token=args.gh_token, - ) - elif args.no_repo: - repo = init_default_virtual_repo(api=self.api) configuration_path, configuration = load_apply_configuration(args.configuration_file) configurator_class = get_apply_configurator_class(configuration.type) configurator = configurator_class(api_client=self.api) @@ -148,7 +91,6 @@ def _command(self, args: argparse.Namespace): command_args=args, configurator_args=known, unknown_args=unknown, - repo=repo, ) except KeyboardInterrupt: console.print("\nOperation interrupted by user. Exiting...") diff --git a/src/dstack/_internal/cli/commands/init.py b/src/dstack/_internal/cli/commands/init.py index 0076fc6ac1..1aab72d405 100644 --- a/src/dstack/_internal/cli/commands/init.py +++ b/src/dstack/_internal/cli/commands/init.py @@ -1,12 +1,17 @@ import argparse import os from pathlib import Path +from typing import Optional from dstack._internal.cli.commands import BaseCommand -from dstack._internal.cli.services.repos import init_repo, register_init_repo_args +from dstack._internal.cli.services.repos import ( + get_repo_from_dir, + get_repo_from_url, + is_git_repo_url, + register_init_repo_args, +) from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn from dstack._internal.core.errors import ConfigurationError -from dstack._internal.core.models.repos.base import RepoType from dstack._internal.core.services.configs import ConfigManager from dstack.api import Client @@ -21,6 +26,15 @@ def _register(self): help="The name of the project", default=os.getenv("DSTACK_PROJECT"), ) + self._parser.add_argument( + "-P", + "--repo", + help=( + "The repo to initialize. Can be a local path or a Git repo URL." + " Defaults to the current working directory." + ), + dest="repo", + ) register_init_repo_args(self._parser) # Deprecated since 0.19.25, ignored self._parser.add_argument( @@ -30,7 +44,7 @@ def _register(self): type=Path, dest="ssh_identity_file", ) - # A hidden mode for transitional period only, remove it with local repos + # A hidden mode for transitional period only, remove it with repos in `config.yml` self._parser.add_argument( "--remove", help=argparse.SUPPRESS, @@ -39,44 +53,62 @@ def _register(self): def _command(self, args: argparse.Namespace): configure_logging() + + repo_path: Optional[Path] = None + repo_url: Optional[str] = None + repo_arg: Optional[str] = args.repo + if repo_arg is not None: + if is_git_repo_url(repo_arg): + repo_url = repo_arg + else: + repo_path = Path(repo_arg).expanduser().resolve() + else: + repo_path = Path.cwd() + if args.remove: + if repo_url is not None: + raise ConfigurationError(f"Local path expected, got URL: {repo_url}") + assert repo_path is not None config_manager = ConfigManager() - repo_path = Path.cwd() repo_config = config_manager.get_repo_config(repo_path) if repo_config is None: - raise ConfigurationError("The repo is not initialized, nothing to remove") - if repo_config.repo_type != RepoType.LOCAL: - raise ConfigurationError("`dstack init --remove` is for local repos only") + raise ConfigurationError("Repo record not found, nothing to remove") console.print( - f"You are about to remove the local repo {repo_path}\n" + f"You are about to remove the repo {repo_path}\n" "Only the record about the repo will be removed," " the repo files will remain intact\n" ) - if not confirm_ask("Remove the local repo?"): + if not confirm_ask("Remove the repo?"): return config_manager.delete_repo_config(repo_config.repo_id) config_manager.save() - console.print("Local repo has been removed") + console.print("Repo has been removed") return - api = Client.from_config( - project_name=args.project, ssh_identity_file=args.ssh_identity_file - ) - if args.local: + + local: bool = args.local + if local: warn( - "Local repos are deprecated since 0.19.25 and will be removed soon." - " Consider using `files` instead: https://dstack.ai/docs/concepts/tasks/#files" + "Local repos are deprecated since 0.19.25 and will be removed soon. Consider" + " using [code]files[/code] instead: https://dstack.ai/docs/concepts/tasks/#files" ) if args.ssh_identity_file: warn( - "`--ssh-identity` in `dstack init` is deprecated and ignored since 0.19.25." - " Use this option with `dstack apply` and `dstack attach` instead" + "[code]--ssh-identity[/code] in [code]dstack init[/code] is deprecated and ignored" + " since 0.19.25. Use this option with [code]dstack apply[/code]" + " and [code]dstack attach[/code] instead" ) - init_repo( - api=api, - repo_path=Path.cwd(), - repo_branch=None, - repo_hash=None, - local=args.local, + + if repo_url is not None: + # Dummy repo branch to avoid autodetection that fails on private repos. + # We don't need branch/hash for repo_id anyway. + repo = get_repo_from_url(repo_url, repo_branch="master") + elif repo_path is not None: + repo = get_repo_from_dir(repo_path, local=local) + else: + assert False, "should not reach here" + api = Client.from_config(project_name=args.project) + api.repos.init( + repo=repo, git_identity_file=args.git_identity_file, oauth_token=args.gh_token, ) diff --git a/src/dstack/_internal/cli/services/configurators/base.py b/src/dstack/_internal/cli/services/configurators/base.py index 440a31d6c2..35414801ee 100644 --- a/src/dstack/_internal/cli/services/configurators/base.py +++ b/src/dstack/_internal/cli/services/configurators/base.py @@ -1,7 +1,7 @@ import argparse import os from abc import ABC, abstractmethod -from typing import Generic, List, Optional, TypeVar, Union, cast +from typing import Generic, List, TypeVar, Union, cast from dstack._internal.cli.services.args import env_var from dstack._internal.core.errors import ConfigurationError @@ -10,7 +10,6 @@ ApplyConfigurationType, ) from dstack._internal.core.models.envs import Env, EnvSentinel, EnvVarTuple -from dstack._internal.core.models.repos.base import Repo from dstack.api._public import Client ArgsParser = Union[argparse._ArgumentGroup, argparse.ArgumentParser] @@ -32,7 +31,6 @@ def apply_configuration( command_args: argparse.Namespace, configurator_args: argparse.Namespace, unknown_args: List[str], - repo: Optional[Repo] = None, ): """ Implements `dstack apply` for a given configuration type. @@ -43,7 +41,6 @@ def apply_configuration( command_args: The args parsed by `dstack apply`. configurator_args: The known args parsed by `cls.get_parser()`. unknown_args: The unknown args after parsing by `cls.get_parser()`. - repo: The repo to use with apply. """ pass diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index 6718f4f0f2..a250058bc3 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -35,7 +35,6 @@ InstanceGroupPlacement, ) from dstack._internal.core.models.instances import InstanceAvailability, InstanceStatus, SSHKey -from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.services.diff import diff_models from dstack._internal.utils.common import local_time from dstack._internal.utils.logging import get_logger @@ -56,7 +55,6 @@ def apply_configuration( command_args: argparse.Namespace, configurator_args: argparse.Namespace, unknown_args: List[str], - repo: Optional[Repo] = None, ): self.apply_args(conf, configurator_args, unknown_args) profile = load_profile(Path.cwd(), None) diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py index 8a22277b17..7d26e220ba 100644 --- a/src/dstack/_internal/cli/services/configurators/gateway.py +++ b/src/dstack/_internal/cli/services/configurators/gateway.py @@ -1,6 +1,6 @@ import argparse import time -from typing import List, Optional +from typing import List from rich.table import Table @@ -21,7 +21,6 @@ GatewaySpec, GatewayStatus, ) -from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.services.diff import diff_models from dstack._internal.utils.common import local_time from dstack.api._public import Client @@ -37,7 +36,6 @@ def apply_configuration( command_args: argparse.Namespace, configurator_args: argparse.Namespace, unknown_args: List[str], - repo: Optional[Repo] = None, ): self.apply_args(conf, configurator_args, unknown_args) spec = GatewaySpec( diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 02f02f7b42..4a4a4d453b 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -15,12 +15,14 @@ BaseApplyConfigurator, ) from dstack._internal.cli.services.profile import apply_profile_args, register_profile_args -from dstack._internal.cli.services.repos import init_default_virtual_repo -from dstack._internal.cli.utils.common import ( - confirm_ask, - console, - warn, +from dstack._internal.cli.services.repos import ( + get_repo_from_dir, + get_repo_from_url, + init_default_virtual_repo, + is_git_repo_url, + register_init_repo_args, ) +from dstack._internal.cli.utils.common import confirm_ask, console, warn from dstack._internal.cli.utils.rich import MultiItemStatus from dstack._internal.cli.utils.run import get_runs_table, print_run_plan from dstack._internal.core.errors import ( @@ -46,6 +48,7 @@ from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.diff import diff_models +from dstack._internal.core.services.repos import load_repo from dstack._internal.utils.common import local_time from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator from dstack._internal.utils.logging import get_logger @@ -78,42 +81,18 @@ def apply_configuration( command_args: argparse.Namespace, configurator_args: argparse.Namespace, unknown_args: List[str], - repo: Optional[Repo] = None, ): + if configurator_args.repo and configurator_args.no_repo: + raise CLIError("Either --repo or --no-repo can be specified") + self.apply_args(conf, configurator_args, unknown_args) self.validate_gpu_vendor_and_image(conf) self.validate_cpu_arch_and_image(conf) + config_manager = ConfigManager() - if repo is None: - repo_path = Path.cwd() - repo_config = config_manager.get_repo_config(repo_path) - if repo_config is None: - warn( - "Repo is not initialized. " - "Use [code]--repo [/code] or [code]--no-repo[/code] to initialize it.\n" - "Starting from 0.19.26, repos will be configured via YAML and this message won't appear." - ) - if not command_args.yes and not confirm_ask("Continue without the repo?"): - console.print("\nExiting...") - return - repo = init_default_virtual_repo(self.api) - else: - # Unlikely, but may raise ConfigurationError if the repo does not exist - # on the server side (stale entry in `config.yml`) - repo = self.api.repos.load(repo_path) - if isinstance(repo, LocalRepo): - warn( - f"{repo.repo_dir} is a local repo.\n" - "Local repos are deprecated since 0.19.25" - " and will be removed soon\n" - "There are two options:\n" - " - Migrate to `files`: https://dstack.ai/docs/concepts/tasks/#files\n" - " - Specify `--no-repo` if you don't need the repo at all\n" - "In either case, you can run `dstack init --remove` to remove the repo" - " (only the record about the repo, not its files) and this warning" - ) + repo = self.get_repo(conf, configuration_path, configurator_args, config_manager) self.api.ssh_identity_file = get_ssh_keypair( - command_args.ssh_identity_file, + configurator_args.ssh_identity_file, config_manager.dstack_key_path, ) profile = load_profile(Path.cwd(), configurator_args.profile) @@ -298,6 +277,13 @@ def delete_configuration( @classmethod def register_args(cls, parser: argparse.ArgumentParser): + parser.add_argument( + "--ssh-identity", + metavar="SSH_PRIVATE_KEY", + help="The private SSH key path for SSH tunneling", + type=Path, + dest="ssh_identity_file", + ) configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options") configuration_group.add_argument( "-n", @@ -336,6 +322,30 @@ def register_args(cls, parser: argparse.ArgumentParser): dest="disk_spec", ) register_profile_args(parser) + repo_group = parser.add_argument_group("Repo Options") + repo_group.add_argument( + "-P", + "--repo", + help=("The repo to use for the run. Can be a local path or a Git repo URL."), + dest="repo", + ) + repo_group.add_argument( + "--repo-branch", + help="The repo branch to use for the run", + dest="repo_branch", + ) + repo_group.add_argument( + "--repo-hash", + help="The hash of the repo commit to use for the run", + dest="repo_hash", + ) + repo_group.add_argument( + "--no-repo", + help="Do not use any repo for the run", + dest="no_repo", + action="store_true", + ) + register_init_repo_args(repo_group) def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace, unknown: List[str]): apply_profile_args(args, conf) @@ -465,6 +475,118 @@ def validate_cpu_arch_and_image(self, conf: RunConfigurationT) -> None: if arch == gpuhunt.CPUArchitecture.ARM and conf.image is None: raise ConfigurationError("`image` is required if `resources.cpu.arch` is `arm`") + def get_repo( + self, + conf: RunConfigurationT, + configuration_path: str, + configurator_args: argparse.Namespace, + config_manager: ConfigManager, + ) -> Repo: + if configurator_args.no_repo: + return init_default_virtual_repo(api=self.api) + + repo: Optional[Repo] = None + repo_branch: Optional[str] = configurator_args.repo_branch + repo_hash: Optional[str] = configurator_args.repo_hash + # Should we (re)initialize the repo? + # If any Git credentials provided, we reinitialize the repo, as the user may have provided + # updated credentials. + init = ( + configurator_args.git_identity_file is not None + or configurator_args.gh_token is not None + ) + + url: Optional[str] = None + local_path: Optional[Path] = None + # dummy value, safe to join with any path + root_dir = Path(".") + # True if no repo specified, but we found one in `config.yml` + legacy_local_path = False + if repo_arg := configurator_args.repo: + if is_git_repo_url(repo_arg): + url = repo_arg + else: + local_path = Path(repo_arg) + # rel paths in `--repo` are resolved relative to the current working dir + root_dir = Path.cwd() + elif conf.repos: + repo_spec = conf.repos[0] + if repo_spec.url: + url = repo_spec.url + elif repo_spec.local_path: + local_path = Path(repo_spec.local_path) + # rel paths in the conf are resolved relative to the conf's parent dir + root_dir = Path(configuration_path).resolve().parent + else: + assert False, f"should not reach here: {repo_spec}" + if repo_branch is None: + repo_branch = repo_spec.branch + if repo_hash is None: + repo_hash = repo_spec.hash + else: + local_path = Path.cwd() + legacy_local_path = True + if url: + repo = get_repo_from_url(repo_url=url, repo_branch=repo_branch, repo_hash=repo_hash) + if not self.api.repos.is_initialized(repo, by_user=True): + init = True + elif local_path: + if legacy_local_path: + if repo_config := config_manager.get_repo_config(local_path): + repo = load_repo(repo_config) + # allow users with legacy configurations use shared repo creds + if self.api.repos.is_initialized(repo, by_user=False): + warn( + "The repo is not specified but found and will be used in the run\n" + "Future versions will not load repos automatically\n" + "To prepare for future versions and get rid of this warning:\n" + "- If you need the repo in the run, either specify [code]repos[/code]" + " in the configuration or use [code]--repo .[/code]\n" + "- If you don't need the repo in the run, either run" + " [code]dstack init --remove[/code] once (it removes only the record" + " about the repo, the repo files will remain intact)" + " or use [code]--no-repo[/code]" + ) + else: + # ignore stale entries in `config.yml` + repo = None + init = False + else: + original_local_path = local_path + local_path = local_path.expanduser() + if not local_path.is_absolute(): + local_path = (root_dir / local_path).resolve() + if not local_path.exists(): + raise ConfigurationError( + f"Invalid repo path: {original_local_path} -> {local_path}" + ) + local: bool = configurator_args.local + repo = get_repo_from_dir(local_path, local=local) + if not self.api.repos.is_initialized(repo, by_user=True): + init = True + else: + assert False, "should not reach here" + + if repo is None: + return init_default_virtual_repo(api=self.api) + + if init: + self.api.repos.init( + repo=repo, + git_identity_file=configurator_args.git_identity_file, + oauth_token=configurator_args.gh_token, + ) + if isinstance(repo, LocalRepo): + warn( + f"{repo.repo_dir} is a local repo\n" + "Local repos are deprecated since 0.19.25 and will be removed soon\n" + "There are two options:\n" + "- Migrate to [code]files[/code]: https://dstack.ai/docs/concepts/tasks/#files\n" + "- Specify [code]--no-repo[/code] if you don't need the repo at all" + ) + + return repo + class RunWithPortsConfiguratorMixin: @classmethod diff --git a/src/dstack/_internal/cli/services/configurators/volume.py b/src/dstack/_internal/cli/services/configurators/volume.py index 72b21e5bb4..b0e25e503c 100644 --- a/src/dstack/_internal/cli/services/configurators/volume.py +++ b/src/dstack/_internal/cli/services/configurators/volume.py @@ -1,6 +1,6 @@ import argparse import time -from typing import List, Optional +from typing import List from rich.table import Table @@ -14,7 +14,6 @@ from dstack._internal.cli.utils.volume import get_volumes_table from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.configurations import ApplyConfigurationType -from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.volumes import ( Volume, VolumeConfiguration, @@ -36,7 +35,6 @@ def apply_configuration( command_args: argparse.Namespace, configurator_args: argparse.Namespace, unknown_args: List[str], - repo: Optional[Repo] = None, ): self.apply_args(conf, configurator_args, unknown_args) spec = VolumeSpec( diff --git a/src/dstack/_internal/cli/services/repos.py b/src/dstack/_internal/cli/services/repos.py index 5e7a10589f..0abc9d9ef4 100644 --- a/src/dstack/_internal/cli/services/repos.py +++ b/src/dstack/_internal/cli/services/repos.py @@ -1,10 +1,11 @@ import argparse -from pathlib import Path -from typing import Optional +from typing import Literal, Optional, Union, overload + +import git from dstack._internal.cli.services.configurators.base import ArgsParser from dstack._internal.core.errors import CLIError -from dstack._internal.core.models.repos.base import Repo +from dstack._internal.core.models.repos.local import LocalRepo from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError from dstack._internal.core.models.repos.virtual import VirtualRepo from dstack._internal.core.services.repos import get_default_branch @@ -36,51 +37,54 @@ def register_init_repo_args(parser: ArgsParser): ) -def init_repo( - api: Client, - repo_path: PathLike, - repo_branch: Optional[str], - repo_hash: Optional[str], - local: bool, - git_identity_file: Optional[PathLike], - oauth_token: Optional[str], -) -> Repo: - if Path(repo_path).exists(): - repo = api.repos.load( - repo_dir=repo_path, - local=local, - init=True, - git_identity_file=git_identity_file, - oauth_token=oauth_token, - ) - elif isinstance(repo_path, str): - try: - GitRepoURL.parse(repo_path) - except RepoError as e: - raise CLIError("Invalid repo path") from e - if repo_branch is None and repo_hash is None: - repo_branch = get_default_branch(repo_path) - if repo_branch is None: - raise CLIError( - "Failed to automatically detect remote repo branch." - " Specify --repo-branch or --repo-hash." - ) - repo = RemoteRepo.from_url( - repo_url=repo_path, - repo_branch=repo_branch, - repo_hash=repo_hash, - ) - api.repos.init( - repo=repo, - git_identity_file=git_identity_file, - oauth_token=oauth_token, - ) - else: - raise CLIError("Invalid repo path") - return repo - - def init_default_virtual_repo(api: Client) -> VirtualRepo: repo = VirtualRepo() api.repos.init(repo) return repo + + +def get_repo_from_url( + repo_url: str, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None +) -> RemoteRepo: + if repo_branch is None and repo_hash is None: + repo_branch = get_default_branch(repo_url) + if repo_branch is None: + raise CLIError( + "Failed to automatically detect remote repo branch. Specify branch or hash." + ) + return RemoteRepo.from_url( + repo_url=repo_url, + repo_branch=repo_branch, + repo_hash=repo_hash, + ) + + +@overload +def get_repo_from_dir(repo_dir: PathLike, local: Literal[False] = False) -> RemoteRepo: ... + + +@overload +def get_repo_from_dir(repo_dir: PathLike, local: Literal[True]) -> LocalRepo: ... + + +def get_repo_from_dir(repo_dir: PathLike, local: bool = False) -> Union[RemoteRepo, LocalRepo]: + if local: + return LocalRepo.from_dir(repo_dir) + try: + return RemoteRepo.from_dir(repo_dir) + except git.InvalidGitRepositoryError: + raise CLIError( + f"Git repo not found: {repo_dir}\n" + "Use `files` to mount an arbitrary directory:" + " https://dstack.ai/docs/concepts/tasks/#files" + ) + except RepoError as e: + raise CLIError(str(e)) from e + + +def is_git_repo_url(value: str) -> bool: + try: + GitRepoURL.parse(value) + except RepoError: + return False + return True diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index f58be55e19..cfb809acb5 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -136,6 +136,7 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: configuration_excludes["schedule"] = True if profile is not None and profile.schedule is None: profile_excludes.add("schedule") + configuration_excludes["repos"] = True if configuration_excludes: spec_excludes["configuration"] = configuration_excludes diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 4b92d6f82b..a72792da24 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -1,12 +1,13 @@ import re +import string from collections import Counter from enum import Enum from pathlib import PurePosixPath -from typing import Any, Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Union import orjson from pydantic import Field, ValidationError, conint, constr, root_validator, validator -from typing_extensions import Annotated, Literal +from typing_extensions import Self from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth @@ -83,6 +84,72 @@ def parse(cls, v: str) -> "PortMapping": return PortMapping(local_port=local_port, container_port=int(container_port)) +class RepoSpec(CoreModel): + local_path: Annotated[ + Optional[str], + Field( + description=( + "The path to the Git repo on the user's machine. Relative paths are resolved" + " relative to the parent directory of the the configuration file." + " Mutually exclusive with `url`" + ) + ), + ] = None + url: Annotated[ + Optional[str], + Field(description="The Git repo URL. Mutually exclusive with `local_path`"), + ] = None + branch: Annotated[ + Optional[str], + Field( + description=( + "The repo branch. Defaults to the active branch for local paths" + " and the default branch for URLs" + ) + ), + ] = None + hash: Annotated[ + Optional[str], + Field(description="The commit hash"), + ] = None + # Not implemented, has no effect, hidden in the docs + path: str = DEFAULT_REPO_DIR + + @classmethod + def parse(cls, v: str) -> Self: + is_url = False + parts = v.split(":") + if len(parts) > 1: + # Git repo, git@github.com:dstackai/dstack.git or https://github.com/dstackai/dstack + if "@" in parts[0] or parts[1].startswith("//"): + parts = [f"{parts[0]}:{parts[1]}", *parts[2:]] + is_url = True + # Windows path, e.g., `C:\path\to`, 'c:/path/to' + elif ( + len(parts[0]) == 1 + and parts[0] in string.ascii_letters + and parts[1][:1] in ["\\", "/"] + ): + parts = [f"{parts[0]}:{parts[1]}", *parts[2:]] + if len(parts) == 1: + if is_url: + return cls(url=parts[0]) + return cls(local_path=parts[0]) + if len(parts) == 2: + if is_url: + return cls(url=parts[0], path=parts[1]) + return cls(local_path=parts[0], path=parts[1]) + raise ValueError(f"Invalid repo: {v}") + + @root_validator + def validate_local_path_or_url(cls, values): + if values["local_path"] and values["url"]: + raise ValueError("`local_path` and `url` are mutually exclusive") + if not values["local_path"] and not values["url"]: + raise ValueError("Either `local_path` or `url` must be specified") + return values + + class ScalingSpec(CoreModel): metric: Annotated[ Literal["rps"], @@ -392,6 +459,10 @@ class BaseRunConfiguration(CoreModel): description="Use Docker inside the container. Mutually exclusive with `image`, `python`, and `nvcc`. Overrides `privileged`" ), ] = None + repos: Annotated[ + list[RepoSpec], + Field(description="The list of Git repos"), + ] = [] files: Annotated[ list[FilePathMapping], Field(description="The local to container file path mappings"), @@ -447,6 +518,18 @@ def convert_files(cls, v: Union[FilePathMapping, str]) -> FilePathMapping: return FilePathMapping.parse(v) return v + @validator("repos", pre=True, each_item=True) + def convert_repos(cls, v: Union[RepoSpec, str]) -> RepoSpec: + if isinstance(v, str): + return RepoSpec.parse(v) + return v + + @validator("repos") + def validate_repos(cls, v) -> RepoSpec: + if len(v) > 1: + raise ValueError("A maximum of one repo is currently supported") + return v + @validator("user") def validate_user(cls, v) -> Optional[str]: if v is None: diff --git a/src/dstack/_internal/core/models/files.py b/src/dstack/_internal/core/models/files.py index f2e4f6826d..797c64da23 100644 --- a/src/dstack/_internal/core/models/files.py +++ b/src/dstack/_internal/core/models/files.py @@ -28,7 +28,7 @@ class FilePathMapping(CoreModel): Field( description=( "The path in the container. Relative paths are resolved relative to" - " the repo directory (`/workflow`)" + " the repo directory" ) ), ] diff --git a/src/dstack/_internal/core/services/repos.py b/src/dstack/_internal/core/services/repos.py index 61ff9b3abb..a054519fd1 100644 --- a/src/dstack/_internal/core/services/repos.py +++ b/src/dstack/_internal/core/services/repos.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Optional, Union -import git +import git.cmd import requests import yaml from git.exc import GitCommandError @@ -24,6 +24,8 @@ gh_config_path = os.path.expanduser("~/.config/gh/hosts.yml") default_ssh_key = os.path.expanduser("~/.ssh/id_rsa") +no_prompt_env = dict(GIT_TERMINAL_PROMPT="0") + class InvalidRepoCredentialsError(DstackError): pass @@ -84,7 +86,7 @@ def get_local_repo_credentials( def check_remote_repo_credentials_https(url: GitRepoURL, oauth_token: str) -> RemoteRepoCreds: try: - git.cmd.Git().ls_remote(url.as_https(oauth_token), env=dict(GIT_TERMINAL_PROMPT="0")) # type: ignore[attr-defined] + git.cmd.Git().ls_remote(url.as_https(oauth_token), env=no_prompt_env) except GitCommandError: masked = len(oauth_token[:-4]) * "*" + oauth_token[-4:] raise InvalidRepoCredentialsError( @@ -111,7 +113,7 @@ def check_remote_repo_credentials_ssh(url: GitRepoURL, identity_file: PathLike) private_key = f.read() try: - git.cmd.Git().ls_remote( # type: ignore[attr-defined] + git.cmd.Git().ls_remote( url.as_ssh(), env=dict(GIT_SSH_COMMAND=make_ssh_command_for_git(identity_file)) ) except GitCommandError: @@ -131,7 +133,7 @@ def get_default_branch(remote_url: str) -> Optional[str]: Get the default branch of a remote Git repository. """ try: - output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD") # type: ignore[attr-defined] + output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD", env=no_prompt_env) for line in output.splitlines(): if line.startswith("ref:"): return line.split()[1].split("/")[-1] diff --git a/src/dstack/api/_public/repos.py b/src/dstack/api/_public/repos.py index 9655e48f91..c05f80de75 100644 --- a/src/dstack/api/_public/repos.py +++ b/src/dstack/api/_public/repos.py @@ -68,6 +68,7 @@ def init( """ creds = None if isinstance(repo, RemoteRepo): + assert repo.repo_url is not None try: creds = get_local_repo_credentials( repo_url=repo.repo_url, @@ -140,22 +141,40 @@ def load( def is_initialized( self, repo: Repo, + by_user: bool = False, ) -> bool: """ - Checks if the remote repo is initialized in the project + Checks if the repo is initialized in the project Args: repo: The repo to check. + by_user: Require the remote repo to be initialized by the user, that is, to have + the user's credentials. Ignored for other repo types. Returns: Whether the repo is initialized or not. """ + if isinstance(repo, RemoteRepo) and by_user: + return self._is_initialized_by_user(repo) try: - self._api_client.repos.get(self._project, repo.repo_id, include_creds=False) + self._api_client.repos.get(self._project, repo.repo_id) return True except ResourceNotExistsError: return False + def _is_initialized_by_user(self, repo: RemoteRepo) -> bool: + try: + repo_head = self._api_client.repos.get_with_creds(self._project, repo.repo_id) + except ResourceNotExistsError: + return False + # This works because: + # - RepoCollection.init() always submits RemoteRepoCreds for remote repos, even if + # the repo is public + # - Server returns creds only if there is RepoCredsModel for the user (or legacy + # shared creds in RepoModel) + # TODO: add an API method with the same logic returning a bool value? + return repo_head.repo_creds is not None + def get_ssh_keypair(key_path: Optional[PathLike], dstack_key_path: Path) -> str: """Returns a path to the private key""" diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 473c462139..8a87fd879f 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -23,7 +23,7 @@ PortMapping, ServiceConfiguration, ) -from dstack._internal.core.models.files import FileArchiveMapping, FilePathMapping +from dstack._internal.core.models.files import FileArchiveMapping from dstack._internal.core.models.profiles import ( CreationPolicy, Profile, @@ -499,7 +499,6 @@ def apply_plan( self._validate_configuration_files(configuration, run_spec.configuration_path) for file_mapping in configuration.files: - assert isinstance(file_mapping, FilePathMapping) with tempfile.TemporaryFile("w+b") as fp: try: archive_hash = create_file_archive(file_mapping.local_path, fp) diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index 1cdbd5e7cb..ce0328c2ef 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -27,9 +27,6 @@ from dstack.api.server._users import UsersAPIClient from dstack.api.server._volumes import VolumesAPIClient -logger = get_logger(__name__) - - _MAX_RETRIES = 3 _RETRY_INTERVAL = 1 @@ -66,6 +63,7 @@ def __init__(self, base_url: str, token: str): client_api_version = os.getenv("DSTACK_CLIENT_API_VERSION", version.__version__) if client_api_version is not None: self._s.headers.update({"X-API-VERSION": client_api_version}) + self._logger = get_logger(__name__) @property def base_url(self) -> str: @@ -73,55 +71,55 @@ def base_url(self) -> str: @property def users(self) -> UsersAPIClient: - return UsersAPIClient(self._request) + return UsersAPIClient(self._request, self._logger) @property def projects(self) -> ProjectsAPIClient: - return ProjectsAPIClient(self._request) + return ProjectsAPIClient(self._request, self._logger) @property def backends(self) -> BackendsAPIClient: - return BackendsAPIClient(self._request) + return BackendsAPIClient(self._request, self._logger) @property def fleets(self) -> FleetsAPIClient: - return FleetsAPIClient(self._request) + return FleetsAPIClient(self._request, self._logger) @property def repos(self) -> ReposAPIClient: - return ReposAPIClient(self._request) + return ReposAPIClient(self._request, self._logger) @property def runs(self) -> RunsAPIClient: - return RunsAPIClient(self._request) + return RunsAPIClient(self._request, self._logger) @property def gpus(self) -> GpusAPIClient: - return GpusAPIClient(self._request) + return GpusAPIClient(self._request, self._logger) @property def metrics(self) -> MetricsAPIClient: - return MetricsAPIClient(self._request) + return MetricsAPIClient(self._request, self._logger) @property def logs(self) -> LogsAPIClient: - return LogsAPIClient(self._request) + return LogsAPIClient(self._request, self._logger) @property def secrets(self) -> SecretsAPIClient: - return SecretsAPIClient(self._request) + return SecretsAPIClient(self._request, self._logger) @property def gateways(self) -> GatewaysAPIClient: - return GatewaysAPIClient(self._request) + return GatewaysAPIClient(self._request, self._logger) @property def volumes(self) -> VolumesAPIClient: - return VolumesAPIClient(self._request) + return VolumesAPIClient(self._request, self._logger) @property def files(self) -> FilesAPIClient: - return FilesAPIClient(self._request) + return FilesAPIClient(self._request, self._logger) def _request( self, @@ -136,20 +134,20 @@ def _request( kwargs.setdefault("headers", {})["Content-Type"] = "application/json" kwargs["data"] = body - logger.debug("POST /%s", path) + self._logger.debug("POST /%s", path) for _ in range(_MAX_RETRIES): try: # TODO: set adequate timeout here or everywhere the method is used resp = self._s.request(method, f"{self._base_url}/{path}", **kwargs) break except requests.exceptions.ConnectionError as e: - logger.debug("Could not connect to server: %s", e) + self._logger.debug("Could not connect to server: %s", e) time.sleep(_RETRY_INTERVAL) else: raise ClientError(f"Failed to connect to dstack server {self._base_url}") if 400 <= resp.status_code < 600: - logger.debug( + self._logger.debug( "Error requesting %s. Status: %s. Headers: %s. Body: %s", resp.request.url, resp.status_code, diff --git a/src/dstack/api/server/_group.py b/src/dstack/api/server/_group.py index b893647f71..9d3ec1918a 100644 --- a/src/dstack/api/server/_group.py +++ b/src/dstack/api/server/_group.py @@ -1,3 +1,4 @@ +from logging import Logger from typing import Optional import requests @@ -12,10 +13,10 @@ def __call__( raise_for_status: bool = True, method: str = "POST", **kwargs, - ) -> requests.Response: - pass + ) -> requests.Response: ... class APIClientGroup: - def __init__(self, _request: APIRequest): + def __init__(self, _request: APIRequest, _logger: Logger): self._request = _request + self._logger = _logger diff --git a/src/dstack/api/server/_repos.py b/src/dstack/api/server/_repos.py index a827d7e64a..03f9eb9eab 100644 --- a/src/dstack/api/server/_repos.py +++ b/src/dstack/api/server/_repos.py @@ -2,7 +2,12 @@ from pydantic import parse_obj_as -from dstack._internal.core.models.repos import AnyRepoInfo, RemoteRepoCreds, RepoHead +from dstack._internal.core.models.repos import ( + AnyRepoInfo, + RemoteRepoCreds, + RepoHead, + RepoHeadWithCreds, +) from dstack._internal.server.schemas.repos import ( DeleteReposRequest, GetRepoRequest, @@ -16,11 +21,23 @@ def list(self, project_name: str) -> List[RepoHead]: resp = self._request(f"/api/project/{project_name}/repos/list") return parse_obj_as(List[RepoHead.__response__], resp.json()) - def get(self, project_name: str, repo_id: str, include_creds: bool) -> RepoHead: - body = GetRepoRequest(repo_id=repo_id, include_creds=include_creds) + def get( + self, project_name: str, repo_id: str, include_creds: Optional[bool] = None + ) -> RepoHead: + if include_creds is not None: + self._logger.warning( + "`include_creds` argument is deprecated and has no effect, `get()` always returns" + " the repo without creds. Use `get_with_creds()` to get the repo with creds" + ) + body = GetRepoRequest(repo_id=repo_id, include_creds=False) resp = self._request(f"/api/project/{project_name}/repos/get", body=body.json()) return parse_obj_as(RepoHead.__response__, resp.json()) + def get_with_creds(self, project_name: str, repo_id: str) -> RepoHeadWithCreds: + body = GetRepoRequest(repo_id=repo_id, include_creds=True) + resp = self._request(f"/api/project/{project_name}/repos/get", body=body.json()) + return parse_obj_as(RepoHeadWithCreds.__response__, resp.json()) + def init( self, project_name: str, diff --git a/src/tests/_internal/core/models/test_configurations.py b/src/tests/_internal/core/models/test_configurations.py index 0f081f615e..a4fbe8f4de 100644 --- a/src/tests/_internal/core/models/test_configurations.py +++ b/src/tests/_internal/core/models/test_configurations.py @@ -4,7 +4,11 @@ from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.common import RegistryAuth -from dstack._internal.core.models.configurations import parse_run_configuration +from dstack._internal.core.models.configurations import ( + DEFAULT_REPO_DIR, + RepoSpec, + parse_run_configuration, +) from dstack._internal.core.models.resources import Range @@ -73,6 +77,49 @@ def test_shell_invalid(self): parse_run_configuration(conf) +class TestRepoSpec: + @pytest.mark.parametrize("value", [".", "rel/path", "/abs/path/"]) + def test_parse_local_path_no_path(self, value: str): + assert RepoSpec.parse(value) == RepoSpec(local_path=value, path=DEFAULT_REPO_DIR) + + @pytest.mark.parametrize( + ["value", "expected_repo_path"], + [[".:/repo", "."], ["rel/path:/repo", "rel/path"], ["/abs/path/:/repo", "/abs/path/"]], + ) + def test_parse_local_path_with_path(self, value: str, expected_repo_path: str): + assert RepoSpec.parse(value) == RepoSpec(local_path=expected_repo_path, path="/repo") + + def test_parse_windows_abs_local_path_no_path(self): + assert RepoSpec.parse("C:\\repo") == RepoSpec(local_path="C:\\repo", path=DEFAULT_REPO_DIR) + + def test_parse_windows_abs_local_path_with_path(self): + assert RepoSpec.parse("C:\\repo:/repo") == RepoSpec(local_path="C:\\repo", path="/repo") + + def test_parse_url_no_path(self): + assert RepoSpec.parse("https://example.com/repo.git") == RepoSpec( + url="https://example.com/repo.git", path=DEFAULT_REPO_DIR + ) + + def test_parse_url_with_path(self): + assert RepoSpec.parse("https://example.com/repo.git:/repo") == RepoSpec( + url="https://example.com/repo.git", path="/repo" + ) + + def test_parse_scp_no_path(self): + assert RepoSpec.parse("git@example.com:repo.git") == RepoSpec( + url="git@example.com:repo.git", path=DEFAULT_REPO_DIR + ) + + def test_parse_scp_with_path(self): + assert RepoSpec.parse("git@example.com:repo.git:/repo") == RepoSpec( + url="git@example.com:repo.git", path="/repo" + ) + + def test_error_invalid_mapping_if_more_than_two_parts(self): + with pytest.raises(ValueError, match="Invalid repo"): + RepoSpec.parse("./foo:bar:baz") + + def test_registry_auth_hashable(): """ RegistryAuth instances should be hashable diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 945e039495..4438b61dce 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -154,6 +154,7 @@ def get_dev_env_run_plan_dict( "shm_size": None, }, "volumes": [json.loads(v.json()) for v in volumes], + "repos": [], "files": [], "backends": ["local", "aws", "azure", "gcp", "lambda", "runpod"], "regions": ["us"], @@ -358,6 +359,7 @@ def get_dev_env_run_dict( "shm_size": None, }, "volumes": [], + "repos": [], "files": [], "backends": ["local", "aws", "azure", "gcp", "lambda"], "regions": ["us"],