Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/dstack/_internal/cli/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import argparse
import os
import shlex
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import ClassVar, Optional

from rich_argparse import RichHelpFormatter

from dstack._internal.cli.services.completion import ProjectNameCompleter
from dstack._internal.cli.utils.common import configure_logging
from dstack._internal.core.errors import CLIError
from dstack.api import Client


class BaseCommand(ABC):
NAME: str = "name the command"
DESCRIPTION: str = "describe the command"
DEFAULT_HELP: bool = True
ALIASES: Optional[List[str]] = None
NAME: ClassVar[str] = "name the command"
DESCRIPTION: ClassVar[str] = "describe the command"
DEFAULT_HELP: ClassVar[bool] = True
ALIASES: ClassVar[Optional[list[str]]] = None
ACCEPT_EXTRA_ARGS: ClassVar[bool] = False

def __init__(self, parser: argparse.ArgumentParser):
self._parser = parser
Expand Down Expand Up @@ -50,7 +52,8 @@ def _register(self):

@abstractmethod
def _command(self, args: argparse.Namespace):
pass
if not self.ACCEPT_EXTRA_ARGS and args.extra_args:
raise CLIError(f"Unrecognized arguments: {shlex.join(args.extra_args)}")


class APIBaseCommand(BaseCommand):
Expand All @@ -65,5 +68,5 @@ def _register(self):
).completer = ProjectNameCompleter() # type: ignore[attr-defined]

def _command(self, args: argparse.Namespace):
configure_logging()
super()._command(args)
self.api = Client.from_config(project_name=args.project)
9 changes: 6 additions & 3 deletions src/dstack/_internal/cli/commands/apply.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import shlex

from argcomplete import FilesCompleter # type: ignore[attr-defined]

Expand All @@ -19,6 +20,7 @@ class ApplyCommand(APIBaseCommand):
NAME = "apply"
DESCRIPTION = "Apply a configuration"
DEFAULT_HELP = False
ACCEPT_EXTRA_ARGS = True

def _register(self):
super()._register()
Expand Down Expand Up @@ -84,13 +86,14 @@ def _command(self, args: argparse.Namespace):
configurator_class = get_apply_configurator_class(configuration.type)
configurator = configurator_class(api_client=self.api)
configurator_parser = configurator.get_parser()
known, unknown = configurator_parser.parse_known_args(args.unknown)
configurator_args, unknown_args = configurator_parser.parse_known_args(args.extra_args)
if unknown_args:
raise CLIError(f"Unrecognized arguments: {shlex.join(unknown_args)}")
configurator.apply_configuration(
conf=configuration,
configuration_path=configuration_path,
command_args=args,
configurator_args=known,
unknown_args=unknown,
configurator_args=configurator_args,
)
except KeyboardInterrupt:
console.print("\nOperation interrupted by user. Exiting...")
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/cli/commands/completion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse

import argcomplete

from dstack._internal.cli.commands import BaseCommand
Expand All @@ -15,6 +17,6 @@ def _register(self):
choices=["bash", "zsh"],
)

def _command(self, args):
def _command(self, args: argparse.Namespace):
super()._command(args)
print(argcomplete.shellcode(["dstack"], shell=args.shell)) # type: ignore[attr-defined]
1 change: 1 addition & 0 deletions src/dstack/_internal/cli/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _register(self):
)

def _command(self, args: argparse.Namespace):
super()._command(args)
config_manager = ConfigManager()
if args.remove:
config_manager.delete_project(args.project)
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/cli/commands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
is_git_repo_url,
register_init_repo_args,
)
from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
from dstack._internal.cli.utils.common import confirm_ask, console, warn
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.repos.remote import RemoteRepo
from dstack._internal.core.services.configs import ConfigManager
Expand Down Expand Up @@ -52,7 +52,7 @@ def _register(self):
)

def _command(self, args: argparse.Namespace):
configure_logging()
super()._command(args)

repo_path: Optional[Path] = None
repo_url: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/cli/commands/offer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _command(self, args: argparse.Namespace):
conf = TaskConfiguration(commands=[":"])

configurator = OfferConfigurator(api_client=self.api)
configurator.apply_args(conf, args, [])
configurator.apply_args(conf, args)
profile = load_profile(Path.cwd(), profile_name=args.profile)

