Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 1 addition & 59 deletions src/dstack/_internal/cli/commands/apply.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
from pathlib import Path

from argcomplete import FilesCompleter # type: ignore[attr-defined]

Expand All @@ -9,12 +8,7 @@
get_apply_configurator_class,
load_apply_configuration,
)
from dstack._internal.cli.services.repos import (
init_default_virtual_repo,
init_repo,
register_init_repo_args,
)
from dstack._internal.cli.utils.common import console, warn
from dstack._internal.cli.utils.common import console
from dstack._internal.core.errors import CLIError
from dstack._internal.core.models.configurations import ApplyConfigurationType

Expand Down Expand Up @@ -66,37 +60,6 @@ def _register(self):
help="Exit immediately after submitting configuration",
action="store_true",
)
self._parser.add_argument(
"--ssh-identity",
metavar="SSH_PRIVATE_KEY",
help="The private SSH key path for SSH tunneling",
type=Path,
dest="ssh_identity_file",
)
repo_group = self._parser.add_argument_group("Repo Options")
repo_group.add_argument(
"-P",
"--repo",
help=("The repo to use for the run. Can be a local path or a Git repo URL."),
dest="repo",
)
repo_group.add_argument(
"--repo-branch",
help="The repo branch to use for the run",
dest="repo_branch",
)
repo_group.add_argument(
"--repo-hash",
help="The hash of the repo commit to use for the run",
dest="repo_hash",
)
repo_group.add_argument(
"--no-repo",
help="Do not use any repo for the run",
dest="no_repo",
action="store_true",
)
register_init_repo_args(repo_group)

def _command(self, args: argparse.Namespace):
try:
Expand All @@ -117,26 +80,6 @@ def _command(self, args: argparse.Namespace):
super()._command(args)
if not args.yes and args.configuration_file == APPLY_STDIN_NAME:
raise CLIError("Cannot read configuration from stdin if -y/--yes is not specified")
if args.repo and args.no_repo:
raise CLIError("Either --repo or --no-repo can be specified")
if args.local:
warn(
"Local repos are deprecated since 0.19.25 and will be removed soon."
" Consider using `files` instead: https://dstack.ai/docs/concepts/tasks/#files"
)
repo = None
if args.repo:
repo = init_repo(
api=self.api,
repo_path=args.repo,
repo_branch=args.repo_branch,
repo_hash=args.repo_hash,
local=args.local,
git_identity_file=args.git_identity_file,
oauth_token=args.gh_token,
)
elif args.no_repo:
repo = init_default_virtual_repo(api=self.api)
configuration_path, configuration = load_apply_configuration(args.configuration_file)
configurator_class = get_apply_configurator_class(configuration.type)
configurator = configurator_class(api_client=self.api)
Expand All @@ -148,7 +91,6 @@ def _command(self, args: argparse.Namespace):
command_args=args,
configurator_args=known,
unknown_args=unknown,
repo=repo,
)
except KeyboardInterrupt:
console.print("\nOperation interrupted by user. Exiting...")
Expand Down
80 changes: 56 additions & 24 deletions src/dstack/_internal/cli/commands/init.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import argparse
import os
from pathlib import Path
from typing import Optional

from dstack._internal.cli.commands import BaseCommand
from dstack._internal.cli.services.repos import init_repo, register_init_repo_args
from dstack._internal.cli.services.repos import (
get_repo_from_dir,
get_repo_from_url,
is_git_repo_url,
register_init_repo_args,
)
from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.repos.base import RepoType
from dstack._internal.core.services.configs import ConfigManager
from dstack.api import Client

