From c84686ed5101bc873b4b196f6d25e94c47e62779 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Thu, 4 Sep 2025 14:02:11 +0000 Subject: [PATCH] [CLI] Handle unrecognized arguments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * All commands now reject unrecognized arguments * Undocumented `${{ run.args }}` for tasks and services is still supported but requires `--` pseudo-argument: > If you have positional arguments that must begin with `-` > and don’t look like negative numbers, you can insert > the pseudo-argument `'--'` which tells `parse_args()` > that everything after that is a positional argument ``` dstack apply --reuse -- --some=arg --some-option ^^ ``` Fixes: https://github.com/dstackai/dstack/issues/3073 --- src/dstack/_internal/cli/commands/__init__.py | 19 +++-- src/dstack/_internal/cli/commands/apply.py | 9 ++- .../_internal/cli/commands/completion.py | 4 +- src/dstack/_internal/cli/commands/config.py | 1 + src/dstack/_internal/cli/commands/init.py | 4 +- src/dstack/_internal/cli/commands/offer.py | 2 +- src/dstack/_internal/cli/commands/project.py | 1 + src/dstack/_internal/cli/commands/server.py | 4 +- src/dstack/_internal/cli/main.py | 2 +- .../cli/services/configurators/base.py | 6 +- .../cli/services/configurators/fleet.py | 9 +-- .../cli/services/configurators/gateway.py | 8 +- .../cli/services/configurators/run.py | 78 ++++++++++++------- .../cli/services/configurators/volume.py | 8 +- .../cli/services/configurators/test_fleet.py | 4 +- .../cli/services/configurators/test_run.py | 6 +- 16 files changed, 96 insertions(+), 69 deletions(-) diff --git a/src/dstack/_internal/cli/commands/__init__.py b/src/dstack/_internal/cli/commands/__init__.py index 48ee1148e2..da89156e50 100644 --- a/src/dstack/_internal/cli/commands/__init__.py +++ b/src/dstack/_internal/cli/commands/__init__.py @@ -1,20 +1,22 @@ import argparse import os +import shlex from abc import ABC, abstractmethod -from typing import List, Optional +from typing import ClassVar, Optional from rich_argparse import RichHelpFormatter from dstack._internal.cli.services.completion import ProjectNameCompleter -from dstack._internal.cli.utils.common import configure_logging +from dstack._internal.core.errors import CLIError from dstack.api import Client class BaseCommand(ABC): - NAME: str = "name the command" - DESCRIPTION: str = "describe the command" - DEFAULT_HELP: bool = True - ALIASES: Optional[List[str]] = None + NAME: ClassVar[str] = "name the command" + DESCRIPTION: ClassVar[str] = "describe the command" + DEFAULT_HELP: ClassVar[bool] = True + ALIASES: ClassVar[Optional[list[str]]] = None + ACCEPT_EXTRA_ARGS: ClassVar[bool] = False def __init__(self, parser: argparse.ArgumentParser): self._parser = parser @@ -50,7 +52,8 @@ def _register(self): @abstractmethod def _command(self, args: argparse.Namespace): - pass + if not self.ACCEPT_EXTRA_ARGS and args.extra_args: + raise CLIError(f"Unrecognized arguments: {shlex.join(args.extra_args)}") class APIBaseCommand(BaseCommand): @@ -65,5 +68,5 @@ def _register(self): ).completer = ProjectNameCompleter() # type: ignore[attr-defined] def _command(self, args: argparse.Namespace): - configure_logging() + super()._command(args) self.api = Client.from_config(project_name=args.project) diff --git a/src/dstack/_internal/cli/commands/apply.py b/src/dstack/_internal/cli/commands/apply.py index dba99f08ff..ab73ae2f2b 100644 --- a/src/dstack/_internal/cli/commands/apply.py +++ b/src/dstack/_internal/cli/commands/apply.py @@ -1,4 +1,5 @@ import argparse +import shlex from argcomplete import FilesCompleter # type: ignore[attr-defined] @@ -19,6 +20,7 @@ class ApplyCommand(APIBaseCommand): NAME = "apply" DESCRIPTION = "Apply a configuration" DEFAULT_HELP = False + ACCEPT_EXTRA_ARGS = True def _register(self): super()._register() @@ -84,13 +86,14 @@ def _command(self, args: argparse.Namespace): configurator_class = get_apply_configurator_class(configuration.type) configurator = configurator_class(api_client=self.api) configurator_parser = configurator.get_parser() - known, unknown = configurator_parser.parse_known_args(args.unknown) + configurator_args, unknown_args = configurator_parser.parse_known_args(args.extra_args) + if unknown_args: + raise CLIError(f"Unrecognized arguments: {shlex.join(unknown_args)}") configurator.apply_configuration( conf=configuration, configuration_path=configuration_path, command_args=args, - configurator_args=known, - unknown_args=unknown, + configurator_args=configurator_args, ) except KeyboardInterrupt: console.print("\nOperation interrupted by user. Exiting...") diff --git a/src/dstack/_internal/cli/commands/completion.py b/src/dstack/_internal/cli/commands/completion.py index 588a4ce091..3bcdfbcfec 100644 --- a/src/dstack/_internal/cli/commands/completion.py +++ b/src/dstack/_internal/cli/commands/completion.py @@ -1,3 +1,5 @@ +import argparse + import argcomplete from dstack._internal.cli.commands import BaseCommand @@ -15,6 +17,6 @@ def _register(self): choices=["bash", "zsh"], ) - def _command(self, args): + def _command(self, args: argparse.Namespace): super()._command(args) print(argcomplete.shellcode(["dstack"], shell=args.shell)) # type: ignore[attr-defined] diff --git a/src/dstack/_internal/cli/commands/config.py b/src/dstack/_internal/cli/commands/config.py index d9200bace0..adff5c8709 100644 --- a/src/dstack/_internal/cli/commands/config.py +++ b/src/dstack/_internal/cli/commands/config.py @@ -40,6 +40,7 @@ def _register(self): ) def _command(self, args: argparse.Namespace): + super()._command(args) config_manager = ConfigManager() if args.remove: config_manager.delete_project(args.project) diff --git a/src/dstack/_internal/cli/commands/init.py b/src/dstack/_internal/cli/commands/init.py index 7df156bb7b..2bbde987b1 100644 --- a/src/dstack/_internal/cli/commands/init.py +++ b/src/dstack/_internal/cli/commands/init.py @@ -9,7 +9,7 @@ is_git_repo_url, register_init_repo_args, ) -from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn +from dstack._internal.cli.utils.common import confirm_ask, console, warn from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.repos.remote import RemoteRepo from dstack._internal.core.services.configs import ConfigManager @@ -52,7 +52,7 @@ def _register(self): ) def _command(self, args: argparse.Namespace): - configure_logging() + super()._command(args) repo_path: Optional[Path] = None repo_url: Optional[str] = None diff --git a/src/dstack/_internal/cli/commands/offer.py b/src/dstack/_internal/cli/commands/offer.py index fb23361e15..e6d02a0154 100644 --- a/src/dstack/_internal/cli/commands/offer.py +++ b/src/dstack/_internal/cli/commands/offer.py @@ -99,7 +99,7 @@ def _command(self, args: argparse.Namespace): conf = TaskConfiguration(commands=[":"]) configurator = OfferConfigurator(api_client=self.api) - configurator.apply_args(conf, args, []) + configurator.apply_args(conf, args) profile = load_profile(Path.cwd(), profile_name=args.profile) run_spec = RunSpec( diff --git a/src/dstack/_internal/cli/commands/project.py b/src/dstack/_internal/cli/commands/project.py index 2c3ea41314..edcc067097 100644 --- a/src/dstack/_internal/cli/commands/project.py +++ b/src/dstack/_internal/cli/commands/project.py @@ -67,6 +67,7 @@ def _register(self): set_default_parser.set_defaults(subfunc=self._set_default) def _command(self, args: argparse.Namespace): + super()._command(args) if not hasattr(args, "subfunc"): args.subfunc = self._list args.subfunc(args) diff --git a/src/dstack/_internal/cli/commands/server.py b/src/dstack/_internal/cli/commands/server.py index 51ece809d3..ebbc8a1bcf 100644 --- a/src/dstack/_internal/cli/commands/server.py +++ b/src/dstack/_internal/cli/commands/server.py @@ -1,5 +1,5 @@ +import argparse import os -from argparse import Namespace from dstack._internal import settings from dstack._internal.cli.commands import BaseCommand @@ -53,7 +53,7 @@ def _register(self): ) self._parser.add_argument("--token", type=str, help="The admin user token") - def _command(self, args: Namespace): + def _command(self, args: argparse.Namespace): super()._command(args) if not UVICORN_INSTALLED: diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index 81c087fcdf..2ef2979051 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -83,7 +83,7 @@ def main(): argcomplete.autocomplete(parser, always_complete_options=False) args, unknown_args = parser.parse_known_args() - args.unknown = unknown_args + args.extra_args = unknown_args try: check_for_updates() diff --git a/src/dstack/_internal/cli/services/configurators/base.py b/src/dstack/_internal/cli/services/configurators/base.py index 35414801ee..c2d88b565a 100644 --- a/src/dstack/_internal/cli/services/configurators/base.py +++ b/src/dstack/_internal/cli/services/configurators/base.py @@ -1,7 +1,7 @@ import argparse import os from abc import ABC, abstractmethod -from typing import Generic, List, TypeVar, Union, cast +from typing import ClassVar, Generic, List, TypeVar, Union, cast from dstack._internal.cli.services.args import env_var from dstack._internal.core.errors import ConfigurationError @@ -18,7 +18,7 @@ class BaseApplyConfigurator(ABC, Generic[ApplyConfigurationT]): - TYPE: ApplyConfigurationType + TYPE: ClassVar[ApplyConfigurationType] def __init__(self, api_client: Client): self.api = api_client @@ -30,7 +30,6 @@ def apply_configuration( configuration_path: str, command_args: argparse.Namespace, configurator_args: argparse.Namespace, - unknown_args: List[str], ): """ Implements `dstack apply` for a given configuration type. @@ -40,7 +39,6 @@ def apply_configuration( configuration_path: The path to the configuration file. command_args: The args parsed by `dstack apply`. configurator_args: The known args parsed by `cls.get_parser()`. - unknown_args: The unknown args after parsing by `cls.get_parser()`. """ pass diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index a250058bc3..0dfb30ef20 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -1,7 +1,7 @@ import argparse import time from pathlib import Path -from typing import List, Optional +from typing import Optional from rich.table import Table @@ -46,7 +46,7 @@ class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[FleetConfiguration]): - TYPE: ApplyConfigurationType = ApplyConfigurationType.FLEET + TYPE = ApplyConfigurationType.FLEET def apply_configuration( self, @@ -54,9 +54,8 @@ def apply_configuration( configuration_path: str, command_args: argparse.Namespace, configurator_args: argparse.Namespace, - unknown_args: List[str], ): - self.apply_args(conf, configurator_args, unknown_args) + self.apply_args(conf, configurator_args) profile = load_profile(Path.cwd(), None) spec = FleetSpec( configuration=conf, @@ -309,7 +308,7 @@ def register_args(cls, parser: argparse.ArgumentParser): ) cls.register_env_args(configuration_group) - def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace, unknown: List[str]): + def apply_args(self, conf: FleetConfiguration, args: argparse.Namespace): if args.name: conf.name = args.name self.apply_env_vars(conf.env, args) diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py index 7d26e220ba..43d4460a17 100644 --- a/src/dstack/_internal/cli/services/configurators/gateway.py +++ b/src/dstack/_internal/cli/services/configurators/gateway.py @@ -1,6 +1,5 @@ import argparse import time -from typing import List from rich.table import Table @@ -27,7 +26,7 @@ class GatewayConfigurator(BaseApplyConfigurator[GatewayConfiguration]): - TYPE: ApplyConfigurationType = ApplyConfigurationType.GATEWAY + TYPE = ApplyConfigurationType.GATEWAY def apply_configuration( self, @@ -35,9 +34,8 @@ def apply_configuration( configuration_path: str, command_args: argparse.Namespace, configurator_args: argparse.Namespace, - unknown_args: List[str], ): - self.apply_args(conf, configurator_args, unknown_args) + self.apply_args(conf, configurator_args) spec = GatewaySpec( configuration=conf, configuration_path=configuration_path, @@ -179,7 +177,7 @@ def register_args(cls, parser: argparse.ArgumentParser): help="The gateway name", ) - def apply_args(self, conf: GatewayConfiguration, args: argparse.Namespace, unknown: List[str]): + def apply_args(self, conf: GatewayConfiguration, args: argparse.Namespace): if args.name: conf.name = args.name diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index cdbd324a6b..0403a57a64 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -1,4 +1,5 @@ import argparse +import shlex import subprocess import sys import time @@ -35,6 +36,7 @@ LEGACY_REPO_DIR, AnyRunConfiguration, ApplyConfigurationType, + ConfigurationWithCommandsParams, ConfigurationWithPortsParams, DevEnvironmentConfiguration, PortMapping, @@ -80,20 +82,17 @@ class BaseRunConfigurator( ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[RunConfigurationT], ): - TYPE: ApplyConfigurationType - def apply_configuration( self, conf: RunConfigurationT, configuration_path: str, command_args: argparse.Namespace, configurator_args: argparse.Namespace, - unknown_args: List[str], ): if configurator_args.repo and configurator_args.no_repo: raise CLIError("Either --repo or --no-repo can be specified") - self.apply_args(conf, configurator_args, unknown_args) + self.apply_args(conf, configurator_args) self.validate_gpu_vendor_and_image(conf) self.validate_cpu_arch_and_image(conf) @@ -395,7 +394,7 @@ def register_args(cls, parser: argparse.ArgumentParser): ) register_init_repo_args(repo_group) - def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace, unknown: List[str]): + def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace): apply_profile_args(args, conf) if args.run_name: conf.name = args.run_name @@ -408,16 +407,6 @@ def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace, unknown: self.apply_env_vars(conf.env, args) self.interpolate_env(conf) - self.interpolate_run_args(conf.setup, unknown) - - def interpolate_run_args(self, value: List[str], unknown): - run_args = " ".join(unknown) - interpolator = VariablesInterpolator({"run": {"args": run_args}}, skip=["secrets"]) - try: - for i in range(len(value)): - value[i] = interpolator.interpolate_or_error(value[i]) - except InterpolatorError as e: - raise ConfigurationError(e.args[0]) def interpolate_env(self, conf: RunConfigurationT): env_dict = conf.env.as_dict() @@ -701,18 +690,50 @@ def apply_ports_args( conf.ports = list(_merge_ports(conf.ports, args.ports).values()) -class TaskConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigurator): +class RunWithCommandsConfiguratorMixin: + @classmethod + def register_commands_args(cls, parser: argparse.ArgumentParser): + parser.add_argument( + "run_args", + help=( + "Run arguments. Available in the configuration [code]commands[/code] as" + " [code]${{ run.args }}[/code]." + " Use [code]--[/code] to separate run options from [code]dstack[/code] options" + ), + nargs="*", + metavar="RUN_ARGS", + ) + + def apply_commands_args( + self, + conf: ConfigurationWithCommandsParams, + args: argparse.Namespace, + ): + commands = conf.commands + run_args = shlex.join(args.run_args) + interpolator = VariablesInterpolator({"run": {"args": run_args}}, skip=["secrets"]) + try: + for i, command in enumerate(commands): + commands[i] = interpolator.interpolate_or_error(command) + except InterpolatorError as e: + raise ConfigurationError(e.args[0]) + + +class TaskConfigurator( + RunWithPortsConfiguratorMixin, RunWithCommandsConfiguratorMixin, BaseRunConfigurator +): TYPE = ApplyConfigurationType.TASK @classmethod def register_args(cls, parser: argparse.ArgumentParser): super().register_args(parser) cls.register_ports_args(parser) + cls.register_commands_args(parser) - def apply_args(self, conf: TaskConfiguration, args: argparse.Namespace, unknown: List[str]): - super().apply_args(conf, args, unknown) + def apply_args(self, conf: TaskConfiguration, args: argparse.Namespace): + super().apply_args(conf, args) self.apply_ports_args(conf, args) - self.interpolate_run_args(conf.commands, unknown) + self.apply_commands_args(conf, args) class DevEnvironmentConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigurator): @@ -723,10 +744,8 @@ def register_args(cls, parser: argparse.ArgumentParser): super().register_args(parser) cls.register_ports_args(parser) - def apply_args( - self, conf: DevEnvironmentConfiguration, args: argparse.Namespace, unknown: List[str] - ): - super().apply_args(conf, args, unknown) + def apply_args(self, conf: DevEnvironmentConfiguration, args: argparse.Namespace): + super().apply_args(conf, args) self.apply_ports_args(conf, args) if conf.ide == "vscode" and conf.version is None: conf.version = _detect_vscode_version() @@ -746,12 +765,17 @@ def apply_args( ) -class ServiceConfigurator(BaseRunConfigurator): +class ServiceConfigurator(RunWithCommandsConfiguratorMixin, BaseRunConfigurator): TYPE = ApplyConfigurationType.SERVICE - def apply_args(self, conf: ServiceConfiguration, args: argparse.Namespace, unknown: List[str]): - super().apply_args(conf, args, unknown) - self.interpolate_run_args(conf.commands, unknown) + @classmethod + def register_args(cls, parser: argparse.ArgumentParser): + super().register_args(parser) + cls.register_commands_args(parser) + + def apply_args(self, conf: TaskConfiguration, args: argparse.Namespace): + super().apply_args(conf, args) + self.apply_commands_args(conf, args) def _merge_ports(conf: List[PortMapping], args: List[PortMapping]) -> Dict[int, PortMapping]: diff --git a/src/dstack/_internal/cli/services/configurators/volume.py b/src/dstack/_internal/cli/services/configurators/volume.py index b0e25e503c..624c2080c8 100644 --- a/src/dstack/_internal/cli/services/configurators/volume.py +++ b/src/dstack/_internal/cli/services/configurators/volume.py @@ -1,6 +1,5 @@ import argparse import time -from typing import List from rich.table import Table @@ -26,7 +25,7 @@ class VolumeConfigurator(BaseApplyConfigurator[VolumeConfiguration]): - TYPE: ApplyConfigurationType = ApplyConfigurationType.VOLUME + TYPE = ApplyConfigurationType.VOLUME def apply_configuration( self, @@ -34,9 +33,8 @@ def apply_configuration( configuration_path: str, command_args: argparse.Namespace, configurator_args: argparse.Namespace, - unknown_args: List[str], ): - self.apply_args(conf, configurator_args, unknown_args) + self.apply_args(conf, configurator_args) spec = VolumeSpec( configuration=conf, configuration_path=configuration_path, @@ -167,7 +165,7 @@ def register_args(cls, parser: argparse.ArgumentParser): help="The volume name", ) - def apply_args(self, conf: VolumeConfiguration, args: argparse.Namespace, unknown: List[str]): + def apply_args(self, conf: VolumeConfiguration, args: argparse.Namespace): if args.name: conf.name = args.name diff --git a/src/tests/_internal/cli/services/configurators/test_fleet.py b/src/tests/_internal/cli/services/configurators/test_fleet.py index 91d78271a8..a14b5c7ac5 100644 --- a/src/tests/_internal/cli/services/configurators/test_fleet.py +++ b/src/tests/_internal/cli/services/configurators/test_fleet.py @@ -50,6 +50,6 @@ def apply_args( configurator = FleetConfigurator(Mock()) configurator.register_args(parser) conf = conf.copy(deep=True) - configurator_args, unknown_args = parser.parse_known_args(args) - configurator.apply_args(conf, configurator_args, unknown_args) + configurator_args = parser.parse_args(args) + configurator.apply_args(conf, configurator_args) return conf, configurator_args diff --git a/src/tests/_internal/cli/services/configurators/test_run.py b/src/tests/_internal/cli/services/configurators/test_run.py index 14959f5845..eb5027671a 100644 --- a/src/tests/_internal/cli/services/configurators/test_run.py +++ b/src/tests/_internal/cli/services/configurators/test_run.py @@ -34,9 +34,9 @@ def apply_args( configurator = configurator_class(Mock()) configurator.register_args(parser) conf = conf.copy(deep=True) # to avoid modifying the original configuration - known, unknown = parser.parse_known_args(args) - configurator.apply_args(conf, known, unknown) - return conf, known + parsed_args = parser.parse_args(args) + configurator.apply_args(conf, parsed_args) + return conf, parsed_args def test_env(self): conf = TaskConfiguration(commands=["whoami"])