run_spec = RunSpec(
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/cli/commands/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _register(self):
set_default_parser.set_defaults(subfunc=self._set_default)

def _command(self, args: argparse.Namespace):
super()._command(args)
if not hasattr(args, "subfunc"):
args.subfunc = self._list
args.subfunc(args)
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/cli/commands/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
import os
from argparse import Namespace

from dstack._internal import settings
from dstack._internal.cli.commands import BaseCommand
Expand Down Expand Up @@ -53,7 +53,7 @@ def _register(self):
)
self._parser.add_argument("--token", type=str, help="The admin user token")

def _command(self, args: Namespace):
def _command(self, args: argparse.Namespace):
super()._command(args)

if not UVICORN_INSTALLED:
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def main():
argcomplete.autocomplete(parser, always_complete_options=False)

args, unknown_args = parser.parse_known_args()
args.unknown = unknown_args
args.extra_args = unknown_args

try:
check_for_updates()
Expand Down
6 changes: 2 additions & 4 deletions src/dstack/_internal/cli/services/configurators/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os
from abc import ABC, abstractmethod
from typing import Generic, List, TypeVar, Union, cast
from typing import ClassVar, Generic, List, TypeVar, Union, cast

from dstack._internal.cli.services.args import env_var
from dstack._internal.core.errors import ConfigurationError
Expand All @@ -18,7 +18,7 @@


class BaseApplyConfigurator(ABC, Generic[ApplyConfigurationT]):
TYPE: ApplyConfigurationType
TYPE: ClassVar[ApplyConfigurationType]

def __init__(self, api_client: Client):
self.api = api_client
Expand All @@ -30,7 +30,6 @@ def apply_configuration(
configuration_path: str,
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
):
"""
Implements `dstack apply` for a given configuration type.
Expand All @@ -40,7 +39,6 @@ def apply_configuration(
configuration_path: The path to the configuration file.
command_args: The args parsed by `dstack apply`.
configurator_args: The known args parsed by `cls.get_parser()`.
unknown_args: The unknown args after parsing by `cls.get_parser()`.
"""
pass

Expand Down
9 changes: 4 additions & 5 deletions src/dstack/_internal/cli/services/configurators/fleet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import time
from pathlib import Path
from typing import List, Optional
from typing import Optional

from rich.table import Table

Expand Down Expand Up @@ -46,17 +46,16 @@


class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[FleetConfiguration]):
TYPE: ApplyConfigurationType = ApplyConfigurationType.FLEET
TYPE = ApplyConfigurationType.FLEET

def apply_configuration(
self,
conf: FleetConfiguration,
configuration_path: str,
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
):
self.apply_args(conf, configurator_args, unknown_args)
self.apply_args(conf, configurator_args)
profile = load_profile(Path.cwd(), None)
spec = FleetSpec(
configuration=conf,
Expand Down Expand Up @@ -309,7 +308,7 @@ def register_args(cls, parser: argparse.ArgumentParser):
)
cls.register_env_args(configuration_group)

def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace, unknown: List[str]):
def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace):
if args.name:
conf.name = args.name
self.apply_env_vars(conf.env, args)
Expand Down
8 changes: 3 additions & 5 deletions src/dstack/_internal/cli/services/configurators/gateway.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import time
from typing import List

from rich.table import Table

Expand All @@ -27,17 +26,16 @@


class GatewayConfigurator(BaseApplyConfigurator[GatewayConfiguration]):
TYPE: ApplyConfigurationType = ApplyConfigurationType.GATEWAY
TYPE = ApplyConfigurationType.GATEWAY

def apply_configuration(
self,
conf: GatewayConfiguration,
configuration_path: str,
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
unknown_args: List[str],
):
self.apply_args(conf, configurator_args, unknown_args)
self.apply_args(conf, configurator_args)
spec = GatewaySpec(
configuration=conf,
configuration_path=configuration_path,
Expand Down Expand Up @@ -179,7 +177,7 @@ def register_args(cls, parser: argparse.ArgumentParser):
help="The gateway name",
)

def apply_args(self, conf: GatewayConfiguration, args: argparse.Namespace, unknown: List[str]):
def apply_args(self, conf: GatewayConfiguration, args: argparse.Namespace):
if args.name:
conf.name = args.name

Expand Down
Loading
Loading