Expand All @@ -21,6 +26,15 @@ def _register(self):
help="The name of the project",
default=os.getenv("DSTACK_PROJECT"),
)
self._parser.add_argument(
"-P",
"--repo",
help=(
"The repo to initialize. Can be a local path or a Git repo URL."
" Defaults to the current working directory."
),
dest="repo",
)
register_init_repo_args(self._parser)
# Deprecated since 0.19.25, ignored
self._parser.add_argument(
Expand All @@ -30,7 +44,7 @@ def _register(self):
type=Path,
dest="ssh_identity_file",
)
# A hidden mode for transitional period only, remove it with local repos
# A hidden mode for transitional period only, remove it with repos in `config.yml`
self._parser.add_argument(
"--remove",
help=argparse.SUPPRESS,
Expand All @@ -39,44 +53,62 @@ def _register(self):

def _command(self, args: argparse.Namespace):
configure_logging()

repo_path: Optional[Path] = None
repo_url: Optional[str] = None
repo_arg: Optional[str] = args.repo
if repo_arg is not None:
if is_git_repo_url(repo_arg):
repo_url = repo_arg
else:
repo_path = Path(repo_arg).expanduser().resolve()
else:
repo_path = Path.cwd()

if args.remove:
if repo_url is not None:
raise ConfigurationError(f"Local path expected, got URL: {repo_url}")
assert repo_path is not None
config_manager = ConfigManager()
repo_path = Path.cwd()
repo_config = config_manager.get_repo_config(repo_path)
if repo_config is None:
raise ConfigurationError("The repo is not initialized, nothing to remove")
if repo_config.repo_type != RepoType.LOCAL:
raise ConfigurationError("`dstack init --remove` is for local repos only")
raise ConfigurationError("Repo record not found, nothing to remove")
console.print(
f"You are about to remove the local repo {repo_path}\n"
f"You are about to remove the repo {repo_path}\n"
"Only the record about the repo will be removed,"
" the repo files will remain intact\n"
)
if not confirm_ask("Remove the local repo?"):
if not confirm_ask("Remove the repo?"):
return
config_manager.delete_repo_config(repo_config.repo_id)
config_manager.save()
console.print("Local repo has been removed")
console.print("Repo has been removed")
return
api = Client.from_config(
project_name=args.project, ssh_identity_file=args.ssh_identity_file
)
if args.local:

local: bool = args.local
if local:
warn(
"Local repos are deprecated since 0.19.25 and will be removed soon."
" Consider using `files` instead: https://dstack.ai/docs/concepts/tasks/#files"
"Local repos are deprecated since 0.19.25 and will be removed soon. Consider"
" using [code]files[/code] instead: https://dstack.ai/docs/concepts/tasks/#files"
)
if args.ssh_identity_file:
warn(
"`--ssh-identity` in `dstack init` is deprecated and ignored since 0.19.25."
" Use this option with `dstack apply` and `dstack attach` instead"
"[code]--ssh-identity[/code] in [code]dstack init[/code] is deprecated and ignored"
" since 0.19.25. Use this option with [code]dstack apply[/code]"
" and [code]dstack attach[/code] instead"
)
init_repo(
api=api,
repo_path=Path.cwd(),
repo_branch=None,
repo_hash=None,
local=args.local,

if repo_url is not None:
# Dummy repo branch to avoid autodetection that fails on private repos.
# We don't need branch/hash for repo_id anyway.
repo = get_repo_from_url(repo_url, repo_branch="master")
elif repo_path is not None:
repo = get_repo_from_dir(repo_path, local=local)
else:
assert False, "should not reach here"
api = Client.from_config(project_name=args.project)
api.repos.init(
repo=repo,
git_identity_file=args.git_identity_file,
oauth_token=args.gh_token,
)
Expand Down
5 changes: 1 addition & 4 deletions src/dstack/_internal/cli/services/configurators/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os
from abc import ABC, abstractmethod
from typing import Generic, List, Optional, TypeVar, Union, cast
from typing import Generic, List, TypeVar, Union, cast

from dstack._internal.cli.services.args import env_var
from dstack._internal.core.errors import ConfigurationError
Expand All @@ -10,7 +10,6 @@
ApplyConfigurationType,
)
from dstack._internal.core.models.envs import Env, EnvSentinel, EnvVarTuple
from dstack._internal.core.models.repos.base import Repo
from dstack.api._public import Client

ArgsParser = Union[argparse._ArgumentGroup, argparse.ArgumentParser]
Expand All @@ -32,7 +31,6 @@ def apply_configuration(
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
repo: Optional[Repo] = None,
):
"""
Implements `dstack apply` for a given configuration type.
Expand All @@ -43,7 +41,6 @@ def apply_configuration(
command_args: The args parsed by `dstack apply`.
configurator_args: The known args parsed by `cls.get_parser()`.
unknown_args: The unknown args after parsing by `cls.get_parser()`.
repo: The repo to use with apply.
"""
pass

Expand Down
2 changes: 0 additions & 2 deletions src/dstack/_internal/cli/services/configurators/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
InstanceGroupPlacement,
)
from dstack._internal.core.models.instances import InstanceAvailability, InstanceStatus, SSHKey
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.services.diff import diff_models
from dstack._internal.utils.common import local_time
from dstack._internal.utils.logging import get_logger
Expand All @@ -56,7 +55,6 @@ def apply_configuration(
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
repo: Optional[Repo] = None,
):
self.apply_args(conf, configurator_args, unknown_args)
profile = load_profile(Path.cwd(), None)
Expand Down
4 changes: 1 addition & 3 deletions src/dstack/_internal/cli/services/configurators/gateway.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import time
from typing import List, Optional
from typing import List

from rich.table import Table

Expand All @@ -21,7 +21,6 @@
GatewaySpec,
GatewayStatus,
)
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.services.diff import diff_models
from dstack._internal.utils.common import local_time
from dstack.api._public import Client
Expand All @@ -37,7 +36,6 @@ def apply_configuration(
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
repo: Optional[Repo] = None,
):
self.apply_args(conf, configurator_args, unknown_args)
spec = GatewaySpec(
Expand Down
Loading
Loading