diff --git a/.github/workflows/build-artifacts.yml b/.github/workflows/build-artifacts.yml
index ecb686bd66..12c2d1884c 100644
--- a/.github/workflows/build-artifacts.yml
+++ b/.github/workflows/build-artifacts.yml
@@ -73,6 +73,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: uv sync --all-extras
+ - name: Run pyright
+ uses: jakebailey/pyright-action@v2
+ with:
+ pylance-version: latest-release
- name: Download frontend build
uses: actions/download-artifact@v4
with:
diff --git a/contributing/DEVELOPMENT.md b/contributing/DEVELOPMENT.md
index 775a0691e4..4428f811b3 100644
--- a/contributing/DEVELOPMENT.md
+++ b/contributing/DEVELOPMENT.md
@@ -25,12 +25,26 @@ uv sync --all-extras
Alternatively, if you want to manage virtual environments by yourself, you can install `dstack` into the activated virtual environment with `uv sync --all-extras --active`.
-## 4. (Recommended) Install pre-commits:
+## 4. (Recommended) Install pre-commit hooks:
+
+Code formatting and linting can be done automatically on each commit with `pre-commit` hooks:
```shell
uv run pre-commit install
```
-## 5. Frontend
+## 5. (Recommended) Use pyright:
+
+The CI runs `pyright` for type checking `dstack` Python code.
+So we recommend you configure your IDE to use `pyright`/`pylance` with `standard` type checking mode.
+
+You can also install `pyright` and run it from the CLI:
+
+```shell
+uv tool install pyright
+pyright -p .
+```
+
+## 6. Frontend
See [FRONTEND.md](FRONTEND.md) for the details on how to build and develop the frontend.
diff --git a/pyproject.toml b/pyproject.toml
index a14ac0cbbf..6d184cbfe4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -78,6 +78,17 @@ pattern = '\s*|]*>\s*|\s*|\
replacement = ''
ignore-case = true
+[tool.pyright]
+include = [
+ "src/dstack/plugins",
+ "src/dstack/_internal/server",
+ "src/dstack/_internal/core/services",
+ "src/dstack/_internal/cli/services/configurators",
+]
+ignore = [
+ "src/dstack/_internal/server/migrations/versions",
+]
+
[dependency-groups]
dev = [
"httpx>=0.28.1",
diff --git a/src/dstack/_internal/cli/commands/offer.py b/src/dstack/_internal/cli/commands/offer.py
index 0201ddc21a..fb23361e15 100644
--- a/src/dstack/_internal/cli/commands/offer.py
+++ b/src/dstack/_internal/cli/commands/offer.py
@@ -3,7 +3,11 @@
from typing import List
from dstack._internal.cli.commands import APIBaseCommand
-from dstack._internal.cli.services.configurators.run import BaseRunConfigurator
+from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec
+from dstack._internal.cli.services.configurators.run import (
+ BaseRunConfigurator,
+)
+from dstack._internal.cli.services.profile import register_profile_args
from dstack._internal.cli.utils.common import console
from dstack._internal.cli.utils.gpu import print_gpu_json, print_gpu_table
from dstack._internal.cli.utils.run import print_offers_json, print_run_plan
@@ -18,11 +22,8 @@ class OfferConfigurator(BaseRunConfigurator):
TYPE = ApplyConfigurationType.TASK
@classmethod
- def register_args(
- cls,
- parser: argparse.ArgumentParser,
- ):
- super().register_args(parser, default_max_offers=50)
+ def register_args(cls, parser: argparse.ArgumentParser):
+ configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options")
parser.add_argument(
"--group-by",
action="append",
@@ -33,6 +34,43 @@ def register_args(
"Can be repeated or comma-separated (e.g. [code]--group-by gpu,backend[/code])."
),
)
+ configuration_group.add_argument(
+ "-n",
+ "--name",
+ dest="run_name",
+ help="The name of the run. If not specified, a random name is assigned",
+ )
+ configuration_group.add_argument(
+ "--max-offers",
+ help="Number of offers to show in the run plan",
+ type=int,
+ default=50,
+ )
+ cls.register_env_args(configuration_group)
+ configuration_group.add_argument(
+ "--cpu",
+ type=cpu_spec,
+ help="Request CPU for the run. "
+ "The format is [code]ARCH[/]:[code]COUNT[/] (all parts are optional)",
+ dest="cpu_spec",
+ metavar="SPEC",
+ )
+ configuration_group.add_argument(
+ "--gpu",
+ type=gpu_spec,
+ help="Request GPU for the run. "
+ "The format is [code]NAME[/]:[code]COUNT[/]:[code]MEMORY[/] (all parts are optional)",
+ dest="gpu_spec",
+ metavar="SPEC",
+ )
+ configuration_group.add_argument(
+ "--disk",
+ type=disk_spec,
+ help="Request the size range of disk for the run. Example [code]--disk 100GB..[/].",
+ metavar="RANGE",
+ dest="disk_spec",
+ )
+ register_profile_args(parser)
class OfferCommand(APIBaseCommand):
@@ -117,7 +155,7 @@ def _process_group_by_args(self, group_by_args: List[str]) -> List[str]:
return processed
- def _list_gpus(self, args: List[str], run_spec: RunSpec) -> List[GpuGroup]:
+ def _list_gpus(self, args: argparse.Namespace, run_spec: RunSpec) -> List[GpuGroup]:
group_by = [g for g in args.group_by if g != "gpu"] or None
return self.api.client.gpus.list_gpus(
self.api.project,
diff --git a/src/dstack/_internal/cli/services/configurators/__init__.py b/src/dstack/_internal/cli/services/configurators/__init__.py
index cba23ee31a..91768bdcd3 100644
--- a/src/dstack/_internal/cli/services/configurators/__init__.py
+++ b/src/dstack/_internal/cli/services/configurators/__init__.py
@@ -24,7 +24,9 @@
APPLY_STDIN_NAME = "-"
-apply_configurators_mapping: Dict[ApplyConfigurationType, Type[BaseApplyConfigurator]] = {
+apply_configurators_mapping: Dict[
+ ApplyConfigurationType, Type[BaseApplyConfigurator[AnyApplyConfiguration]]
+] = {
cls.TYPE: cls
for cls in [
DevEnvironmentConfigurator,
@@ -47,7 +49,9 @@
}
-def get_apply_configurator_class(configurator_type: str) -> Type[BaseApplyConfigurator]:
+def get_apply_configurator_class(
+ configurator_type: str,
+) -> Type[BaseApplyConfigurator[AnyApplyConfiguration]]:
return apply_configurators_mapping[ApplyConfigurationType(configurator_type)]
diff --git a/src/dstack/_internal/cli/services/configurators/base.py b/src/dstack/_internal/cli/services/configurators/base.py
index 39e34693eb..440a31d6c2 100644
--- a/src/dstack/_internal/cli/services/configurators/base.py
+++ b/src/dstack/_internal/cli/services/configurators/base.py
@@ -1,7 +1,7 @@
import argparse
import os
from abc import ABC, abstractmethod
-from typing import List, Optional, Union, cast
+from typing import Generic, List, Optional, TypeVar, Union, cast
from dstack._internal.cli.services.args import env_var
from dstack._internal.core.errors import ConfigurationError
@@ -15,8 +15,10 @@
ArgsParser = Union[argparse._ArgumentGroup, argparse.ArgumentParser]
+ApplyConfigurationT = TypeVar("ApplyConfigurationT", bound=AnyApplyConfiguration)
-class BaseApplyConfigurator(ABC):
+
+class BaseApplyConfigurator(ABC, Generic[ApplyConfigurationT]):
TYPE: ApplyConfigurationType
def __init__(self, api_client: Client):
@@ -25,7 +27,7 @@ def __init__(self, api_client: Client):
@abstractmethod
def apply_configuration(
self,
- conf: AnyApplyConfiguration,
+ conf: ApplyConfigurationT,
configuration_path: str,
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
@@ -48,7 +50,7 @@ def apply_configuration(
@abstractmethod
def delete_configuration(
self,
- conf: AnyApplyConfiguration,
+ conf: ApplyConfigurationT,
configuration_path: str,
command_args: argparse.Namespace,
):
diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py
index cb9d7a2b87..6718f4f0f2 100644
--- a/src/dstack/_internal/cli/services/configurators/fleet.py
+++ b/src/dstack/_internal/cli/services/configurators/fleet.py
@@ -46,7 +46,7 @@
logger = get_logger(__name__)
-class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
+class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[FleetConfiguration]):
TYPE: ApplyConfigurationType = ApplyConfigurationType.FLEET
def apply_configuration(
diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py
index 8651c79ce8..8a22277b17 100644
--- a/src/dstack/_internal/cli/services/configurators/gateway.py
+++ b/src/dstack/_internal/cli/services/configurators/gateway.py
@@ -27,7 +27,7 @@
from dstack.api._public import Client
-class GatewayConfigurator(BaseApplyConfigurator):
+class GatewayConfigurator(BaseApplyConfigurator[GatewayConfiguration]):
TYPE: ApplyConfigurationType = ApplyConfigurationType.GATEWAY
def apply_configuration(
diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py
index 185ab65688..02f02f7b42 100644
--- a/src/dstack/_internal/cli/services/configurators/run.py
+++ b/src/dstack/_internal/cli/services/configurators/run.py
@@ -3,7 +3,7 @@
import sys
import time
from pathlib import Path
-from typing import Dict, List, Optional, Set
+from typing import Dict, List, Optional, Set, TypeVar
import gpuhunt
from pydantic import parse_obj_as
@@ -33,8 +33,7 @@
from dstack._internal.core.models.configurations import (
AnyRunConfiguration,
ApplyConfigurationType,
- BaseRunConfiguration,
- BaseRunConfigurationWithPorts,
+ ConfigurationWithPortsParams,
DevEnvironmentConfiguration,
PortMapping,
RunConfigurationType,
@@ -63,13 +62,18 @@
logger = get_logger(__name__)
+RunConfigurationT = TypeVar("RunConfigurationT", bound=AnyRunConfiguration)
-class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator):
+
+class BaseRunConfigurator(
+ ApplyEnvVarsConfiguratorMixin,
+ BaseApplyConfigurator[RunConfigurationT],
+):
TYPE: ApplyConfigurationType
def apply_configuration(
self,
- conf: BaseRunConfiguration,
+ conf: RunConfigurationT,
configuration_path: str,
command_args: argparse.Namespace,
configurator_args: argparse.Namespace,
@@ -267,7 +271,7 @@ def apply_configuration(
def delete_configuration(
self,
- conf: AnyRunConfiguration,
+ conf: RunConfigurationT,
configuration_path: str,
command_args: argparse.Namespace,
):
@@ -293,7 +297,7 @@ def delete_configuration(
console.print(f"Run [code]{conf.name}[/] deleted")
@classmethod
- def register_args(cls, parser: argparse.ArgumentParser, default_max_offers: int = 3):
+ def register_args(cls, parser: argparse.ArgumentParser):
configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options")
configuration_group.add_argument(
"-n",
@@ -305,7 +309,7 @@ def register_args(cls, parser: argparse.ArgumentParser, default_max_offers: int
"--max-offers",
help="Number of offers to show in the run plan",
type=int,
- default=default_max_offers,
+ default=3,
)
cls.register_env_args(configuration_group)
configuration_group.add_argument(
@@ -333,7 +337,7 @@ def register_args(cls, parser: argparse.ArgumentParser, default_max_offers: int
)
register_profile_args(parser)
- def apply_args(self, conf: BaseRunConfiguration, args: argparse.Namespace, unknown: List[str]):
+ def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace, unknown: List[str]):
apply_profile_args(args, conf)
if args.run_name:
conf.name = args.run_name
@@ -357,7 +361,7 @@ def interpolate_run_args(self, value: List[str], unknown):
except InterpolatorError as e:
raise ConfigurationError(e.args[0])
- def interpolate_env(self, conf: BaseRunConfiguration):
+ def interpolate_env(self, conf: RunConfigurationT):
env_dict = conf.env.as_dict()
interpolator = VariablesInterpolator({"env": env_dict}, skip=["secrets"])
try:
@@ -377,7 +381,7 @@ def interpolate_env(self, conf: BaseRunConfiguration):
except InterpolatorError as e:
raise ConfigurationError(e.args[0])
- def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
+ def validate_gpu_vendor_and_image(self, conf: RunConfigurationT) -> None:
"""
Infers and sets `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD.
"""
@@ -438,7 +442,7 @@ def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
"`image` is required if `resources.gpu.vendor` is `tenstorrent`"
)
- def validate_cpu_arch_and_image(self, conf: BaseRunConfiguration) -> None:
+ def validate_cpu_arch_and_image(self, conf: RunConfigurationT) -> None:
"""
Infers `resources.cpu.arch` if not set, requires `image` if the architecture is ARM.
"""
@@ -462,10 +466,9 @@ def validate_cpu_arch_and_image(self, conf: BaseRunConfiguration) -> None:
raise ConfigurationError("`image` is required if `resources.cpu.arch` is `arm`")
-class RunWithPortsConfigurator(BaseRunConfigurator):
+class RunWithPortsConfiguratorMixin:
@classmethod
- def register_args(cls, parser: argparse.ArgumentParser):
- super().register_args(parser)
+ def register_ports_args(cls, parser: argparse.ArgumentParser):
parser.add_argument(
"-p",
"--port",
@@ -482,29 +485,42 @@ def register_args(cls, parser: argparse.ArgumentParser):
metavar="HOST",
)
- def apply_args(
- self, conf: BaseRunConfigurationWithPorts, args: argparse.Namespace, unknown: List[str]
+ def apply_ports_args(
+ self,
+ conf: ConfigurationWithPortsParams,
+ args: argparse.Namespace,
):
- super().apply_args(conf, args, unknown)
if args.ports:
conf.ports = list(_merge_ports(conf.ports, args.ports).values())
-class TaskConfigurator(RunWithPortsConfigurator):
+class TaskConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigurator):
TYPE = ApplyConfigurationType.TASK
+ @classmethod
+ def register_args(cls, parser: argparse.ArgumentParser):
+ super().register_args(parser)
+ cls.register_ports_args(parser)
+
def apply_args(self, conf: TaskConfiguration, args: argparse.Namespace, unknown: List[str]):
super().apply_args(conf, args, unknown)
+ self.apply_ports_args(conf, args)
self.interpolate_run_args(conf.commands, unknown)
-class DevEnvironmentConfigurator(RunWithPortsConfigurator):
+class DevEnvironmentConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigurator):
TYPE = ApplyConfigurationType.DEV_ENVIRONMENT
+ @classmethod
+ def register_args(cls, parser: argparse.ArgumentParser):
+ super().register_args(parser)
+ cls.register_ports_args(parser)
+
def apply_args(
self, conf: DevEnvironmentConfiguration, args: argparse.Namespace, unknown: List[str]
):
super().apply_args(conf, args, unknown)
+ self.apply_ports_args(conf, args)
if conf.ide == "vscode" and conf.version is None:
conf.version = _detect_vscode_version()
if conf.version is None:
@@ -674,6 +690,8 @@ def render_run_spec_diff(old_spec: RunSpec, new_spec: RunSpec) -> Optional[str]:
if type(old_spec.profile) is not type(new_spec.profile):
item = NestedListItem("Profile")
else:
+ assert old_spec.profile is not None
+ assert new_spec.profile is not None
item = NestedListItem(
"Profile properties:",
children=[
diff --git a/src/dstack/_internal/cli/services/configurators/volume.py b/src/dstack/_internal/cli/services/configurators/volume.py
index 2a085477ed..72b21e5bb4 100644
--- a/src/dstack/_internal/cli/services/configurators/volume.py
+++ b/src/dstack/_internal/cli/services/configurators/volume.py
@@ -26,7 +26,7 @@
from dstack.api._public import Client
-class VolumeConfigurator(BaseApplyConfigurator):
+class VolumeConfigurator(BaseApplyConfigurator[VolumeConfiguration]):
TYPE: ApplyConfigurationType = ApplyConfigurationType.VOLUME
def apply_configuration(
diff --git a/src/dstack/_internal/cli/services/profile.py b/src/dstack/_internal/cli/services/profile.py
index d57ea2e130..6340719bd2 100644
--- a/src/dstack/_internal/cli/services/profile.py
+++ b/src/dstack/_internal/cli/services/profile.py
@@ -159,7 +159,7 @@ def apply_profile_args(
if args.idle_duration is not None:
profile_settings.idle_duration = args.idle_duration
elif args.dont_destroy:
- profile_settings.idle_duration = "off"
+ profile_settings.idle_duration = -1
if args.creation_policy_reuse:
profile_settings.creation_policy = CreationPolicy.REUSE
diff --git a/src/dstack/_internal/core/backends/base/configurator.py b/src/dstack/_internal/core/backends/base/configurator.py
index 994266c438..f31e978a31 100644
--- a/src/dstack/_internal/core/backends/base/configurator.py
+++ b/src/dstack/_internal/core/backends/base/configurator.py
@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
-from typing import Any, ClassVar, List, Optional
+from typing import Any, ClassVar, List, Literal, Optional, overload
from uuid import UUID
from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.models import (
AnyBackendConfig,
AnyBackendConfigWithCreds,
+ AnyBackendConfigWithoutCreds,
)
from dstack._internal.core.errors import BackendInvalidCredentialsError
from dstack._internal.core.models.backends.base import BackendType
@@ -77,6 +78,18 @@ def create_backend(
"""
pass
+ @overload
+ def get_backend_config(
+ self, record: StoredBackendRecord, include_creds: Literal[False]
+ ) -> AnyBackendConfigWithoutCreds:
+ pass
+
+ @overload
+ def get_backend_config(
+ self, record: StoredBackendRecord, include_creds: Literal[True]
+ ) -> AnyBackendConfigWithCreds:
+ pass
+
@abstractmethod
def get_backend_config(
self, record: StoredBackendRecord, include_creds: bool
diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py
index 4c4e45fd09..772da55274 100644
--- a/src/dstack/_internal/core/models/common.py
+++ b/src/dstack/_internal/core/models/common.py
@@ -102,12 +102,12 @@ class RegistryAuth(CoreModel):
password (str): The password or access token
"""
- class Config(CoreModel.Config):
- frozen = True
-
username: Annotated[str, Field(description="The username")]
password: Annotated[str, Field(description="The password or access token")]
+ class Config(CoreModel.Config):
+ frozen = True
+
class ApplyAction(str, Enum):
CREATE = "create" # resource is to be created or overridden
diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py
index ee0ec61b5f..4b92d6f82b 100644
--- a/src/dstack/_internal/core/models/configurations.py
+++ b/src/dstack/_internal/core/models/configurations.py
@@ -221,7 +221,7 @@ class ProbeConfig(CoreModel):
),
] = None
timeout: Annotated[
- Optional[Union[int, str]],
+ Optional[int],
Field(
description=(
f"Maximum amount of time the HTTP request is allowed to take. Defaults to `{DEFAULT_PROBE_TIMEOUT}s`"
@@ -229,7 +229,7 @@ class ProbeConfig(CoreModel):
),
] = None
interval: Annotated[
- Optional[Union[int, str]],
+ Optional[int],
Field(
description=(
"Minimum amount of time between the end of one probe execution"
@@ -249,7 +249,19 @@ class ProbeConfig(CoreModel):
),
] = None
- @validator("timeout")
+ class Config(CoreModel.Config):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ add_extra_schema_types(
+ schema["properties"]["timeout"],
+ extra_types=[{"type": "string"}],
+ )
+ add_extra_schema_types(
+ schema["properties"]["interval"],
+ extra_types=[{"type": "string"}],
+ )
+
+ @validator("timeout", pre=True)
def parse_timeout(cls, v: Optional[Union[int, str]]) -> Optional[int]:
if v is None:
return v
@@ -258,7 +270,7 @@ def parse_timeout(cls, v: Optional[Union[int, str]]) -> Optional[int]:
raise ValueError(f"Probe timeout cannot be shorter than {MIN_PROBE_TIMEOUT}s")
return parsed
- @validator("interval")
+ @validator("interval", pre=True)
def parse_interval(cls, v: Optional[Union[int, str]]) -> Optional[int]:
if v is None:
return v
@@ -373,9 +385,7 @@ class BaseRunConfiguration(CoreModel):
),
),
] = None
- volumes: Annotated[
- List[Union[MountPoint, str]], Field(description="The volumes mount points")
- ] = []
+ volumes: Annotated[List[MountPoint], Field(description="The volumes mount points")] = []
docker: Annotated[
Optional[bool],
Field(
@@ -383,12 +393,24 @@ class BaseRunConfiguration(CoreModel):
),
] = None
files: Annotated[
- list[Union[FilePathMapping, str]],
+ list[FilePathMapping],
Field(description="The local to container file path mappings"),
] = []
# deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init`
setup: CommandsList = []
+ class Config(CoreModel.Config):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ add_extra_schema_types(
+ schema["properties"]["volumes"]["items"],
+ extra_types=[{"type": "string"}],
+ )
+ add_extra_schema_types(
+ schema["properties"]["files"]["items"],
+ extra_types=[{"type": "string"}],
+ )
+
@validator("python", pre=True, always=True)
def convert_python(cls, v, values) -> Optional[PythonVersion]:
if v is not None and values.get("image"):
@@ -413,14 +435,14 @@ def _docker(cls, v, values) -> Optional[bool]:
# but it's not possible to do so without breaking backwards compatibility.
return v
- @validator("volumes", each_item=True)
- def convert_volumes(cls, v) -> MountPoint:
+ @validator("volumes", each_item=True, pre=True)
+ def convert_volumes(cls, v: Union[MountPoint, str]) -> MountPoint:
if isinstance(v, str):
return parse_mount_point(v)
return v
- @validator("files", each_item=True)
- def convert_files(cls, v) -> FilePathMapping:
+ @validator("files", each_item=True, pre=True)
+ def convert_files(cls, v: Union[FilePathMapping, str]) -> FilePathMapping:
if isinstance(v, str):
return FilePathMapping.parse(v)
return v
@@ -444,7 +466,7 @@ def validate_shell(cls, v) -> Optional[str]:
raise ValueError("The value must be `sh`, `bash`, or an absolute path")
-class BaseRunConfigurationWithPorts(BaseRunConfiguration):
+class ConfigurationWithPortsParams(CoreModel):
ports: Annotated[
List[Union[ValidPort, constr(regex=r"^(?:[0-9]+|\*):[0-9]+$"), PortMapping]],
Field(description="Port numbers/mapping to expose"),
@@ -459,7 +481,7 @@ def convert_ports(cls, v) -> PortMapping:
return v
-class BaseRunConfigurationWithCommands(BaseRunConfiguration):
+class ConfigurationWithCommandsParams(CoreModel):
commands: Annotated[CommandsList, Field(description="The shell commands to run")] = []
@root_validator
@@ -503,10 +525,25 @@ def parse_inactivity_duration(
class DevEnvironmentConfiguration(
- ProfileParams, BaseRunConfigurationWithPorts, DevEnvironmentConfigurationParams
+ ProfileParams,
+ BaseRunConfiguration,
+ ConfigurationWithPortsParams,
+ DevEnvironmentConfigurationParams,
):
type: Literal["dev-environment"] = "dev-environment"
+ class Config(ProfileParams.Config, BaseRunConfiguration.Config):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ ProfileParams.Config.schema_extra(schema)
+ BaseRunConfiguration.Config.schema_extra(schema)
+
+ @validator("entrypoint")
+ def validate_entrypoint(cls, v: Optional[str]) -> Optional[str]:
+ if v is not None:
+ raise ValueError("entrypoint is not supported for dev-environment")
+ return v
+
class TaskConfigurationParams(CoreModel):
nodes: Annotated[int, Field(description="Number of nodes", ge=1)] = 1
@@ -514,12 +551,19 @@ class TaskConfigurationParams(CoreModel):
class TaskConfiguration(
ProfileParams,
- BaseRunConfigurationWithCommands,
- BaseRunConfigurationWithPorts,
+ BaseRunConfiguration,
+ ConfigurationWithCommandsParams,
+ ConfigurationWithPortsParams,
TaskConfigurationParams,
):
type: Literal["task"] = "task"
+ class Config(ProfileParams.Config, BaseRunConfiguration.Config):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ ProfileParams.Config.schema_extra(schema)
+ BaseRunConfiguration.Config.schema_extra(schema)
+
class ServiceConfigurationParams(CoreModel):
port: Annotated[
@@ -547,7 +591,7 @@ class ServiceConfigurationParams(CoreModel):
),
] = STRIP_PREFIX_DEFAULT
model: Annotated[
- Optional[Union[AnyModel, str]],
+ Optional[AnyModel],
Field(
description=(
"Mapping of the model for the OpenAI-compatible endpoint provided by `dstack`."
@@ -578,6 +622,18 @@ class ServiceConfigurationParams(CoreModel):
Field(description="List of probes used to determine job health"),
] = []
+ class Config(CoreModel.Config):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ add_extra_schema_types(
+ schema["properties"]["replicas"],
+ extra_types=[{"type": "integer"}, {"type": "string"}],
+ )
+ add_extra_schema_types(
+ schema["properties"]["model"],
+ extra_types=[{"type": "string"}],
+ )
+
@validator("port")
def convert_port(cls, v) -> PortMapping:
if isinstance(v, int):
@@ -586,7 +642,7 @@ def convert_port(cls, v) -> PortMapping:
return PortMapping.parse(v)
return v
- @validator("model")
+ @validator("model", pre=True)
def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]:
if isinstance(v, str):
return OpenAIChatModel(type="chat", name=v, format="openai")
@@ -645,17 +701,23 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
class ServiceConfiguration(
- ProfileParams, BaseRunConfigurationWithCommands, ServiceConfigurationParams
+ ProfileParams,
+ BaseRunConfiguration,
+ ConfigurationWithCommandsParams,
+ ServiceConfigurationParams,
):
type: Literal["service"] = "service"
- class Config(CoreModel.Config):
+ class Config(
+ ProfileParams.Config,
+ BaseRunConfiguration.Config,
+ ServiceConfigurationParams.Config,
+ ):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
- add_extra_schema_types(
- schema["properties"]["replicas"],
- extra_types=[{"type": "integer"}, {"type": "string"}],
- )
+ ProfileParams.Config.schema_extra(schema)
+ BaseRunConfiguration.Config.schema_extra(schema)
+ ServiceConfigurationParams.Config.schema_extra(schema)
AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]
diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py
index 8aaf0d18ee..357f9b5b0c 100644
--- a/src/dstack/_internal/core/models/fleets.py
+++ b/src/dstack/_internal/core/models/fleets.py
@@ -224,7 +224,7 @@ class InstanceGroupParams(CoreModel):
Field(description="The maximum instance price per hour, in dollars", gt=0.0),
] = None
idle_duration: Annotated[
- Optional[Union[Literal["off"], str, int]],
+ Optional[int],
Field(
description="Time to wait before terminating idle instances. Defaults to `5m` for runs and `3d` for fleets. Use `off` for unlimited duration"
),
@@ -243,6 +243,10 @@ def schema_extra(schema: Dict[str, Any], model: Type):
schema["properties"]["nodes"],
extra_types=[{"type": "integer"}, {"type": "string"}],
)
+ add_extra_schema_types(
+ schema["properties"]["idle_duration"],
+ extra_types=[{"type": "string"}],
+ )
_validate_idle_duration = validator("idle_duration", pre=True, allow_reuse=True)(
parse_idle_duration
diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py
index 5572ae25dd..79da7e41cb 100644
--- a/src/dstack/_internal/core/models/profiles.py
+++ b/src/dstack/_internal/core/models/profiles.py
@@ -9,6 +9,7 @@
from dstack._internal.core.models.common import CoreModel, Duration
from dstack._internal.utils.common import list_enum_values_for_annotation
from dstack._internal.utils.cron import validate_cron
+from dstack._internal.utils.json_schema import add_extra_schema_types
from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent
from dstack._internal.utils.tags import tags_validator
@@ -61,15 +62,17 @@ def parse_duration(v: Optional[Union[int, str]]) -> Optional[int]:
return Duration.parse(v)
-def parse_max_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str, int]]:
+def parse_max_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[Literal["off"], int]]:
return parse_off_duration(v)
-def parse_stop_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str, int]]:
+def parse_stop_duration(
+ v: Optional[Union[int, str, bool]],
+) -> Optional[Union[Literal["off"], int]]:
return parse_off_duration(v)
-def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str, int]]:
+def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[Literal["off"], int]]:
if v == "off" or v is False:
return "off"
if v is True:
@@ -77,7 +80,7 @@ def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str
return parse_duration(v)
-def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[Union[str, int]]:
+def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[int]:
if v == "off" or v == -1:
return -1
return parse_duration(v)
@@ -121,10 +124,18 @@ class ProfileRetry(CoreModel):
),
] = None
duration: Annotated[
- Optional[Union[int, str]],
+ Optional[int],
Field(description="The maximum period of retrying the run, e.g., `4h` or `1d`"),
] = None
+ class Config(CoreModel.Config):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ add_extra_schema_types(
+ schema["properties"]["duration"],
+ extra_types=[{"type": "string"}],
+ )
+
_validate_duration = validator("duration", pre=True, allow_reuse=True)(parse_duration)
@root_validator
@@ -151,7 +162,7 @@ class UtilizationPolicy(CoreModel):
),
]
time_window: Annotated[
- Union[int, str],
+ int,
Field(
description=(
"The time window of metric samples taking into account to measure utilization"
@@ -160,6 +171,14 @@ class UtilizationPolicy(CoreModel):
),
]
+ class Config(CoreModel.Config):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ add_extra_schema_types(
+ schema["properties"]["time_window"],
+ extra_types=[{"type": "string"}],
+ )
+
@validator("time_window", pre=True)
def validate_time_window(cls, v: Union[int, str]) -> int:
v = parse_duration(v)
@@ -247,7 +266,7 @@ class ProfileParams(CoreModel):
Field(description="The policy for resubmitting the run. Defaults to `false`"),
] = None
max_duration: Annotated[
- Optional[Union[Literal["off"], str, int, bool]],
+ Optional[Union[Literal["off"], int]],
Field(
description=(
"The maximum duration of a run (e.g., `2h`, `1d`, etc)."
@@ -257,7 +276,7 @@ class ProfileParams(CoreModel):
),
] = None
stop_duration: Annotated[
- Optional[Union[Literal["off"], str, int, bool]],
+ Optional[Union[Literal["off"], int]],
Field(
description=(
"The maximum duration of a run graceful stopping."
@@ -282,7 +301,7 @@ class ProfileParams(CoreModel):
),
] = None
idle_duration: Annotated[
- Optional[Union[Literal["off"], str, int]],
+ Optional[int],
Field(
description=(
"Time to wait before terminating idle instances."
@@ -347,6 +366,18 @@ def schema_extra(schema: Dict[str, Any]) -> None:
del schema["properties"]["retry_policy"]
del schema["properties"]["termination_policy"]
del schema["properties"]["termination_idle_time"]
+ add_extra_schema_types(
+ schema["properties"]["max_duration"],
+ extra_types=[{"type": "boolean"}, {"type": "string"}],
+ )
+ add_extra_schema_types(
+ schema["properties"]["stop_duration"],
+ extra_types=[{"type": "boolean"}, {"type": "string"}],
+ )
+ add_extra_schema_types(
+ schema["properties"]["idle_duration"],
+ extra_types=[{"type": "string"}],
+ )
_validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)(
parse_max_duration
@@ -382,7 +413,6 @@ class ProfilesConfig(CoreModel):
class Config(CoreModel.Config):
json_loads = orjson.loads
json_dumps = pydantic_orjson_dumps_with_indent
-
schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"}
def default(self) -> Optional[Profile]:
diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py
index 13d5dcf2a9..c959d4b059 100644
--- a/src/dstack/_internal/core/models/resources.py
+++ b/src/dstack/_internal/core/models/resources.py
@@ -130,6 +130,12 @@ def __str__(self):
class CPUSpec(CoreModel):
+ arch: Annotated[
+ Optional[gpuhunt.CPUArchitecture],
+ Field(description="The CPU architecture, one of: `x86`, `arm`"),
+ ] = None
+ count: Annotated[Range[int], Field(description="The number of CPU cores")] = DEFAULT_CPU_COUNT
+
class Config(CoreModel.Config):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
@@ -138,12 +144,6 @@ def schema_extra(schema: Dict[str, Any]):
extra_types=[{"type": "integer"}, {"type": "string"}],
)
- arch: Annotated[
- Optional[gpuhunt.CPUArchitecture],
- Field(description="The CPU architecture, one of: `x86`, `arm`"),
- ] = None
- count: Annotated[Range[int], Field(description="The number of CPU cores")] = DEFAULT_CPU_COUNT
-
@classmethod
def __get_validators__(cls):
yield cls.parse
@@ -191,22 +191,6 @@ def _validate_arch(cls, v: Any) -> Any:
class GPUSpec(CoreModel):
- class Config(CoreModel.Config):
- @staticmethod
- def schema_extra(schema: Dict[str, Any]):
- add_extra_schema_types(
- schema["properties"]["count"],
- extra_types=[{"type": "integer"}, {"type": "string"}],
- )
- add_extra_schema_types(
- schema["properties"]["memory"],
- extra_types=[{"type": "integer"}, {"type": "string"}],
- )
- add_extra_schema_types(
- schema["properties"]["total_memory"],
- extra_types=[{"type": "integer"}, {"type": "string"}],
- )
-
vendor: Annotated[
Optional[gpuhunt.AcceleratorVendor],
Field(
@@ -234,6 +218,22 @@ def schema_extra(schema: Dict[str, Any]):
Field(description="The minimum compute capability of the GPU (e.g., `7.5`)"),
] = None
+ class Config(CoreModel.Config):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ add_extra_schema_types(
+ schema["properties"]["count"],
+ extra_types=[{"type": "integer"}, {"type": "string"}],
+ )
+ add_extra_schema_types(
+ schema["properties"]["memory"],
+ extra_types=[{"type": "integer"}, {"type": "string"}],
+ )
+ add_extra_schema_types(
+ schema["properties"]["total_memory"],
+ extra_types=[{"type": "integer"}, {"type": "string"}],
+ )
+
@classmethod
def __get_validators__(cls):
yield cls.parse
@@ -314,6 +314,8 @@ def _vendor_from_string(cls, v: str) -> gpuhunt.AcceleratorVendor:
class DiskSpec(CoreModel):
+ size: Annotated[Range[Memory], Field(description="Disk size")]
+
class Config(CoreModel.Config):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
@@ -322,8 +324,6 @@ def schema_extra(schema: Dict[str, Any]):
extra_types=[{"type": "integer"}, {"type": "string"}],
)
- size: Annotated[Range[Memory], Field(description="Disk size")]
-
@classmethod
def __get_validators__(cls):
yield cls._parse
@@ -340,6 +340,24 @@ def _parse(cls, v: Any) -> Any:
class ResourcesSpec(CoreModel):
+ # TODO: Remove Range[int] in 0.20. Range[int] for backward compatibility only.
+ cpu: Annotated[Union[CPUSpec, Range[int]], Field(description="The CPU requirements")] = (
+ CPUSpec()
+ )
+ memory: Annotated[Range[Memory], Field(description="The RAM size (e.g., `8GB`)")] = (
+ DEFAULT_MEMORY_SIZE
+ )
+ shm_size: Annotated[
+ Optional[Memory],
+ Field(
+ description="The size of shared memory (e.g., `8GB`). "
+ "If you are using parallel communicating processes (e.g., dataloaders in PyTorch), "
+ "you may need to configure this"
+ ),
+ ] = None
+ gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None
+ disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK
+
class Config(CoreModel.Config):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
@@ -364,24 +382,6 @@ def schema_extra(schema: Dict[str, Any]):
extra_types=[{"type": "integer"}, {"type": "string"}],
)
- # TODO: Remove Range[int] in 0.20. Range[int] for backward compatibility only.
- cpu: Annotated[Union[CPUSpec, Range[int]], Field(description="The CPU requirements")] = (
- CPUSpec()
- )
- memory: Annotated[Range[Memory], Field(description="The RAM size (e.g., `8GB`)")] = (
- DEFAULT_MEMORY_SIZE
- )
- shm_size: Annotated[
- Optional[Memory],
- Field(
- description="The size of shared memory (e.g., `8GB`). "
- "If you are using parallel communicating processes (e.g., dataloaders in PyTorch), "
- "you may need to configure this"
- ),
- ] = None
- gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None
- disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK
-
def pretty_format(self) -> str:
# TODO: Remove in 0.20. Use self.cpu directly
cpu = parse_obj_as(CPUSpec, self.cpu)
diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py
index 87a274a0c3..75f3b6b829 100644
--- a/src/dstack/_internal/core/models/runs.py
+++ b/src/dstack/_internal/core/models/runs.py
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Type
+from urllib.parse import urlparse
from pydantic import UUID4, Field, root_validator
from typing_extensions import Annotated
@@ -483,6 +484,9 @@ class ServiceSpec(CoreModel):
model: Optional[ServiceModelSpec] = None
options: Dict[str, Any] = {}
+ def get_domain(self) -> Optional[str]:
+ return urlparse(self.url).hostname
+
class RunStatus(str, Enum):
PENDING = "pending"
diff --git a/src/dstack/_internal/core/services/profiles.py b/src/dstack/_internal/core/services/profiles.py
index cd268aeac0..71ed2e520e 100644
--- a/src/dstack/_internal/core/services/profiles.py
+++ b/src/dstack/_internal/core/services/profiles.py
@@ -37,10 +37,10 @@ def get_termination(
) -> Tuple[TerminationPolicy, int]:
termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE
termination_idle_time = default_termination_idle_time
- if profile.idle_duration is not None and int(profile.idle_duration) < 0:
+ if profile.idle_duration is not None and profile.idle_duration < 0:
termination_policy = TerminationPolicy.DONT_DESTROY
elif profile.idle_duration is not None:
termination_idle_time = profile.idle_duration
if termination_policy == TerminationPolicy.DONT_DESTROY:
termination_idle_time = -1
- return termination_policy, int(termination_idle_time)
+ return termination_policy, termination_idle_time
diff --git a/src/dstack/_internal/core/services/repos.py b/src/dstack/_internal/core/services/repos.py
index bd6026d11b..61ff9b3abb 100644
--- a/src/dstack/_internal/core/services/repos.py
+++ b/src/dstack/_internal/core/services/repos.py
@@ -84,7 +84,7 @@ def get_local_repo_credentials(
def check_remote_repo_credentials_https(url: GitRepoURL, oauth_token: str) -> RemoteRepoCreds:
try:
- git.cmd.Git().ls_remote(url.as_https(oauth_token), env=dict(GIT_TERMINAL_PROMPT="0"))
+ git.cmd.Git().ls_remote(url.as_https(oauth_token), env=dict(GIT_TERMINAL_PROMPT="0")) # type: ignore[attr-defined]
except GitCommandError:
masked = len(oauth_token[:-4]) * "*" + oauth_token[-4:]
raise InvalidRepoCredentialsError(
@@ -111,7 +111,7 @@ def check_remote_repo_credentials_ssh(url: GitRepoURL, identity_file: PathLike)
private_key = f.read()
try:
- git.cmd.Git().ls_remote(
+ git.cmd.Git().ls_remote( # type: ignore[attr-defined]
url.as_ssh(), env=dict(GIT_SSH_COMMAND=make_ssh_command_for_git(identity_file))
)
except GitCommandError:
@@ -131,7 +131,7 @@ def get_default_branch(remote_url: str) -> Optional[str]:
Get the default branch of a remote Git repository.
"""
try:
- output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD")
+ output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD") # type: ignore[attr-defined]
for line in output.splitlines():
if line.startswith("ref:"):
return line.split()[1].split("/")[-1]
diff --git a/src/dstack/_internal/core/services/ssh/ports.py b/src/dstack/_internal/core/services/ssh/ports.py
index 1462958fe7..f0716e6158 100644
--- a/src/dstack/_internal/core/services/ssh/ports.py
+++ b/src/dstack/_internal/core/services/ssh/ports.py
@@ -74,7 +74,7 @@ def _listen(port: int) -> Optional[socket.socket]:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if IS_WINDOWS:
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) # type: ignore[attr-defined]
sock.bind(("", port))
return sock
except socket.error as e:
diff --git a/src/dstack/_internal/proxy/lib/deps.py b/src/dstack/_internal/proxy/lib/deps.py
index ae10be7abe..21528899ce 100644
--- a/src/dstack/_internal/proxy/lib/deps.py
+++ b/src/dstack/_internal/proxy/lib/deps.py
@@ -21,12 +21,16 @@ class ProxyDependencyInjector(ABC):
def __init__(self) -> None:
self._service_conn_pool = ServiceConnectionPool()
+ # Abstract AsyncGenerator does not need async def since
+ # type checkers infer a different type without yield in body.
+ # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
+
@abstractmethod
- async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]:
+ def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]:
pass
@abstractmethod
- async def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]:
+ def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]:
pass
async def get_service_connection_pool(self) -> ServiceConnectionPool:
diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py
index bbb666ac12..eda62ba08f 100644
--- a/src/dstack/_internal/server/app.py
+++ b/src/dstack/_internal/server/app.py
@@ -110,9 +110,11 @@ async def lifespan(app: FastAPI):
_print_dstack_logo()
if not check_required_ssh_version():
logger.warning("OpenSSH 8.4+ is required. The dstack server may not work properly")
+ server_config_manager = None
+ server_config_loaded = False
if settings.SERVER_CONFIG_ENABLED:
server_config_manager = ServerConfigManager()
- config_loaded = server_config_manager.load_config()
+ server_config_loaded = server_config_manager.load_config()
# Encryption has to be configured before working with users and projects
await server_config_manager.apply_encryption()
async with get_session_ctx() as session:
@@ -126,11 +128,9 @@ async def lifespan(app: FastAPI):
session=session,
user=admin,
)
- if settings.SERVER_CONFIG_ENABLED:
- server_config_dir = str(SERVER_CONFIG_FILE_PATH).replace(
- os.path.expanduser("~"), "~", 1
- )
- if not config_loaded:
+ if server_config_manager is not None:
+ server_config_dir = _get_server_config_dir()
+ if not server_config_loaded:
logger.info("Initializing the default configuration...", {"show_path": False})
await server_config_manager.init_config(session=session)
logger.info(
@@ -153,6 +153,7 @@ async def lifespan(app: FastAPI):
)
if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None:
init_default_storage()
+ scheduler = None
if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
scheduler = start_background_tasks()
else:
@@ -167,7 +168,7 @@ async def lifespan(app: FastAPI):
for func in _ON_STARTUP_HOOKS:
await func(app)
yield
- if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
+ if scheduler is not None:
scheduler.shutdown()
PROBES_SCHEDULER.shutdown(wait=False)
await gateway_connections_pool.remove_all()
@@ -371,6 +372,18 @@ def _is_prometheus_request(request: Request) -> bool:
return request.url.path.startswith("/metrics")
+def _sentry_traces_sampler(sampling_context: SamplingContext) -> float:
+ parent_sampling_decision = sampling_context["parent_sampled"]
+ if parent_sampling_decision is not None:
+ return float(parent_sampling_decision)
+ transaction_context = sampling_context["transaction_context"]
+ name = transaction_context.get("name")
+ if name is not None:
+ if name.startswith("background."):
+ return settings.SENTRY_TRACES_BACKGROUND_SAMPLE_RATE
+ return settings.SENTRY_TRACES_SAMPLE_RATE
+
+
def _print_dstack_logo():
console.print(
"""[purple]╱╱╭╮╱╱╭╮╱╱╱╱╱╱╭╮
@@ -387,13 +400,5 @@ def _print_dstack_logo():
)
-def _sentry_traces_sampler(sampling_context: SamplingContext) -> float:
- parent_sampling_decision = sampling_context["parent_sampled"]
- if parent_sampling_decision is not None:
- return float(parent_sampling_decision)
- transaction_context = sampling_context["transaction_context"]
- name = transaction_context.get("name")
- if name is not None:
- if name.startswith("background."):
- return settings.SENTRY_TRACES_BACKGROUND_SAMPLE_RATE
- return settings.SENTRY_TRACES_SAMPLE_RATE
+def _get_server_config_dir() -> str:
+ return str(SERVER_CONFIG_FILE_PATH).replace(os.path.expanduser("~"), "~", 1)
diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py
index ef6c1aebe6..a54cb9e319 100644
--- a/src/dstack/_internal/server/background/tasks/process_gateways.py
+++ b/src/dstack/_internal/server/background/tasks/process_gateways.py
@@ -49,8 +49,8 @@ async def process_gateways():
if gateway_model is None:
return
lockset.add(gateway_model.id)
+ gateway_model_id = gateway_model.id
try:
- gateway_model_id = gateway_model.id
initial_status = gateway_model.status
if initial_status == GatewayStatus.SUBMITTED:
await _process_submitted_gateway(session=session, gateway_model=gateway_model)
@@ -165,6 +165,9 @@ async def _process_provisioning_gateway(
)
gateway_model = res.unique().scalar_one()
+ # Provisioning gateways must have compute.
+ assert gateway_model.gateway_compute is not None
+
# FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway:
# - cannot delete the gateway before it is provisioned because the DB model is locked
# - connection retry counter is reset on server restart
diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py
index 5be54f21ce..8e2127cd78 100644
--- a/src/dstack/_internal/server/background/tasks/process_instances.py
+++ b/src/dstack/_internal/server/background/tasks/process_instances.py
@@ -181,8 +181,8 @@ async def _process_next_instance():
if instance is None:
return
lockset.add(instance.id)
+ instance_model_id = instance.id
try:
- instance_model_id = instance.id
await _process_instance(session=session, instance=instance)
finally:
lockset.difference_update([instance_model_id])
@@ -393,6 +393,7 @@ async def _add_remote(instance: InstanceModel) -> None:
return
region = instance.region
+ assert region is not None # always set for ssh instances
jpd = JobProvisioningData(
backend=BackendType.REMOTE,
instance_type=instance_type,
diff --git a/src/dstack/_internal/server/background/tasks/process_probes.py b/src/dstack/_internal/server/background/tasks/process_probes.py
index 5ed9375d13..bc1dc09431 100644
--- a/src/dstack/_internal/server/background/tasks/process_probes.py
+++ b/src/dstack/_internal/server/background/tasks/process_probes.py
@@ -120,7 +120,7 @@ async def _execute_probe(probe: ProbeModel, probe_spec: ProbeSpec) -> bool:
method=probe_spec.method,
url="http://dstack" + probe_spec.url,
headers=[(h.name, h.value) for h in probe_spec.headers],
- data=probe_spec.body,
+ content=probe_spec.body,
timeout=probe_spec.timeout,
follow_redirects=False,
)
diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py
index 0a98bc7fae..19cb089b11 100644
--- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py
@@ -128,9 +128,8 @@ async def _process_next_running_job():
if job_model is None:
return
lockset.add(job_model.id)
-
+ job_model_id = job_model.id
try:
- job_model_id = job_model.id
await _process_running_job(session=session, job_model=job_model)
finally:
lockset.difference_update([job_model_id])
@@ -170,6 +169,11 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
+ volumes = []
+ secrets = {}
+ cluster_info = None
+ repo_creds = None
+
initial_status = job_model.status
if initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]:
# Wait until all other jobs in the replica are provisioned
@@ -257,6 +261,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
user_ssh_key,
)
else:
+ assert cluster_info is not None
logger.debug(
"%s: process provisioning job without shim, age=%s",
fmt(job_model),
@@ -275,7 +280,6 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
repo=repo_model,
code_hash=_get_repo_code_hash(run, job),
)
-
success = await common_utils.run_async(
_submit_job_to_runner,
server_ssh_private_keys,
@@ -309,6 +313,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
else: # fails are not acceptable
if initial_status == JobStatus.PULLING:
+ assert cluster_info is not None
logger.debug(
"%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age
)
@@ -341,7 +346,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
server_ssh_private_keys,
job_provisioning_data,
)
- elif initial_status == JobStatus.RUNNING:
+ else:
logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age)
success = await common_utils.run_async(
_process_running,
@@ -632,6 +637,7 @@ def _process_pulling_with_shim(
is successful
"""
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
+ job_runtime_data = None
if shim_client.is_api_v2_supported(): # raises error if shim is down, causes retry
task = shim_client.get_task(job_model.id)
diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py
index 16a84dcb93..e9d13a5009 100644
--- a/src/dstack/_internal/server/background/tasks/process_runs.py
+++ b/src/dstack/_internal/server/background/tasks/process_runs.py
@@ -129,8 +129,8 @@ async def _process_next_run():
job_ids = [j.id for j in run_model.jobs]
run_lockset.add(run_model.id)
job_lockset.update(job_ids)
+ run_model_id = run_model.id
try:
- run_model_id = run_model.id
await _process_run(session=session, run_model=run_model)
finally:
run_lockset.difference_update([run_model_id])
diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
index 9470e39b79..c85715f0e7 100644
--- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
@@ -148,8 +148,8 @@ async def _process_next_submitted_job():
if job_model is None:
return
lockset.add(job_model.id)
+ job_model_id = job_model.id
try:
- job_model_id = job_model.id
await _process_submitted_job(session=session, job_model=job_model)
finally:
lockset.difference_update([job_model_id])
diff --git a/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py b/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py
index cd81765636..6a358dcd61 100644
--- a/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py
@@ -75,9 +75,9 @@ async def _process_next_terminating_job():
return
instance_lockset.add(instance_model.id)
job_lockset.add(job_model.id)
+ job_model_id = job_model.id
+ instance_model_id = job_model.used_instance_id
try:
- job_model_id = job_model.id
- instance_model_id = job_model.used_instance_id
await _process_job(
session=session,
job_model=job_model,
diff --git a/src/dstack/_internal/server/background/tasks/process_volumes.py b/src/dstack/_internal/server/background/tasks/process_volumes.py
index 4e37f6997b..534af8d48f 100644
--- a/src/dstack/_internal/server/background/tasks/process_volumes.py
+++ b/src/dstack/_internal/server/background/tasks/process_volumes.py
@@ -42,8 +42,8 @@ async def process_submitted_volumes():
if volume_model is None:
return
lockset.add(volume_model.id)
+ volume_model_id = volume_model.id
try:
- volume_model_id = volume_model.id
await _process_submitted_volume(session=session, volume_model=volume_model)
finally:
lockset.difference_update([volume_model_id])
diff --git a/src/dstack/_internal/server/db.py b/src/dstack/_internal/server/db.py
index 4e747a8e78..084630add1 100644
--- a/src/dstack/_internal/server/db.py
+++ b/src/dstack/_internal/server/db.py
@@ -4,8 +4,12 @@
from alembic import command, config
from sqlalchemy import AsyncAdaptedQueuePool, event
from sqlalchemy.engine.interfaces import DBAPIConnection
-from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy.ext.asyncio import (
+ AsyncEngine,
+ AsyncSession,
+ async_sessionmaker,
+ create_async_engine,
+)
from sqlalchemy.pool import ConnectionPoolEntry
from dstack._internal.server import settings
@@ -26,8 +30,8 @@ def __init__(self, url: str, engine: Optional[AsyncEngine] = None):
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
)
- self.session_maker = sessionmaker(
- bind=self.engine,
+ self.session_maker = async_sessionmaker(
+ bind=self.engine, # type: ignore[assignment]
expire_on_commit=False,
class_=AsyncSession,
)
diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py
index 915a7c7665..cd8873e73b 100644
--- a/src/dstack/_internal/server/models.py
+++ b/src/dstack/_internal/server/models.py
@@ -622,6 +622,7 @@ class InstanceModel(BaseModel):
backend: Mapped[Optional[BackendType]] = mapped_column(EnumAsString(BackendType, 100))
backend_data: Mapped[Optional[str]] = mapped_column(Text)
+ # Not set for cloud fleets that haven't been provisioning
offer: Mapped[Optional[str]] = mapped_column(Text)
region: Mapped[Optional[str]] = mapped_column(String(2000))
price: Mapped[Optional[float]] = mapped_column(Float)
diff --git a/src/dstack/_internal/server/routers/gpus.py b/src/dstack/_internal/server/routers/gpus.py
index 521ace1594..45f0e8bf1f 100644
--- a/src/dstack/_internal/server/routers/gpus.py
+++ b/src/dstack/_internal/server/routers/gpus.py
@@ -1,9 +1,7 @@
from typing import Tuple
from fastapi import APIRouter, Depends
-from sqlalchemy.ext.asyncio import AsyncSession
-from dstack._internal.server.db import get_session
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.schemas.gpus import ListGpusRequest, ListGpusResponse
from dstack._internal.server.security.permissions import ProjectMember
@@ -20,10 +18,7 @@
@project_router.post("/list", response_model=ListGpusResponse, response_model_exclude_none=True)
async def list_gpus(
body: ListGpusRequest,
- session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> ListGpusResponse:
_, project = user_project
- return await list_gpus_grouped(
- session=session, project=project, run_spec=body.run_spec, group_by=body.group_by
- )
+ return await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by)
diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py
index 16d8ae7821..9711a503bf 100644
--- a/src/dstack/_internal/server/services/backends/__init__.py
+++ b/src/dstack/_internal/server/services/backends/__init__.py
@@ -17,8 +17,8 @@
)
from dstack._internal.core.backends.local.backend import LocalBackend
from dstack._internal.core.backends.models import (
- AnyBackendConfig,
AnyBackendConfigWithCreds,
+ AnyBackendConfigWithoutCreds,
)
from dstack._internal.core.errors import (
BackendError,
@@ -126,19 +126,25 @@ async def get_backend_config(
)
continue
if backend_model.type == backend_type:
- return get_backend_config_from_backend_model(
- configurator, backend_model, include_creds=True
- )
+ return get_backend_config_with_creds_from_backend_model(configurator, backend_model)
return None
-def get_backend_config_from_backend_model(
+def get_backend_config_with_creds_from_backend_model(
+ configurator: Configurator,
+ backend_model: BackendModel,
+) -> AnyBackendConfigWithCreds:
+ backend_record = get_stored_backend_record(backend_model)
+ backend_config = configurator.get_backend_config(backend_record, include_creds=True)
+ return backend_config
+
+
+def get_backend_config_without_creds_from_backend_model(
configurator: Configurator,
backend_model: BackendModel,
- include_creds: bool,
-) -> AnyBackendConfig:
+) -> AnyBackendConfigWithoutCreds:
backend_record = get_stored_backend_record(backend_model)
- backend_config = configurator.get_backend_config(backend_record, include_creds=include_creds)
+ backend_config = configurator.get_backend_config(backend_record, include_creds=False)
return backend_config
diff --git a/src/dstack/_internal/server/services/backends/handlers.py b/src/dstack/_internal/server/services/backends/handlers.py
index bcd4b857b0..77f8d9832f 100644
--- a/src/dstack/_internal/server/services/backends/handlers.py
+++ b/src/dstack/_internal/server/services/backends/handlers.py
@@ -55,7 +55,11 @@ async def _check_active_instances(
)
for fleet_model in fleet_models:
for instance in fleet_model.instances:
- if instance.status.is_active() and instance.backend in backends_types:
+ if (
+ instance.status.is_active()
+ and instance.backend is not None
+ and instance.backend in backends_types
+ ):
if error:
msg = (
f"Backend {instance.backend.value} has active instances."
@@ -83,6 +87,7 @@ async def _check_active_volumes(
if (
volume_model.status.is_active()
and volume_model.provisioning_data is not None
+ and volume_model.provisioning_data.backend is not None
and volume_model.provisioning_data.backend in backends_types
):
if error:
diff --git a/src/dstack/_internal/server/services/docker.py b/src/dstack/_internal/server/services/docker.py
index 49e8d8e857..7181edc7d3 100644
--- a/src/dstack/_internal/server/services/docker.py
+++ b/src/dstack/_internal/server/services/docker.py
@@ -32,15 +32,15 @@ def __call__(self, dxf: DXF, response: requests.Response) -> None:
class DockerImage(CoreModel):
- class Config(CoreModel.Config):
- frozen = True
-
image: str
registry: Optional[str]
repo: str
tag: str
digest: Optional[str]
+ class Config(CoreModel.Config):
+ frozen = True
+
class ImageConfig(CoreModel):
user: Annotated[Optional[str], Field(alias="User")] = None
@@ -77,7 +77,7 @@ def get_image_config(image_name: str, registry_auth: Optional[RegistryAuth]) ->
registry_client = PatchedDXF(
host=image.registry or DEFAULT_REGISTRY,
repo=image.repo,
- auth=DXFAuthAdapter(registry_auth),
+ auth=DXFAuthAdapter(registry_auth), # type: ignore[assignment]
timeout=REGISTRY_REQUEST_TIMEOUT,
)
@@ -88,7 +88,7 @@ def get_image_config(image_name: str, registry_auth: Optional[RegistryAuth]) ->
)
manifest = ImageManifest.__response__.parse_raw(manifest_resp)
config_stream = registry_client.pull_blob(manifest.config.digest)
- config_resp = join_byte_stream_checked(config_stream, MAX_CONFIG_OBJECT_SIZE)
+ config_resp = join_byte_stream_checked(config_stream, MAX_CONFIG_OBJECT_SIZE) # type: ignore[arg-type]
if config_resp is None:
raise DockerRegistryError(
f"Image config object exceeds the size limit of {MAX_CONFIG_OBJECT_SIZE} bytes"
diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py
index e02cac6589..4f2b64cc5a 100644
--- a/src/dstack/_internal/server/services/fleets.py
+++ b/src/dstack/_internal/server/services/fleets.py
@@ -504,6 +504,7 @@ async def create_fleet_ssh_instance_model(
raise ServerClientError("ssh key or user not specified")
if proxy_jump is not None:
+ assert proxy_jump.ssh_key is not None
ssh_proxy = SSHConnectionParams(
hostname=proxy_jump.hostname,
port=proxy_jump.port or 22,
diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py
index 5a7b50d021..f47b192999 100644
--- a/src/dstack/_internal/server/services/gateways/__init__.py
+++ b/src/dstack/_internal/server/services/gateways/__init__.py
@@ -93,6 +93,8 @@ async def create_gateway_compute(
backend_id: Optional[uuid.UUID] = None,
) -> GatewayComputeModel:
assert isinstance(backend_compute, ComputeWithGatewaySupport)
+ assert configuration.name is not None
+
private_bytes, public_bytes = generate_rsa_key_pair_bytes()
gateway_ssh_private_key = private_bytes.decode()
gateway_ssh_public_key = public_bytes.decode()
diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py
index aa4b4823cf..f8c0900792 100644
--- a/src/dstack/_internal/server/services/gateways/client.py
+++ b/src/dstack/_internal/server/services/gateways/client.py
@@ -7,7 +7,7 @@
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.errors import GatewayError
-from dstack._internal.core.models.configurations import RateLimit, ServiceConfiguration
+from dstack._internal.core.models.configurations import RateLimit
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import JobSpec, JobSubmission, Run, get_service_port
from dstack._internal.proxy.gateway.schemas.stats import ServiceStats
@@ -85,7 +85,7 @@ async def register_replica(
ssh_head_proxy: Optional[SSHConnectionParams],
ssh_head_proxy_private_key: Optional[str],
):
- assert isinstance(run.run_spec.configuration, ServiceConfiguration)
+ assert run.run_spec.configuration.type == "service"
payload = {
"job_id": job_submission.id.hex,
"app_port": get_service_port(job_spec, run.run_spec.configuration),
@@ -93,6 +93,9 @@ async def register_replica(
"ssh_head_proxy_private_key": ssh_head_proxy_private_key,
}
jpd = job_submission.job_provisioning_data
+ assert jpd is not None
+ assert jpd.hostname is not None
+ assert jpd.ssh_port is not None
if not jpd.dockerized:
payload.update(
{
diff --git a/src/dstack/_internal/server/services/gateways/connection.py b/src/dstack/_internal/server/services/gateways/connection.py
index 6b107c34a9..b8df322a1d 100644
--- a/src/dstack/_internal/server/services/gateways/connection.py
+++ b/src/dstack/_internal/server/services/gateways/connection.py
@@ -67,7 +67,7 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int):
# reverse_forwarded_sockets are added later in .open()
)
self.tunnel_id = uuid.uuid4()
- self._client = GatewayClient(uds=self.gateway_socket_path)
+ self._client = GatewayClient(uds=str(self.gateway_socket_path))
@staticmethod
def _init_symlink_dir(connection_dir: Path) -> Tuple[TemporaryDirectory, Path]:
diff --git a/src/dstack/_internal/server/services/gpus.py b/src/dstack/_internal/server/services/gpus.py
index 0ec347be00..c2ddcd2fd8 100644
--- a/src/dstack/_internal/server/services/gpus.py
+++ b/src/dstack/_internal/server/services/gpus.py
@@ -1,8 +1,8 @@
from typing import Dict, List, Literal, Optional, Tuple
-from sqlalchemy.ext.asyncio import AsyncSession
-
from dstack._internal.core.backends.base.backend import Backend
+from dstack._internal.core.errors import ServerClientError
+from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import InstanceOfferWithAvailability
from dstack._internal.core.models.profiles import SpotPolicy
from dstack._internal.core.models.resources import Range
@@ -15,10 +15,43 @@
ListGpusResponse,
)
from dstack._internal.server.services.offers import get_offers_by_requirements
+from dstack._internal.utils.common import get_or_error
+
+
+async def list_gpus_grouped(
+ project: ProjectModel,
+ run_spec: RunSpec,
+ group_by: Optional[List[Literal["backend", "region", "count"]]] = None,
+) -> ListGpusResponse:
+ """Retrieves available GPU specifications based on a run spec, with optional grouping."""
+ offers = await _get_gpu_offers(project=project, run_spec=run_spec)
+ backend_gpus = _process_offers_into_backend_gpus(offers)
+ group_by_set = set(group_by) if group_by else set()
+ if "region" in group_by_set and "backend" not in group_by_set:
+ raise ServerClientError("Cannot group by 'region' without also grouping by 'backend'")
+
+ # Determine grouping strategy based on combination
+ has_backend = "backend" in group_by_set
+ has_region = "region" in group_by_set
+ has_count = "count" in group_by_set
+ if has_backend and has_region and has_count:
+ gpus = _get_gpus_grouped_by_backend_region_and_count(backend_gpus)
+ elif has_backend and has_count:
+ gpus = _get_gpus_grouped_by_backend_and_count(backend_gpus)
+ elif has_backend and has_region:
+ gpus = _get_gpus_grouped_by_backend_and_region(backend_gpus)
+ elif has_backend:
+ gpus = _get_gpus_grouped_by_backend(backend_gpus)
+ elif has_count:
+ gpus = _get_gpus_grouped_by_count(backend_gpus)
+ else:
+ gpus = _get_gpus_with_no_grouping(backend_gpus)
+
+ return ListGpusResponse(gpus=gpus)
async def _get_gpu_offers(
- session: AsyncSession, project: ProjectModel, run_spec: RunSpec
+ project: ProjectModel, run_spec: RunSpec
) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
"""Fetches all available instance offers that match the run spec's GPU requirements."""
profile = run_spec.merged_profile
@@ -28,7 +61,6 @@ async def _get_gpu_offers(
spot=get_policy_map(profile.spot_policy, default=SpotPolicy.AUTO),
reservation=profile.reservation,
)
-
return await get_offers_by_requirements(
project=project,
profile=profile,
@@ -45,7 +77,7 @@ def _process_offers_into_backend_gpus(
offers: List[Tuple[Backend, InstanceOfferWithAvailability]],
) -> List[BackendGpus]:
"""Transforms raw offers into a structured list of BackendGpus, aggregating GPU info."""
- backend_data: Dict[str, Dict] = {}
+ backend_data: Dict[BackendType, Dict] = {}
for backend, offer in offers:
backend_type = backend.TYPE
@@ -111,7 +143,7 @@ def _process_offers_into_backend_gpus(
return backend_gpus_list
-def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: str):
+def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: BackendType):
"""Updates an existing GpuGroup with new data from another GPU offer."""
spot_type: Literal["spot", "on-demand"] = "spot" if gpu.spot else "on-demand"
@@ -122,6 +154,12 @@ def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: str):
if row.backends and backend_type not in row.backends:
row.backends.append(backend_type)
+ # FIXME: Consider using non-optional range
+ assert row.count.min is not None
+ assert row.count.max is not None
+ assert row.price.min is not None
+ assert row.price.max is not None
+
row.count.min = min(row.count.min, gpu.count)
row.count.max = max(row.count.max, gpu.count)
per_gpu_price = gpu.price / gpu.count
@@ -194,7 +232,7 @@ def _get_gpus_grouped_by_backend(backend_gpus: List[BackendGpus]) -> List[GpuGro
not any(av.is_available() for av in g.availability),
g.price.min,
g.price.max,
- g.backend.value,
+ get_or_error(g.backend).value,
g.name,
g.memory_mib,
),
@@ -229,7 +267,7 @@ def _get_gpus_grouped_by_backend_and_region(backend_gpus: List[BackendGpus]) ->
not any(av.is_available() for av in g.availability),
g.price.min,
g.price.max,
- g.backend.value,
+ get_or_error(g.backend).value,
g.region,
g.name,
g.memory_mib,
@@ -299,7 +337,7 @@ def _get_gpus_grouped_by_backend_and_count(backend_gpus: List[BackendGpus]) -> L
not any(av.is_available() for av in g.availability),
g.price.min,
g.price.max,
- g.backend.value,
+ get_or_error(g.backend).value,
g.count.min,
g.name,
g.memory_mib,
@@ -344,47 +382,10 @@ def _get_gpus_grouped_by_backend_region_and_count(
not any(av.is_available() for av in g.availability),
g.price.min,
g.price.max,
- g.backend.value,
+ get_or_error(g.backend).value,
g.region,
g.count.min,
g.name,
g.memory_mib,
),
)
-
-
-async def list_gpus_grouped(
- session: AsyncSession,
- project: ProjectModel,
- run_spec: RunSpec,
- group_by: Optional[List[Literal["backend", "region", "count"]]] = None,
-) -> ListGpusResponse:
- """Retrieves available GPU specifications based on a run spec, with optional grouping."""
- offers = await _get_gpu_offers(session, project, run_spec)
- backend_gpus = _process_offers_into_backend_gpus(offers)
-
- group_by_set = set(group_by) if group_by else set()
-
- if "region" in group_by_set and "backend" not in group_by_set:
- from dstack._internal.core.errors import ServerClientError
-
- raise ServerClientError("Cannot group by 'region' without also grouping by 'backend'")
-
- # Determine grouping strategy based on combination
- has_backend = "backend" in group_by_set
- has_region = "region" in group_by_set
- has_count = "count" in group_by_set
- if has_backend and has_region and has_count:
- gpus = _get_gpus_grouped_by_backend_region_and_count(backend_gpus)
- elif has_backend and has_count:
- gpus = _get_gpus_grouped_by_backend_and_count(backend_gpus)
- elif has_backend and has_region:
- gpus = _get_gpus_grouped_by_backend_and_region(backend_gpus)
- elif has_backend:
- gpus = _get_gpus_grouped_by_backend(backend_gpus)
- elif has_count:
- gpus = _get_gpus_grouped_by_count(backend_gpus)
- else:
- gpus = _get_gpus_with_no_grouping(backend_gpus)
-
- return ListGpusResponse(gpus=gpus)
diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py
index 1a67ad3cf7..ee9bb05f6e 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/base.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/base.py
@@ -3,7 +3,7 @@
import threading
from abc import ABC, abstractmethod
from pathlib import PurePosixPath
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional
from cachetools import TTLCache, cached
@@ -179,6 +179,7 @@ def _shell(self) -> str:
async def _commands(self) -> List[str]:
if self.run_spec.configuration.entrypoint is not None: # docker-like format
+ assert self.run_spec.configuration.type != "dev-environment"
entrypoint = shlex.split(self.run_spec.configuration.entrypoint)
commands = self.run_spec.configuration.commands
elif shell_commands := self._shell_commands():
@@ -258,19 +259,17 @@ def _single_branch(self) -> bool:
return self.run_spec.configuration.single_branch
def _max_duration(self) -> Optional[int]:
- if self.run_spec.merged_profile.max_duration in [None, True]:
+ if self.run_spec.merged_profile.max_duration is None:
return self._default_max_duration()
- if self.run_spec.merged_profile.max_duration in ["off", False]:
+ if self.run_spec.merged_profile.max_duration == "off":
return None
- # pydantic validator ensures this is int
return self.run_spec.merged_profile.max_duration
def _stop_duration(self) -> Optional[int]:
- if self.run_spec.merged_profile.stop_duration in [None, True]:
+ if self.run_spec.merged_profile.stop_duration is None:
return DEFAULT_STOP_DURATION
- if self.run_spec.merged_profile.stop_duration in ["off", False]:
+ if self.run_spec.merged_profile.stop_duration == "off":
return None
- # pydantic validator ensures this is int
return self.run_spec.merged_profile.stop_duration
def _utilization_policy(self) -> Optional[UtilizationPolicy]:
@@ -328,7 +327,7 @@ def _probes(self) -> list[ProbeSpec]:
def interpolate_job_volumes(
- run_volumes: List[Union[MountPoint, str]],
+ run_volumes: List[MountPoint],
job_num: int,
) -> List[MountPoint]:
if len(run_volumes) == 0:
@@ -343,9 +342,6 @@ def interpolate_job_volumes(
)
job_volumes = []
for mount_point in run_volumes:
- if isinstance(mount_point, str):
- # pydantic validator ensures strings are converted to MountPoint
- continue
if not isinstance(mount_point, VolumeMountPoint):
job_volumes.append(mount_point.copy())
continue
diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py
index a10922ef79..20aad1f232 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/dev.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py
@@ -18,6 +18,8 @@ class DevEnvironmentJobConfigurator(JobConfigurator):
TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT
def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]):
+ assert run_spec.configuration.type == "dev-environment"
+
if run_spec.configuration.ide == "vscode":
__class = VSCodeDesktop
elif run_spec.configuration.ide == "cursor":
@@ -32,6 +34,8 @@ def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]):
super().__init__(run_spec=run_spec, secrets=secrets)
def _shell_commands(self) -> List[str]:
+ assert self.run_spec.configuration.type == "dev-environment"
+
commands = self.ide.get_install_commands()
commands.append(INSTALL_IPYKERNEL)
commands += self.run_spec.configuration.setup
@@ -56,4 +60,5 @@ def _spot_policy(self) -> SpotPolicy:
return self.run_spec.merged_profile.spot_policy or SpotPolicy.ONDEMAND
def _ports(self) -> List[PortMapping]:
+ assert self.run_spec.configuration.type == "dev-environment"
return self.run_spec.configuration.ports
diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py
index d0c819d8da..9c5e68d96e 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from dstack._internal.core.models.configurations import DEFAULT_REPO_DIR
@@ -6,8 +6,8 @@
class CursorDesktop:
def __init__(
self,
- run_name: str,
- version: str,
+ run_name: Optional[str],
+ version: Optional[str],
extensions: List[str],
):
self.run_name = run_name
diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py
index f1a2534de0..a10b254d02 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from dstack._internal.core.models.configurations import DEFAULT_REPO_DIR
@@ -6,8 +6,8 @@
class VSCodeDesktop:
def __init__(
self,
- run_name: str,
- version: str,
+ run_name: Optional[str],
+ version: Optional[str],
extensions: List[str],
):
self.run_name = run_name
diff --git a/src/dstack/_internal/server/services/jobs/configurators/service.py b/src/dstack/_internal/server/services/jobs/configurators/service.py
index 7cd36f178a..a00216a6d4 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/service.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/service.py
@@ -9,6 +9,7 @@ class ServiceJobConfigurator(JobConfigurator):
TYPE: RunConfigurationType = RunConfigurationType.SERVICE
def _shell_commands(self) -> List[str]:
+ assert self.run_spec.configuration.type == "service"
return self.run_spec.configuration.commands
def _default_single_branch(self) -> bool:
diff --git a/src/dstack/_internal/server/services/jobs/configurators/task.py b/src/dstack/_internal/server/services/jobs/configurators/task.py
index 4b1c93ce05..6a0da9f003 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/task.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/task.py
@@ -10,6 +10,7 @@ class TaskJobConfigurator(JobConfigurator):
TYPE: RunConfigurationType = RunConfigurationType.TASK
async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
+ assert self.run_spec.configuration.type == "task"
job_specs = []
for job_num in range(self.run_spec.configuration.nodes):
job_spec = await self._get_job_spec(
@@ -21,6 +22,7 @@ async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
return job_specs
def _shell_commands(self) -> List[str]:
+ assert self.run_spec.configuration.type == "task"
return self.run_spec.configuration.commands
def _default_single_branch(self) -> bool:
@@ -33,6 +35,7 @@ def _spot_policy(self) -> SpotPolicy:
return self.run_spec.merged_profile.spot_policy or SpotPolicy.ONDEMAND
def _ports(self) -> List[PortMapping]:
+ assert self.run_spec.configuration.type == "task"
return self.run_spec.configuration.ports
def _working_dir(self) -> Optional[str]:
diff --git a/src/dstack/_internal/server/services/locking.py b/src/dstack/_internal/server/services/locking.py
index 4c3b7f938a..71a4aa7bfe 100644
--- a/src/dstack/_internal/server/services/locking.py
+++ b/src/dstack/_internal/server/services/locking.py
@@ -23,13 +23,13 @@ async def __aexit__(self, exc_type, exc, tb): ...
class Lockset(Protocol[T]):
- def __contains__(self, item: T) -> bool: ...
+ def __contains__(self, item: T, /) -> bool: ...
def __iter__(self) -> Iterator[T]: ...
def __len__(self) -> int: ...
- def add(self, item: T) -> None: ...
- def discard(self, item: T) -> None: ...
- def update(self, other: Iterable[T]) -> None: ...
- def difference_update(self, other: Iterable[T]) -> None: ...
+ def add(self, item: T, /) -> None: ...
+ def discard(self, item: T, /) -> None: ...
+ def update(self, other: Iterable[T], /) -> None: ...
+ def difference_update(self, other: Iterable[T], /) -> None: ...
class ResourceLocker:
diff --git a/src/dstack/_internal/server/services/logs/__init__.py b/src/dstack/_internal/server/services/logs/__init__.py
index b38264980d..5b06ff4ad2 100644
--- a/src/dstack/_internal/server/services/logs/__init__.py
+++ b/src/dstack/_internal/server/services/logs/__init__.py
@@ -7,14 +7,14 @@
from dstack._internal.server.models import ProjectModel
from dstack._internal.server.schemas.logs import PollLogsRequest
from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
-from dstack._internal.server.services.logs.aws import BOTO_AVAILABLE, CloudWatchLogStorage
+from dstack._internal.server.services.logs import aws as aws_logs
+from dstack._internal.server.services.logs import gcp as gcp_logs
from dstack._internal.server.services.logs.base import (
LogStorage,
LogStorageError,
b64encode_raw_message,
)
from dstack._internal.server.services.logs.filelog import FileLogStorage
-from dstack._internal.server.services.logs.gcp import GCP_LOGGING_AVAILABLE, GCPLogStorage
from dstack._internal.utils.common import run_async
from dstack._internal.utils.logging import get_logger
@@ -29,9 +29,9 @@ def get_log_storage() -> LogStorage:
if _log_storage is not None:
return _log_storage
if settings.SERVER_CLOUDWATCH_LOG_GROUP:
- if BOTO_AVAILABLE:
+ if aws_logs.BOTO_AVAILABLE:
try:
- _log_storage = CloudWatchLogStorage(
+ _log_storage = aws_logs.CloudWatchLogStorage(
group=settings.SERVER_CLOUDWATCH_LOG_GROUP,
region=settings.SERVER_CLOUDWATCH_LOG_REGION,
)
@@ -44,9 +44,11 @@ def get_log_storage() -> LogStorage:
else:
logger.error("Cannot use CloudWatch Logs storage: boto3 is not installed")
elif settings.SERVER_GCP_LOGGING_PROJECT:
- if GCP_LOGGING_AVAILABLE:
+ if gcp_logs.GCP_LOGGING_AVAILABLE:
try:
- _log_storage = GCPLogStorage(project_id=settings.SERVER_GCP_LOGGING_PROJECT)
+ _log_storage = gcp_logs.GCPLogStorage(
+ project_id=settings.SERVER_GCP_LOGGING_PROJECT
+ )
except LogStorageError as e:
logger.error("Failed to initialize GCP Logs storage: %s", e)
except Exception:
diff --git a/src/dstack/_internal/server/services/logs/aws.py b/src/dstack/_internal/server/services/logs/aws.py
index 692ae1348e..4e56f0865d 100644
--- a/src/dstack/_internal/server/services/logs/aws.py
+++ b/src/dstack/_internal/server/services/logs/aws.py
@@ -24,347 +24,350 @@
)
from dstack._internal.utils.logging import get_logger
+logger = get_logger(__name__)
+
+
BOTO_AVAILABLE = True
try:
import boto3
import botocore.exceptions
except ImportError:
BOTO_AVAILABLE = False
-
-logger = get_logger(__name__)
-
-
-class _CloudWatchLogEvent(TypedDict):
- timestamp: int # unix time in milliseconds
- message: str
-
-
-class CloudWatchLogStorage(LogStorage):
- # "The maximum number of log events in a batch is 10,000".
- EVENT_MAX_COUNT_IN_BATCH = 10000
- # "The maximum batch size is 1,048,576 bytes" — exactly 1 MiB. "This size is calculated
- # as the sum of all event messages in UTF-8, plus 26 bytes for each log event".
- BATCH_MAX_SIZE = 1048576
- # "Each log event can be no larger than 256 KB" — KB means KiB; includes MESSAGE_OVERHEAD_SIZE.
- MESSAGE_MAX_SIZE = 262144
- # Message size in bytes = len(message.encode("utf-8")) + MESSAGE_OVERHEAD_SIZE.
- MESSAGE_OVERHEAD_SIZE = 26
- # "A batch of log events in a single request cannot span more than 24 hours".
- BATCH_MAX_SPAN = int(timedelta(hours=24).total_seconds()) * 1000
- # Decrease allowed deltas by possible clock drift between dstack and CloudWatch.
- CLOCK_DRIFT = int(timedelta(minutes=10).total_seconds()) * 1000
- # "None of the log events in the batch can be more than 14 days in the past."
- PAST_EVENT_MAX_DELTA = int((timedelta(days=14)).total_seconds()) * 1000 - CLOCK_DRIFT
- # "None of the log events in the batch can be more than 2 hours in the future."
- FUTURE_EVENT_MAX_DELTA = int((timedelta(hours=2)).total_seconds()) * 1000 - CLOCK_DRIFT
- # Maximum number of retries when polling for log events to skip empty pages.
- MAX_RETRIES = 10
-
- def __init__(self, *, group: str, region: Optional[str] = None) -> None:
- with self._wrap_boto_errors():
- session = boto3.Session(region_name=region)
- self._client = session.client("logs")
- self._check_group_exists(group)
- self._group = group
- self._region = self._client.meta.region_name
- # Stores names of already created streams.
- # XXX: This set acts as an unbound cache. If this becomes a problem (in case of _very_ long
- # running server and/or lots of jobs, consider replacing it with an LRU cache, e.g.,
- # a simple OrderedDict-based implementation should be OK.
- self._streams: Set[str] = set()
-
- def close(self) -> None:
- self._client.close()
-
- def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
- log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
- stream = self._get_stream_name(
- project.name, request.run_name, request.job_submission_id, log_producer
- )
- cw_events: List[_CloudWatchLogEvent]
- next_token: Optional[str] = None
- with self._wrap_boto_errors():
- try:
- cw_events, next_token = self._get_log_events_with_retry(stream, request)
- except botocore.exceptions.ClientError as e:
- if not self._is_resource_not_found_exception(e):
- raise
- # Check if the group exists to distinguish between group not found vs stream not found
- try:
- self._check_group_exists(self._group)
- # Group exists, so the error must be due to missing stream
- logger.debug("Stream %s not found, returning dummy response", stream)
- cw_events = []
- except LogStorageError:
- # Group doesn't exist, re-raise the LogStorageError
- raise
- logs = [
- LogEvent(
- timestamp=unix_time_ms_to_datetime(cw_event["timestamp"]),
- log_source=LogEventSource.STDOUT,
- message=cw_event["message"],
+else:
+
+ class _CloudWatchLogEvent(TypedDict):
+ timestamp: int # unix time in milliseconds
+ message: str
+
+ class CloudWatchLogStorage(LogStorage):
+ # "The maximum number of log events in a batch is 10,000".
+ EVENT_MAX_COUNT_IN_BATCH = 10000
+ # "The maximum batch size is 1,048,576 bytes" — exactly 1 MiB. "This size is calculated
+ # as the sum of all event messages in UTF-8, plus 26 bytes for each log event".
+ BATCH_MAX_SIZE = 1048576
+ # "Each log event can be no larger than 256 KB" — KB means KiB; includes MESSAGE_OVERHEAD_SIZE.
+ MESSAGE_MAX_SIZE = 262144
+ # Message size in bytes = len(message.encode("utf-8")) + MESSAGE_OVERHEAD_SIZE.
+ MESSAGE_OVERHEAD_SIZE = 26
+ # "A batch of log events in a single request cannot span more than 24 hours".
+ BATCH_MAX_SPAN = int(timedelta(hours=24).total_seconds()) * 1000
+ # Decrease allowed deltas by possible clock drift between dstack and CloudWatch.
+ CLOCK_DRIFT = int(timedelta(minutes=10).total_seconds()) * 1000
+ # "None of the log events in the batch can be more than 14 days in the past."
+ PAST_EVENT_MAX_DELTA = int((timedelta(days=14)).total_seconds()) * 1000 - CLOCK_DRIFT
+ # "None of the log events in the batch can be more than 2 hours in the future."
+ FUTURE_EVENT_MAX_DELTA = int((timedelta(hours=2)).total_seconds()) * 1000 - CLOCK_DRIFT
+ # Maximum number of retries when polling for log events to skip empty pages.
+ MAX_RETRIES = 10
+
+ def __init__(self, *, group: str, region: Optional[str] = None) -> None:
+ with self._wrap_boto_errors():
+ session = boto3.Session(region_name=region)
+ self._client = session.client("logs")
+ self._check_group_exists(group)
+ self._group = group
+ self._region = self._client.meta.region_name
+ # Stores names of already created streams.
+ # XXX: This set acts as an unbound cache. If this becomes a problem (in case of _very_ long
+ # running server and/or lots of jobs, consider replacing it with an LRU cache, e.g.,
+ # a simple OrderedDict-based implementation should be OK.
+ self._streams: Set[str] = set()
+
+ def close(self) -> None:
+ self._client.close()
+
+ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
+ log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
+ stream = self._get_stream_name(
+ project.name, request.run_name, request.job_submission_id, log_producer
)
- for cw_event in cw_events
- ]
- return JobSubmissionLogs(
- logs=logs,
- external_url=self._get_stream_external_url(stream),
- next_token=next_token,
- )
-
- def _get_log_events_with_retry(
- self, stream: str, request: PollLogsRequest
- ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]:
- current_request = request
- previous_next_token = request.next_token
-
- for attempt in range(self.MAX_RETRIES):
- cw_events, next_token = self._get_log_events(stream, current_request)
-
- if cw_events:
- return cw_events, next_token
-
- if not next_token or next_token == previous_next_token:
- return [], None
-
- previous_next_token = next_token
- current_request = PollLogsRequest(
- run_name=request.run_name,
- job_submission_id=request.job_submission_id,
- start_time=request.start_time,
- end_time=request.end_time,
- descending=request.descending,
+ cw_events: List[_CloudWatchLogEvent]
+ next_token: Optional[str] = None
+ with self._wrap_boto_errors():
+ try:
+ cw_events, next_token = self._get_log_events_with_retry(stream, request)
+ except botocore.exceptions.ClientError as e:
+ if not self._is_resource_not_found_exception(e):
+ raise
+ # Check if the group exists to distinguish between group not found vs stream not found
+ try:
+ self._check_group_exists(self._group)
+ # Group exists, so the error must be due to missing stream
+ logger.debug("Stream %s not found, returning dummy response", stream)
+ cw_events = []
+ except LogStorageError:
+ # Group doesn't exist, re-raise the LogStorageError
+ raise
+ logs = [
+ LogEvent(
+ timestamp=unix_time_ms_to_datetime(cw_event["timestamp"]),
+ log_source=LogEventSource.STDOUT,
+ message=cw_event["message"],
+ )
+ for cw_event in cw_events
+ ]
+ return JobSubmissionLogs(
+ logs=logs,
+ external_url=self._get_stream_external_url(stream),
next_token=next_token,
- limit=request.limit,
- diagnose=request.diagnose,
)
- if not request.descending:
- logger.debug(
- "Stream %s: exhausted %d retries without finding logs, returning empty response",
- stream,
- self.MAX_RETRIES,
- )
- # Only return the next token after exhausting retries if going descending—
- # AWS CloudWatch guarantees more logs in that case. In ascending mode,
- # next token is always returned, even if no logs remain.
- # So descending works reliably; ascending has limits if gaps are too large.
- # In the future, UI/CLI should handle retries, and we can return next token for ascending too.
- return [], next_token if request.descending else None
-
- def _get_log_events(
- self, stream: str, request: PollLogsRequest
- ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]:
- start_from_head = not request.descending
- parameters = {
- "logGroupName": self._group,
- "logStreamName": stream,
- "limit": request.limit,
- "startFromHead": start_from_head,
- }
-
- if request.start_time:
- parameters["startTime"] = datetime_to_unix_time_ms(request.start_time)
-
- if request.end_time:
- parameters["endTime"] = datetime_to_unix_time_ms(request.end_time)
- elif start_from_head:
- # When startFromHead=true and no endTime is provided, set endTime to "now"
- # to prevent infinite pagination as new logs arrive faster than we can read them
- parameters["endTime"] = datetime_to_unix_time_ms(datetime.now(timezone.utc))
-
- if request.next_token:
- parameters["nextToken"] = request.next_token
-
- response = self._client.get_log_events(**parameters)
-
- events = response.get("events", [])
- next_token_key = "nextForwardToken" if start_from_head else "nextBackwardToken"
- next_token = response.get(next_token_key)
-
- # TODO: The code below is not going to be used until we migrate from base64-encoded logs to plain text logs.
- if request.descending:
- events = list(reversed(events))
-
- return events, next_token
-
- def _get_stream_external_url(self, stream: str) -> str:
- quoted_group = urllib.parse.quote(self._group, safe="")
- quoted_stream = urllib.parse.quote(stream, safe="")
- return f"https://console.aws.amazon.com/cloudwatch/home?region={self._region}#logsV2:log-groups/log-group/{quoted_group}/log-events/{quoted_stream}"
-
- def write_logs(
- self,
- project: ProjectModel,
- run_name: str,
- job_submission_id: UUID,
- runner_logs: List[RunnerLogEvent],
- job_logs: List[RunnerLogEvent],
- ):
- if len(runner_logs) > 0:
- runner_stream = self._get_stream_name(
- project.name, run_name, job_submission_id, LogProducer.RUNNER
- )
- self._write_logs(
- stream=runner_stream,
- log_events=runner_logs,
- )
- if len(job_logs) > 0:
- jog_stream = self._get_stream_name(
- project.name, run_name, job_submission_id, LogProducer.JOB
- )
- self._write_logs(
- stream=jog_stream,
- log_events=job_logs,
- )
+ def _get_log_events_with_retry(
+ self, stream: str, request: PollLogsRequest
+ ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]:
+ current_request = request
+ previous_next_token = request.next_token
+ next_token = None
+
+ for _ in range(self.MAX_RETRIES):
+ cw_events, next_token = self._get_log_events(stream, current_request)
+
+ if cw_events:
+ return cw_events, next_token
+
+ if not next_token or next_token == previous_next_token:
+ return [], None
+
+ previous_next_token = next_token
+ current_request = PollLogsRequest(
+ run_name=request.run_name,
+ job_submission_id=request.job_submission_id,
+ start_time=request.start_time,
+ end_time=request.end_time,
+ descending=request.descending,
+ next_token=next_token,
+ limit=request.limit,
+ diagnose=request.diagnose,
+ )
- def _write_logs(self, stream: str, log_events: List[RunnerLogEvent]) -> None:
- with self._wrap_boto_errors():
- self._ensure_stream_exists(stream)
- try:
+ if not request.descending:
+ logger.debug(
+ "Stream %s: exhausted %d retries without finding logs, returning empty response",
+ stream,
+ self.MAX_RETRIES,
+ )
+ # Only return the next token after exhausting retries if going descending—
+ # AWS CloudWatch guarantees more logs in that case. In ascending mode,
+ # next token is always returned, even if no logs remain.
+ # So descending works reliably; ascending has limits if gaps are too large.
+ # In the future, UI/CLI should handle retries, and we can return next token for ascending too.
+ return [], next_token if request.descending else None
+
+ def _get_log_events(
+ self, stream: str, request: PollLogsRequest
+ ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]:
+ start_from_head = not request.descending
+ parameters = {
+ "logGroupName": self._group,
+ "logStreamName": stream,
+ "limit": request.limit,
+ "startFromHead": start_from_head,
+ }
+
+ if request.start_time:
+ parameters["startTime"] = datetime_to_unix_time_ms(request.start_time)
+
+ if request.end_time:
+ parameters["endTime"] = datetime_to_unix_time_ms(request.end_time)
+ elif start_from_head:
+ # When startFromHead=true and no endTime is provided, set endTime to "now"
+ # to prevent infinite pagination as new logs arrive faster than we can read them
+ parameters["endTime"] = datetime_to_unix_time_ms(datetime.now(timezone.utc))
+
+ if request.next_token:
+ parameters["nextToken"] = request.next_token
+
+ response = self._client.get_log_events(**parameters)
+
+ events = response.get("events", [])
+ next_token_key = "nextForwardToken" if start_from_head else "nextBackwardToken"
+ next_token = response.get(next_token_key)
+
+ # TODO: The code below is not going to be used until we migrate from base64-encoded logs to plain text logs.
+ if request.descending:
+ events = list(reversed(events))
+
+ return events, next_token
+
+ def _get_stream_external_url(self, stream: str) -> str:
+ quoted_group = urllib.parse.quote(self._group, safe="")
+ quoted_stream = urllib.parse.quote(stream, safe="")
+ return f"https://console.aws.amazon.com/cloudwatch/home?region={self._region}#logsV2:log-groups/log-group/{quoted_group}/log-events/{quoted_stream}"
+
+ def write_logs(
+ self,
+ project: ProjectModel,
+ run_name: str,
+ job_submission_id: UUID,
+ runner_logs: List[RunnerLogEvent],
+ job_logs: List[RunnerLogEvent],
+ ):
+ if len(runner_logs) > 0:
+ runner_stream = self._get_stream_name(
+ project.name, run_name, job_submission_id, LogProducer.RUNNER
+ )
+ self._write_logs(
+ stream=runner_stream,
+ log_events=runner_logs,
+ )
+ if len(job_logs) > 0:
+ jog_stream = self._get_stream_name(
+ project.name, run_name, job_submission_id, LogProducer.JOB
+ )
+ self._write_logs(
+ stream=jog_stream,
+ log_events=job_logs,
+ )
+
+ def _write_logs(self, stream: str, log_events: List[RunnerLogEvent]) -> None:
+ with self._wrap_boto_errors():
+ self._ensure_stream_exists(stream)
+ try:
+ self._put_log_events(stream, log_events)
+ return
+ except botocore.exceptions.ClientError as e:
+ if not self._is_resource_not_found_exception(e):
+ raise
+ logger.debug("Stream %s not found, recreating", stream)
+ # The stream is probably deleted due to retention policy, our cache is stale.
+ self._ensure_stream_exists(stream, force=True)
self._put_log_events(stream, log_events)
- return
- except botocore.exceptions.ClientError as e:
- if not self._is_resource_not_found_exception(e):
- raise
- logger.debug("Stream %s not found, recreating", stream)
- # The stream is probably deleted due to retention policy, our cache is stale.
- self._ensure_stream_exists(stream, force=True)
- self._put_log_events(stream, log_events)
-
- def _put_log_events(self, stream: str, log_events: List[RunnerLogEvent]) -> None:
- # Python docs: "The built-in sorted() function is guaranteed to be stable."
- sorted_log_events = sorted(log_events, key=operator.attrgetter("timestamp"))
- if tuple(map(id, log_events)) != tuple(map(id, sorted_log_events)):
- logger.error(
- "Stream %s: events are not in chronological order, something wrong with runner",
- stream,
- )
- for batch in self._get_batch_iter(stream, sorted_log_events):
- self._client.put_log_events(
- logGroupName=self._group,
- logStreamName=stream,
- logEvents=batch,
- )
- def _get_batch_iter(
- self, stream: str, log_events: List[RunnerLogEvent]
- ) -> Iterator[List[_CloudWatchLogEvent]]:
- shared_event_iter = iter(log_events)
- event_iter = shared_event_iter
- while True:
- batch, excessive_event = self._get_next_batch(stream, event_iter)
- if not batch:
- return
- yield batch
- if excessive_event is not None:
- event_iter = itertools.chain([excessive_event], shared_event_iter)
- else:
- event_iter = shared_event_iter
-
- def _get_next_batch(
- self, stream: str, event_iter: Iterator[RunnerLogEvent]
- ) -> Tuple[List[_CloudWatchLogEvent], Optional[RunnerLogEvent]]:
- now_timestamp = int(datetime.now(timezone.utc).timestamp() * 1000)
- batch: List[_CloudWatchLogEvent] = []
- total_size = 0
- event_count = 0
- first_timestamp: Optional[int] = None
- skipped_past_events = 0
- skipped_future_events = 0
- # event that doesn't fit in the current batch
- excessive_event: Optional[RunnerLogEvent] = None
- for event in event_iter:
- # Normally there should not be empty messages.
- if not event.message:
- continue
- timestamp = event.timestamp
- if first_timestamp is None:
- first_timestamp = timestamp
- elif timestamp - first_timestamp > self.BATCH_MAX_SPAN:
- excessive_event = event
- break
- if now_timestamp - timestamp > self.PAST_EVENT_MAX_DELTA:
- skipped_past_events += 1
- continue
- if timestamp - now_timestamp > self.FUTURE_EVENT_MAX_DELTA:
- skipped_future_events += 1
- continue
- cw_event = self._runner_log_event_to_cloudwatch_event(event)
- message_size = len(event.message) + self.MESSAGE_OVERHEAD_SIZE
- if message_size > self.MESSAGE_MAX_SIZE:
- # we should never hit this limit, as we use `io.Copy` to copy from pty to logs,
- # which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go,
- # `execJob` -> `io.Copy(logger, ptmx)`
+ def _put_log_events(self, stream: str, log_events: List[RunnerLogEvent]) -> None:
+ # Python docs: "The built-in sorted() function is guaranteed to be stable."
+ sorted_log_events = sorted(log_events, key=operator.attrgetter("timestamp"))
+ if tuple(map(id, log_events)) != tuple(map(id, sorted_log_events)):
logger.error(
- "Stream %s: skipping event %d, message exceeds max size: %d > %d",
+ "Stream %s: events are not in chronological order, something wrong with runner",
stream,
- timestamp,
- message_size,
- self.MESSAGE_MAX_SIZE,
)
- continue
- if total_size + message_size > self.BATCH_MAX_SIZE:
- excessive_event = event
- break
- batch.append(cw_event)
- total_size += message_size
- event_count += 1
- if event_count >= self.EVENT_MAX_COUNT_IN_BATCH:
- break
- if skipped_past_events > 0:
- logger.error("Stream %s: skipping %d past event(s)", stream, skipped_past_events)
- if skipped_future_events > 0:
- logger.error("Stream %s: skipping %d future event(s)", stream, skipped_future_events)
- return batch, excessive_event
-
- def _runner_log_event_to_cloudwatch_event(
- self, runner_log_event: RunnerLogEvent
- ) -> _CloudWatchLogEvent:
- return {
- "timestamp": runner_log_event.timestamp,
- "message": runner_log_event.message.decode(errors="replace"),
- }
-
- @contextmanager
- def _wrap_boto_errors(self) -> Iterator[None]:
- try:
- yield
- except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
- raise LogStorageError(f"CloudWatch Logs error: {type(e).__name__}: {e}") from e
-
- def _is_resource_not_found_exception(self, exc: "botocore.exceptions.ClientError") -> bool:
- try:
- return exc.response["Error"]["Code"] == "ResourceNotFoundException"
- except KeyError:
- return False
-
- def _check_group_exists(self, name: str) -> None:
- try:
- self._client.describe_log_streams(logGroupName=name, limit=1)
- except botocore.exceptions.ClientError as e:
- if self._is_resource_not_found_exception(e):
- raise LogStorageError(f"LogGroup '{name}' does not exist")
- raise
-
- def _ensure_stream_exists(self, name: str, *, force: bool = False) -> None:
- if not force and name in self._streams:
- return
- response = self._client.describe_log_streams(
- logGroupName=self._group, logStreamNamePrefix=name
- )
- for stream in response["logStreams"]:
- if stream["logStreamName"] == name:
- self._streams.add(name)
+ for batch in self._get_batch_iter(stream, sorted_log_events):
+ self._client.put_log_events(
+ logGroupName=self._group,
+ logStreamName=stream,
+ logEvents=batch,
+ )
+
+ def _get_batch_iter(
+ self, stream: str, log_events: List[RunnerLogEvent]
+ ) -> Iterator[List[_CloudWatchLogEvent]]:
+ shared_event_iter = iter(log_events)
+ event_iter = shared_event_iter
+ while True:
+ batch, excessive_event = self._get_next_batch(stream, event_iter)
+ if not batch:
+ return
+ yield batch
+ if excessive_event is not None:
+ event_iter = itertools.chain([excessive_event], shared_event_iter)
+ else:
+ event_iter = shared_event_iter
+
+ def _get_next_batch(
+ self, stream: str, event_iter: Iterator[RunnerLogEvent]
+ ) -> Tuple[List[_CloudWatchLogEvent], Optional[RunnerLogEvent]]:
+ now_timestamp = int(datetime.now(timezone.utc).timestamp() * 1000)
+ batch: List[_CloudWatchLogEvent] = []
+ total_size = 0
+ event_count = 0
+ first_timestamp: Optional[int] = None
+ skipped_past_events = 0
+ skipped_future_events = 0
+ # event that doesn't fit in the current batch
+ excessive_event: Optional[RunnerLogEvent] = None
+ for event in event_iter:
+ # Normally there should not be empty messages.
+ if not event.message:
+ continue
+ timestamp = event.timestamp
+ if first_timestamp is None:
+ first_timestamp = timestamp
+ elif timestamp - first_timestamp > self.BATCH_MAX_SPAN:
+ excessive_event = event
+ break
+ if now_timestamp - timestamp > self.PAST_EVENT_MAX_DELTA:
+ skipped_past_events += 1
+ continue
+ if timestamp - now_timestamp > self.FUTURE_EVENT_MAX_DELTA:
+ skipped_future_events += 1
+ continue
+ cw_event = self._runner_log_event_to_cloudwatch_event(event)
+ message_size = len(event.message) + self.MESSAGE_OVERHEAD_SIZE
+ if message_size > self.MESSAGE_MAX_SIZE:
+ # we should never hit this limit, as we use `io.Copy` to copy from pty to logs,
+ # which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go,
+ # `execJob` -> `io.Copy(logger, ptmx)`
+ logger.error(
+ "Stream %s: skipping event %d, message exceeds max size: %d > %d",
+ stream,
+ timestamp,
+ message_size,
+ self.MESSAGE_MAX_SIZE,
+ )
+ continue
+ if total_size + message_size > self.BATCH_MAX_SIZE:
+ excessive_event = event
+ break
+ batch.append(cw_event)
+ total_size += message_size
+ event_count += 1
+ if event_count >= self.EVENT_MAX_COUNT_IN_BATCH:
+ break
+ if skipped_past_events > 0:
+ logger.error("Stream %s: skipping %d past event(s)", stream, skipped_past_events)
+ if skipped_future_events > 0:
+ logger.error(
+ "Stream %s: skipping %d future event(s)", stream, skipped_future_events
+ )
+ return batch, excessive_event
+
+ def _runner_log_event_to_cloudwatch_event(
+ self, runner_log_event: RunnerLogEvent
+ ) -> _CloudWatchLogEvent:
+ return {
+ "timestamp": runner_log_event.timestamp,
+ "message": runner_log_event.message.decode(errors="replace"),
+ }
+
+ @contextmanager
+ def _wrap_boto_errors(self) -> Iterator[None]:
+ try:
+ yield
+ except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e:
+ raise LogStorageError(f"CloudWatch Logs error: {type(e).__name__}: {e}") from e
+
+ def _is_resource_not_found_exception(self, exc: "botocore.exceptions.ClientError") -> bool:
+ try:
+ return exc.response["Error"]["Code"] == "ResourceNotFoundException"
+ except KeyError:
+ return False
+
+ def _check_group_exists(self, name: str) -> None:
+ try:
+ self._client.describe_log_streams(logGroupName=name, limit=1)
+ except botocore.exceptions.ClientError as e:
+ if self._is_resource_not_found_exception(e):
+ raise LogStorageError(f"LogGroup '{name}' does not exist")
+ raise
+
+ def _ensure_stream_exists(self, name: str, *, force: bool = False) -> None:
+ if not force and name in self._streams:
return
- self._client.create_log_stream(logGroupName=self._group, logStreamName=name)
- self._streams.add(name)
-
- def _get_stream_name(
- self,
- project_name: str,
- run_name: str,
- job_submission_id: UUID,
- producer: LogProducer,
- ) -> str:
- return f"{project_name}/{run_name}/{job_submission_id}/{producer.value}"
+ response = self._client.describe_log_streams(
+ logGroupName=self._group, logStreamNamePrefix=name
+ )
+ for stream in response["logStreams"]:
+ if stream["logStreamName"] == name:
+ self._streams.add(name)
+ return
+ self._client.create_log_stream(logGroupName=self._group, logStreamName=name)
+ self._streams.add(name)
+
+ def _get_stream_name(
+ self,
+ project_name: str,
+ run_name: str,
+ job_submission_id: UUID,
+ producer: LogProducer,
+ ) -> str:
+ return f"{project_name}/{run_name}/{job_submission_id}/{producer.value}"
diff --git a/src/dstack/_internal/server/services/logs/filelog.py b/src/dstack/_internal/server/services/logs/filelog.py
index 823222a409..e4289805c6 100644
--- a/src/dstack/_internal/server/services/logs/filelog.py
+++ b/src/dstack/_internal/server/services/logs/filelog.py
@@ -48,7 +48,7 @@ def _poll_logs_ascending(
) -> JobSubmissionLogs:
start_line = 0
if request.next_token:
- start_line = self._next_token(request)
+ start_line = self._parse_next_token(request.next_token)
logs = []
next_token = None
@@ -97,7 +97,9 @@ def _poll_logs_ascending(
def _poll_logs_descending(
self, log_file_path: Path, request: PollLogsRequest
) -> JobSubmissionLogs:
- start_offset = self._next_token(request)
+ start_offset = None
+ if request.next_token is not None:
+ start_offset = self._parse_next_token(request.next_token)
candidate_logs = []
@@ -123,12 +125,12 @@ def _poll_logs_descending(
except FileNotFoundError:
return JobSubmissionLogs(logs=[], next_token=None)
- logs = [log for log, offset in candidate_logs[: request.limit]]
+ logs = [log for log, _ in candidate_logs[: request.limit]]
next_token = None
if len(candidate_logs) > request.limit:
# We fetched one more than the limit, so there are more pages.
# The next token should point to the start of the last log we are returning.
- _last_log_event, last_log_offset = candidate_logs[request.limit - 1]
+ _, last_log_offset = candidate_logs[request.limit - 1]
next_token = str(last_log_offset)
return JobSubmissionLogs(logs=logs, next_token=next_token)
@@ -245,8 +247,7 @@ def _runner_log_event_to_log_event(self, runner_log_event: RunnerLogEvent) -> Lo
message=runner_log_event.message.decode(errors="replace"),
)
- def _next_token(self, request: PollLogsRequest) -> Optional[int]:
- next_token = request.next_token
+ def _parse_next_token(self, next_token: str) -> int:
if next_token is None:
return None
try:
diff --git a/src/dstack/_internal/server/services/logs/gcp.py b/src/dstack/_internal/server/services/logs/gcp.py
index 7faa727dc1..c1b1a75cf1 100644
--- a/src/dstack/_internal/server/services/logs/gcp.py
+++ b/src/dstack/_internal/server/services/logs/gcp.py
@@ -20,6 +20,9 @@
from dstack._internal.utils.common import batched
from dstack._internal.utils.logging import get_logger
+logger = get_logger(__name__)
+
+
GCP_LOGGING_AVAILABLE = True
try:
import google.api_core.exceptions
@@ -28,152 +31,151 @@
from google.cloud.logging_v2.types import ListLogEntriesRequest
except ImportError:
GCP_LOGGING_AVAILABLE = False
-
-
-logger = get_logger(__name__)
-
-
-class GCPLogStorage(LogStorage):
- # Max expected message size from runner is 32KB.
- # Max expected LogEntry size is 32KB + metadata < 50KB < 256KB limit.
- # With MAX_BATCH_SIZE = 100, max write request size < 5MB < 10 MB limit.
- # See: https://cloud.google.com/logging/quotas.
- MAX_RUNNER_MESSAGE_SIZE = 32 * 1024
- MAX_BATCH_SIZE = 100
-
- # Use the same log name for all run logs so that it's easy to manage all dstack-related logs.
- LOG_NAME = "dstack-run-logs"
- # Logs from different jobs belong to different "streams".
- # GCP Logging has no built-in concepts of streams, so we implement them with labels.
- # It should be fast to filter by labels since labels are indexed by default
- # (https://cloud.google.com/logging/docs/analyze/custom-index).
-
- def __init__(self, project_id: str):
- self.project_id = project_id
- try:
- self.client = logging_v2.Client(project=project_id)
- self.logger = self.client.logger(name=self.LOG_NAME)
- self.logger.list_entries(max_results=1)
- # Python client doesn't seem to support dry_run,
- # so emit an empty log to check permissions.
- self.logger.log_empty()
- except google.auth.exceptions.DefaultCredentialsError:
- raise LogStorageError("Default credentials not found")
- except google.api_core.exceptions.NotFound:
- raise LogStorageError(f"Project {project_id} not found")
- except google.api_core.exceptions.PermissionDenied:
- raise LogStorageError("Insufficient permissions")
-
- def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
- # TODO: GCP may return logs in random order when events have the same timestamp.
- producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
- stream_name = self._get_stream_name(
- project_name=project.name,
- run_name=request.run_name,
- job_submission_id=request.job_submission_id,
- producer=producer,
- )
- log_filters = [f'labels.stream = "{stream_name}"']
- if request.start_time:
- log_filters.append(f'timestamp > "{request.start_time.isoformat()}"')
- if request.end_time:
- log_filters.append(f'timestamp < "{request.end_time.isoformat()}"')
- log_filter = " AND ".join(log_filters)
-
- order_by = logging_v2.DESCENDING if request.descending else logging_v2.ASCENDING
- try:
- # Use low-level API to get access to next_page_token
- request_obj = ListLogEntriesRequest(
- resource_names=[f"projects/{self.client.project}"],
- filter=log_filter,
- order_by=order_by,
- page_size=request.limit,
- page_token=request.next_token,
- )
- response = self.client._logging_api._gapic_api.list_log_entries(request=request_obj)
-
- logs = [
- LogEvent(
- timestamp=entry.timestamp,
- message=entry.json_payload.get("message"),
- log_source=LogEventSource.STDOUT,
- )
- for entry in response.entries
- ]
- next_token = response.next_page_token or None
- except google.api_core.exceptions.ResourceExhausted as e:
- logger.warning("GCP Logging exception: %s", repr(e))
- # GCP Logging has severely low quota of 60 reads/min for entries.list
- raise ServerClientError(
- "GCP Logging read request limit exceeded."
- " It's recommended to increase default entries.list request quota from 60 per minute."
- )
- return JobSubmissionLogs(
- logs=logs,
- external_url=self._get_stream_extrnal_url(stream_name),
- next_token=next_token if len(logs) > 0 else None,
- )
-
- def write_logs(
- self,
- project: ProjectModel,
- run_name: str,
- job_submission_id: UUID,
- runner_logs: List[RunnerLogEvent],
- job_logs: List[RunnerLogEvent],
- ):
- producers_with_logs = [(LogProducer.RUNNER, runner_logs), (LogProducer.JOB, job_logs)]
- for producer, producer_logs in producers_with_logs:
+else:
+
+ class GCPLogStorage(LogStorage):
+ # Max expected message size from runner is 32KB.
+ # Max expected LogEntry size is 32KB + metadata < 50KB < 256KB limit.
+ # With MAX_BATCH_SIZE = 100, max write request size < 5MB < 10 MB limit.
+ # See: https://cloud.google.com/logging/quotas.
+ MAX_RUNNER_MESSAGE_SIZE = 32 * 1024
+ MAX_BATCH_SIZE = 100
+
+ # Use the same log name for all run logs so that it's easy to manage all dstack-related logs.
+ LOG_NAME = "dstack-run-logs"
+ # Logs from different jobs belong to different "streams".
+ # GCP Logging has no built-in concepts of streams, so we implement them with labels.
+ # It should be fast to filter by labels since labels are indexed by default
+ # (https://cloud.google.com/logging/docs/analyze/custom-index).
+
+ def __init__(self, project_id: str):
+ self.project_id = project_id
+ try:
+ self.client = logging_v2.Client(project=project_id)
+ self.logger = self.client.logger(name=self.LOG_NAME)
+ self.logger.list_entries(max_results=1)
+ # Python client doesn't seem to support dry_run,
+ # so emit an empty log to check permissions.
+ self.logger.log_empty()
+ except google.auth.exceptions.DefaultCredentialsError:
+ raise LogStorageError("Default credentials not found")
+ except google.api_core.exceptions.NotFound:
+ raise LogStorageError(f"Project {project_id} not found")
+ except google.api_core.exceptions.PermissionDenied:
+ raise LogStorageError("Insufficient permissions")
+
+ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
+ # TODO: GCP may return logs in random order when events have the same timestamp.
+ producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
stream_name = self._get_stream_name(
project_name=project.name,
- run_name=run_name,
- job_submission_id=job_submission_id,
+ run_name=request.run_name,
+ job_submission_id=request.job_submission_id,
producer=producer,
)
- self._write_logs_to_stream(
- stream_name=stream_name,
- logs=producer_logs,
+ log_filters = [f'labels.stream = "{stream_name}"']
+ if request.start_time:
+ log_filters.append(f'timestamp > "{request.start_time.isoformat()}"')
+ if request.end_time:
+ log_filters.append(f'timestamp < "{request.end_time.isoformat()}"')
+ log_filter = " AND ".join(log_filters)
+
+ order_by = logging_v2.DESCENDING if request.descending else logging_v2.ASCENDING
+ try:
+ # Use low-level API to get access to next_page_token
+ request_obj = ListLogEntriesRequest(
+ resource_names=[f"projects/{self.client.project}"],
+ filter=log_filter,
+ order_by=order_by,
+ page_size=request.limit,
+ page_token=request.next_token,
+ )
+ response = self.client._logging_api._gapic_api.list_log_entries( # type: ignore[attr-defined]
+ request=request_obj
+ )
+
+ logs = [
+ LogEvent(
+ timestamp=entry.timestamp,
+ message=entry.json_payload.get("message"),
+ log_source=LogEventSource.STDOUT,
+ )
+ for entry in response.entries
+ ]
+ next_token = response.next_page_token or None
+ except google.api_core.exceptions.ResourceExhausted as e:
+ logger.warning("GCP Logging exception: %s", repr(e))
+ # GCP Logging has severely low quota of 60 reads/min for entries.list
+ raise ServerClientError(
+ "GCP Logging read request limit exceeded."
+ " It's recommended to increase default entries.list request quota from 60 per minute."
+ )
+ return JobSubmissionLogs(
+ logs=logs,
+ external_url=self._get_stream_extrnal_url(stream_name),
+ next_token=next_token if len(logs) > 0 else None,
)
- def close(self):
- self.client.close()
-
- def _write_logs_to_stream(self, stream_name: str, logs: List[RunnerLogEvent]):
- with self.logger.batch() as batcher:
- for batch in batched(logs, self.MAX_BATCH_SIZE):
- for log in batch:
- message = log.message.decode(errors="replace")
- timestamp = unix_time_ms_to_datetime(log.timestamp)
- if len(log.message) > self.MAX_RUNNER_MESSAGE_SIZE:
- logger.error(
- "Stream %s: skipping event at %s, message exceeds max size: %d > %d",
- stream_name,
- timestamp.isoformat(),
- len(log.message),
- self.MAX_RUNNER_MESSAGE_SIZE,
+ def write_logs(
+ self,
+ project: ProjectModel,
+ run_name: str,
+ job_submission_id: UUID,
+ runner_logs: List[RunnerLogEvent],
+ job_logs: List[RunnerLogEvent],
+ ):
+ producers_with_logs = [(LogProducer.RUNNER, runner_logs), (LogProducer.JOB, job_logs)]
+ for producer, producer_logs in producers_with_logs:
+ stream_name = self._get_stream_name(
+ project_name=project.name,
+ run_name=run_name,
+ job_submission_id=job_submission_id,
+ producer=producer,
+ )
+ self._write_logs_to_stream(
+ stream_name=stream_name,
+ logs=producer_logs,
+ )
+
+ def close(self):
+ self.client.close()
+
+ def _write_logs_to_stream(self, stream_name: str, logs: List[RunnerLogEvent]):
+ with self.logger.batch() as batcher:
+ for batch in batched(logs, self.MAX_BATCH_SIZE):
+ for log in batch:
+ message = log.message.decode(errors="replace")
+ timestamp = unix_time_ms_to_datetime(log.timestamp)
+ if len(log.message) > self.MAX_RUNNER_MESSAGE_SIZE:
+ logger.error(
+ "Stream %s: skipping event at %s, message exceeds max size: %d > %d",
+ stream_name,
+ timestamp.isoformat(),
+ len(log.message),
+ self.MAX_RUNNER_MESSAGE_SIZE,
+ )
+ continue
+ batcher.log_struct(
+ {
+ "message": message,
+ },
+ labels={
+ "stream": stream_name,
+ },
+ timestamp=timestamp,
)
- continue
- batcher.log_struct(
- {
- "message": message,
- },
- labels={
- "stream": stream_name,
- },
- timestamp=timestamp,
- )
- batcher.commit()
+ batcher.commit()
- def _get_stream_name(
- self, project_name: str, run_name: str, job_submission_id: UUID, producer: LogProducer
- ) -> str:
- return f"{project_name}-{run_name}-{job_submission_id}-{producer.value}"
+ def _get_stream_name(
+ self, project_name: str, run_name: str, job_submission_id: UUID, producer: LogProducer
+ ) -> str:
+ return f"{project_name}-{run_name}-{job_submission_id}-{producer.value}"
- def _get_stream_extrnal_url(self, stream_name: str) -> str:
- log_name_resource_name = self._get_log_name_resource_name()
- query = f'logName="{log_name_resource_name}" AND labels.stream="{stream_name}"'
- quoted_query = urllib.parse.quote(query, safe="")
- return f"https://console.cloud.google.com/logs/query;query={quoted_query}?project={self.project_id}"
+ def _get_stream_extrnal_url(self, stream_name: str) -> str:
+ log_name_resource_name = self._get_log_name_resource_name()
+ query = f'logName="{log_name_resource_name}" AND labels.stream="{stream_name}"'
+ quoted_query = urllib.parse.quote(query, safe="")
+ return f"https://console.cloud.google.com/logs/query;query={quoted_query}?project={self.project_id}"
- def _get_log_name_resource_name(self) -> str:
- return f"projects/{self.project_id}/logs/{self.LOG_NAME}"
+ def _get_log_name_resource_name(self) -> str:
+ return f"projects/{self.project_id}/logs/{self.LOG_NAME}"
diff --git a/src/dstack/_internal/server/services/plugins.py b/src/dstack/_internal/server/services/plugins.py
index 8acd101f9c..933ed43052 100644
--- a/src/dstack/_internal/server/services/plugins.py
+++ b/src/dstack/_internal/server/services/plugins.py
@@ -60,7 +60,7 @@ def load_plugins(enabled_plugins: list[str]):
_PLUGINS.clear()
entrypoints: dict[str, PluginEntrypoint] = {}
plugins_to_load = enabled_plugins.copy()
- for entrypoint in entry_points(group="dstack.plugins"):
+ for entrypoint in entry_points(group="dstack.plugins"): # type: ignore[call-arg]
if entrypoint.name not in enabled_plugins:
logger.info(
("Found not enabled plugin %s. Plugin will not be loaded."),
diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py
index 992e1be046..2ec37523e4 100644
--- a/src/dstack/_internal/server/services/projects.py
+++ b/src/dstack/_internal/server/services/projects.py
@@ -19,7 +19,7 @@
from dstack._internal.server.schemas.projects import MemberSetting
from dstack._internal.server.services import users
from dstack._internal.server.services.backends import (
- get_backend_config_from_backend_model,
+ get_backend_config_without_creds_from_backend_model,
)
from dstack._internal.server.services.permissions import get_default_permissions
from dstack._internal.server.settings import DEFAULT_PROJECT_NAME
@@ -313,7 +313,6 @@ async def add_project_members(
member_num=None,
commit=False,
)
- member_by_user_id[user_to_add.id] = None
await session.commit()
@@ -544,9 +543,7 @@ def project_model_to_project(
b.type.value,
)
continue
- backend_config = get_backend_config_from_backend_model(
- configurator, b, include_creds=False
- )
+ backend_config = get_backend_config_without_creds_from_backend_model(configurator, b)
if isinstance(backend_config, DstackBackendConfig):
for backend_type in backend_config.base_backends:
backends.append(
diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py
index 8e12a6daeb..ae7ea19f8d 100644
--- a/src/dstack/_internal/server/services/proxy/repo.py
+++ b/src/dstack/_internal/server/services/proxy/repo.py
@@ -74,6 +74,8 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw(
job.job_provisioning_data
)
+ assert jpd.hostname is not None
+ assert jpd.ssh_port is not None
if not jpd.dockerized:
ssh_destination = f"{jpd.username}@{jpd.hostname}"
ssh_port = jpd.ssh_port
@@ -140,7 +142,7 @@ async def list_models(self, project_name: str) -> List[ChatModel]:
model_options_obj = service_spec.options.get("openai", {}).get("model")
if model_spec is None or model_options_obj is None:
continue
- model_options = pydantic.parse_obj_as(AnyModel, model_options_obj)
+ model_options = pydantic.parse_obj_as(AnyModel, model_options_obj) # type: ignore[arg-type]
model = ChatModel(
project_name=project_name,
name=model_spec.name,
@@ -175,6 +177,8 @@ def _model_options_to_format_spec(model: AnyModel) -> AnyModelFormat:
if model.format == "openai":
return OpenAIChatModelFormat(prefix=model.prefix)
elif model.format == "tgi":
+ assert model.chat_template is not None
+ assert model.eos_token is not None
return TGIChatModelFormat(
chat_template=model.chat_template,
eos_token=model.eos_token,
diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py
index 81d34a2ae3..1e4a2d2ad6 100644
--- a/src/dstack/_internal/server/services/runs.py
+++ b/src/dstack/_internal/server/services/runs.py
@@ -529,7 +529,7 @@ async def submit_run(
initial_status = RunStatus.PENDING
initial_replicas = 0
elif run_spec.configuration.type == "service":
- initial_replicas = run_spec.configuration.replicas.min
+ initial_replicas = run_spec.configuration.replicas.min or 0
run_model = RunModel(
id=uuid.uuid4(),
diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py
index aba2698ec2..a8089a93a9 100644
--- a/src/dstack/_internal/server/services/services/__init__.py
+++ b/src/dstack/_internal/server/services/services/__init__.py
@@ -5,7 +5,6 @@
import uuid
from datetime import datetime
from typing import Optional
-from urllib.parse import urlparse
import httpx
from sqlalchemy import select
@@ -73,6 +72,8 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec:
async def _register_service_in_gateway(
session: AsyncSession, run_model: RunModel, run_spec: RunSpec, gateway: GatewayModel
) -> ServiceSpec:
+ assert run_spec.configuration.type == "service"
+
if gateway.gateway_compute is None:
raise ServerClientError("Gateway has no instance associated with it")
@@ -100,6 +101,9 @@ async def _register_service_in_gateway(
model_url=f"{gateway_protocol}://gateway.{wildcard_domain}",
)
+ domain = service_spec.get_domain()
+ assert domain is not None
+
conn = await get_or_add_gateway_connection(session, gateway.id)
try:
logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url)
@@ -107,7 +111,7 @@ async def _register_service_in_gateway(
await client.register_service(
project=run_model.project.name,
run_name=run_model.run_name,
- domain=urlparse(service_spec.url).hostname,
+ domain=domain,
service_https=service_https,
gateway_https=gateway_https,
auth=run_spec.configuration.auth,
@@ -127,6 +131,7 @@ async def _register_service_in_gateway(
def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec:
+ assert run_spec.configuration.type == "service"
if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT:
# Note: if the user sets `https: `, it will be ignored silently
# TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted
@@ -270,6 +275,7 @@ async def unregister_replica(session: AsyncSession, job_model: JobModel):
def _get_service_https(run_spec: RunSpec, configuration: GatewayConfiguration) -> bool:
+ assert run_spec.configuration.type == "service"
if not run_spec.configuration.https:
return False
if configuration.certificate is not None and configuration.certificate.type == "acm":
diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py
index 47eabaab31..cd6d06e588 100644
--- a/src/dstack/_internal/server/services/services/autoscalers.py
+++ b/src/dstack/_internal/server/services/services/autoscalers.py
@@ -120,6 +120,8 @@ def get_desired_count(
def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler:
+ assert conf.replicas.min is not None
+ assert conf.replicas.max is not None
if conf.scaling is None:
return ManualScaler(
min_replicas=conf.replicas.min,
diff --git a/src/dstack/_internal/server/services/ssh.py b/src/dstack/_internal/server/services/ssh.py
index 2ab685eadb..a7967d8031 100644
--- a/src/dstack/_internal/server/services/ssh.py
+++ b/src/dstack/_internal/server/services/ssh.py
@@ -20,10 +20,11 @@ def container_ssh_tunnel(
"""
Build SSHTunnel for connecting to the container running the specified job.
"""
-
jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw(
job.job_provisioning_data
)
+ assert jpd.hostname is not None
+ assert jpd.ssh_port is not None
if not jpd.dockerized:
ssh_destination = f"{jpd.username}@{jpd.hostname}"
ssh_port = jpd.ssh_port
diff --git a/src/dstack/_internal/server/services/storage/__init__.py b/src/dstack/_internal/server/services/storage/__init__.py
index 14b75c3477..d76a5d4bee 100644
--- a/src/dstack/_internal/server/services/storage/__init__.py
+++ b/src/dstack/_internal/server/services/storage/__init__.py
@@ -1,9 +1,8 @@
from typing import Optional
from dstack._internal.server import settings
+from dstack._internal.server.services.storage import gcs, s3
from dstack._internal.server.services.storage.base import BaseStorage
-from dstack._internal.server.services.storage.gcs import GCS_AVAILABLE, GCSStorage
-from dstack._internal.server.services.storage.s3 import BOTO_AVAILABLE, S3Storage
_default_storage = None
@@ -20,16 +19,16 @@ def init_default_storage():
)
if settings.SERVER_S3_BUCKET:
- if not BOTO_AVAILABLE:
+ if not s3.BOTO_AVAILABLE:
raise ValueError("AWS dependencies are not installed")
- _default_storage = S3Storage(
+ _default_storage = s3.S3Storage(
bucket=settings.SERVER_S3_BUCKET,
region=settings.SERVER_S3_BUCKET_REGION,
)
elif settings.SERVER_GCS_BUCKET:
- if not GCS_AVAILABLE:
+ if not gcs.GCS_AVAILABLE:
raise ValueError("GCS dependencies are not installed")
- _default_storage = GCSStorage(
+ _default_storage = gcs.GCSStorage(
bucket=settings.SERVER_GCS_BUCKET,
)
diff --git a/src/dstack/_internal/server/services/storage/gcs.py b/src/dstack/_internal/server/services/storage/gcs.py
index 6c565625e2..a0f9ac568f 100644
--- a/src/dstack/_internal/server/services/storage/gcs.py
+++ b/src/dstack/_internal/server/services/storage/gcs.py
@@ -8,59 +8,59 @@
from google.cloud.exceptions import NotFound
except ImportError:
GCS_AVAILABLE = False
+else:
+ class GCSStorage(BaseStorage):
+ def __init__(
+ self,
+ bucket: str,
+ ):
+ self._client = storage.Client()
+ self._bucket = self._client.bucket(bucket)
-class GCSStorage(BaseStorage):
- def __init__(
- self,
- bucket: str,
- ):
- self._client = storage.Client()
- self._bucket = self._client.bucket(bucket)
+ def upload_code(
+ self,
+ project_id: str,
+ repo_id: str,
+ code_hash: str,
+ blob: bytes,
+ ):
+ key = self._get_code_key(project_id, repo_id, code_hash)
+ self._upload(key, blob)
- def upload_code(
- self,
- project_id: str,
- repo_id: str,
- code_hash: str,
- blob: bytes,
- ):
- key = self._get_code_key(project_id, repo_id, code_hash)
- self._upload(key, blob)
+ def get_code(
+ self,
+ project_id: str,
+ repo_id: str,
+ code_hash: str,
+ ) -> Optional[bytes]:
+ key = self._get_code_key(project_id, repo_id, code_hash)
+ return self._get(key)
- def get_code(
- self,
- project_id: str,
- repo_id: str,
- code_hash: str,
- ) -> Optional[bytes]:
- key = self._get_code_key(project_id, repo_id, code_hash)
- return self._get(key)
+ def upload_archive(
+ self,
+ user_id: str,
+ archive_hash: str,
+ blob: bytes,
+ ):
+ key = self._get_archive_key(user_id, archive_hash)
+ self._upload(key, blob)
- def upload_archive(
- self,
- user_id: str,
- archive_hash: str,
- blob: bytes,
- ):
- key = self._get_archive_key(user_id, archive_hash)
- self._upload(key, blob)
+ def get_archive(
+ self,
+ user_id: str,
+ archive_hash: str,
+ ) -> Optional[bytes]:
+ key = self._get_archive_key(user_id, archive_hash)
+ return self._get(key)
- def get_archive(
- self,
- user_id: str,
- archive_hash: str,
- ) -> Optional[bytes]:
- key = self._get_archive_key(user_id, archive_hash)
- return self._get(key)
+ def _upload(self, key: str, blob: bytes):
+ blob_obj = self._bucket.blob(key)
+ blob_obj.upload_from_string(blob)
- def _upload(self, key: str, blob: bytes):
- blob_obj = self._bucket.blob(key)
- blob_obj.upload_from_string(blob)
-
- def _get(self, key: str) -> Optional[bytes]:
- try:
- blob = self._bucket.blob(key)
- except NotFound:
- return None
- return blob.download_as_bytes()
+ def _get(self, key: str) -> Optional[bytes]:
+ try:
+ blob = self._bucket.blob(key)
+ except NotFound:
+ return None
+ return blob.download_as_bytes()
diff --git a/src/dstack/_internal/server/services/storage/s3.py b/src/dstack/_internal/server/services/storage/s3.py
index a0b993c731..df4b652d1d 100644
--- a/src/dstack/_internal/server/services/storage/s3.py
+++ b/src/dstack/_internal/server/services/storage/s3.py
@@ -8,62 +8,62 @@
from boto3 import Session
except ImportError:
BOTO_AVAILABLE = False
+else:
+ class S3Storage(BaseStorage):
+ def __init__(
+ self,
+ bucket: str,
+ region: Optional[str] = None,
+ ):
+ self._session = Session()
+ self._client = self._session.client("s3", region_name=region)
+ self.bucket = bucket
-class S3Storage(BaseStorage):
- def __init__(
- self,
- bucket: str,
- region: Optional[str] = None,
- ):
- self._session = Session()
- self._client = self._session.client("s3", region_name=region)
- self.bucket = bucket
+ def upload_code(
+ self,
+ project_id: str,
+ repo_id: str,
+ code_hash: str,
+ blob: bytes,
+ ):
+ key = self._get_code_key(project_id, repo_id, code_hash)
+ self._upload(key, blob)
- def upload_code(
- self,
- project_id: str,
- repo_id: str,
- code_hash: str,
- blob: bytes,
- ):
- key = self._get_code_key(project_id, repo_id, code_hash)
- self._upload(key, blob)
+ def get_code(
+ self,
+ project_id: str,
+ repo_id: str,
+ code_hash: str,
+ ) -> Optional[bytes]:
+ key = self._get_code_key(project_id, repo_id, code_hash)
+ return self._get(key)
- def get_code(
- self,
- project_id: str,
- repo_id: str,
- code_hash: str,
- ) -> Optional[bytes]:
- key = self._get_code_key(project_id, repo_id, code_hash)
- return self._get(key)
+ def upload_archive(
+ self,
+ user_id: str,
+ archive_hash: str,
+ blob: bytes,
+ ):
+ key = self._get_archive_key(user_id, archive_hash)
+ self._upload(key, blob)
- def upload_archive(
- self,
- user_id: str,
- archive_hash: str,
- blob: bytes,
- ):
- key = self._get_archive_key(user_id, archive_hash)
- self._upload(key, blob)
+ def get_archive(
+ self,
+ user_id: str,
+ archive_hash: str,
+ ) -> Optional[bytes]:
+ key = self._get_archive_key(user_id, archive_hash)
+ return self._get(key)
- def get_archive(
- self,
- user_id: str,
- archive_hash: str,
- ) -> Optional[bytes]:
- key = self._get_archive_key(user_id, archive_hash)
- return self._get(key)
+ def _upload(self, key: str, blob: bytes):
+ self._client.put_object(Bucket=self.bucket, Key=key, Body=blob)
- def _upload(self, key: str, blob: bytes):
- self._client.put_object(Bucket=self.bucket, Key=key, Body=blob)
-
- def _get(self, key: str) -> Optional[bytes]:
- try:
- response = self._client.get_object(Bucket=self.bucket, Key=key)
- except botocore.exceptions.ClientError as e:
- if e.response["Error"]["Code"] == "NoSuchKey":
- return None
- raise e
- return response["Body"].read()
+ def _get(self, key: str) -> Optional[bytes]:
+ try:
+ response = self._client.get_object(Bucket=self.bucket, Key=key)
+ except botocore.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "NoSuchKey":
+ return None
+ raise e
+ return response["Body"].read()
diff --git a/src/dstack/_internal/server/utils/logging.py b/src/dstack/_internal/server/utils/logging.py
index 1ea58578bf..03d7d05cb4 100644
--- a/src/dstack/_internal/server/utils/logging.py
+++ b/src/dstack/_internal/server/utils/logging.py
@@ -31,15 +31,15 @@ def configure_logging():
rename_fields={"name": "logger", "asctime": "timestamp", "levelname": "level"},
),
}
- handlers = {
+ handlers: dict[str, logging.Handler] = {
"rich": DstackRichHandler(console=console),
"standard": logging.StreamHandler(stream=sys.stdout),
"json": logging.StreamHandler(stream=sys.stdout),
}
if settings.LOG_FORMAT not in formatters:
raise ValueError(f"Invalid settings.LOG_FORMAT: {settings.LOG_FORMAT}")
- formatter = formatters.get(settings.LOG_FORMAT)
- handler = handlers.get(settings.LOG_FORMAT)
+ formatter = formatters[settings.LOG_FORMAT]
+ handler = handlers[settings.LOG_FORMAT]
handler.setFormatter(formatter)
handler.addFilter(AsyncioCancelledErrorFilter())
root_logger = logging.getLogger(None)
diff --git a/src/dstack/_internal/server/utils/provisioning.py b/src/dstack/_internal/server/utils/provisioning.py
index 94a5347343..b77efe7db4 100644
--- a/src/dstack/_internal/server/utils/provisioning.py
+++ b/src/dstack/_internal/server/utils/provisioning.py
@@ -312,10 +312,10 @@ def get_paramiko_connection(
with proxy_ctx as proxy_client, paramiko.SSHClient() as client:
proxy_channel: Optional[paramiko.Channel] = None
if proxy_client is not None:
+ transport = proxy_client.get_transport()
+ assert transport is not None
try:
- proxy_channel = proxy_client.get_transport().open_channel(
- "direct-tcpip", (host, port), ("", 0)
- )
+ proxy_channel = transport.open_channel("direct-tcpip", (host, port), ("", 0))
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError(f"Proxy channel failed: {e}") from e
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
diff --git a/src/dstack/_internal/utils/json_schema.py b/src/dstack/_internal/utils/json_schema.py
index 73ee643179..19bcd0bc62 100644
--- a/src/dstack/_internal/utils/json_schema.py
+++ b/src/dstack/_internal/utils/json_schema.py
@@ -3,7 +3,9 @@ def add_extra_schema_types(schema_property: dict, extra_types: list[dict]):
refs = [schema_property.pop("allOf")[0]]
elif "anyOf" in schema_property:
refs = schema_property.pop("anyOf")
- else:
+ elif "type" in schema_property:
refs = [{"type": schema_property.pop("type")}]
+ else:
+ refs = [{"$ref": schema_property.pop("$ref")}]
refs.extend(extra_types)
schema_property["anyOf"] = refs
diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py
index e1992068d0..473c462139 100644
--- a/src/dstack/api/_public/runs.py
+++ b/src/dstack/api/_public/runs.py
@@ -436,7 +436,7 @@ def get_run_plan(
) -> RunPlan:
"""
Get a run plan.
- Use this method to see the run plan before applying the cofiguration.
+ Use this method to see the run plan before applying the configuration.
Args:
configuration (Union[Task, Service, DevEnvironment]): The run configuration.
@@ -691,11 +691,11 @@ def get_plan(
spot_policy=spot_policy,
retry=None,
utilization_policy=utilization_policy,
- max_duration=max_duration,
- stop_duration=stop_duration,
+ max_duration=max_duration, # type: ignore[assignment]
+ stop_duration=stop_duration, # type: ignore[assignment]
max_price=max_price,
creation_policy=creation_policy,
- idle_duration=idle_duration,
+ idle_duration=idle_duration, # type: ignore[assignment]
)
run_spec = RunSpec(
run_name=run_name,
@@ -812,7 +812,6 @@ def _validate_configuration_files(
if configuration_path is not None:
base_dir = Path(configuration_path).expanduser().resolve().parent
for file_mapping in configuration.files:
- assert isinstance(file_mapping, FilePathMapping)
path = Path(file_mapping.local_path).expanduser()
if not path.is_absolute():
if base_dir is None:
diff --git a/src/dstack/api/huggingface/__init__.py b/src/dstack/api/huggingface/__init__.py
deleted file mode 100644
index 83e5491172..0000000000
--- a/src/dstack/api/huggingface/__init__.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from typing import Dict, Optional
-
-from dstack.api._public.huggingface.finetuning.sft import FineTuningTask
-
-
-class SFTFineTuningTask(FineTuningTask):
- def __init__(
- self,
- model_name: str,
- dataset_name: str,
- env: Dict[str, str],
- new_model_name: Optional[str] = None,
- report_to: Optional[str] = None,
- per_device_train_batch_size: int = 4,
- per_device_eval_batch_size: int = 4,
- gradient_accumulation_steps: int = 1,
- learning_rate: float = 2e-4,
- max_grad_norm: float = 0.3,
- weight_decay: float = 0.001,
- lora_alpha: int = 16,
- lora_dropout: float = 0.1,
- lora_r: int = 64,
- max_seq_length: Optional[int] = None,
- use_4bit: bool = True,
- use_nested_quant: bool = True,
- bnb_4bit_compute_dtype: str = "float16",
- bnb_4bit_quant_type: str = "nf4",
- num_train_epochs: float = 1,
- fp16: bool = False,
- bf16: bool = False,
- packing: bool = False,
- gradient_checkpointing: bool = True,
- optim: str = "paged_adamw_32bit",
- lr_scheduler_type: str = "constant",
- max_steps: int = -1,
- warmup_ratio: float = 0.03,
- group_by_length: bool = True,
- save_steps: int = 0,
- logging_steps: int = 25,
- ):
- super().__init__(
- model_name,
- dataset_name,
- new_model_name,
- env,
- report_to,
- per_device_train_batch_size,
- per_device_eval_batch_size,
- gradient_accumulation_steps,
- learning_rate,
- max_grad_norm,
- weight_decay,
- lora_alpha,
- lora_dropout,
- lora_r,
- max_seq_length,
- use_4bit,
- use_nested_quant,
- bnb_4bit_compute_dtype,
- bnb_4bit_quant_type,
- num_train_epochs,
- fp16,
- bf16,
- packing,
- gradient_checkpointing,
- optim,
- lr_scheduler_type,
- max_steps,
- warmup_ratio,
- group_by_length,
- save_steps,
- logging_steps,
- )
diff --git a/src/dstack/plugins/builtin/rest_plugin/_plugin.py b/src/dstack/plugins/builtin/rest_plugin/_plugin.py
index 1a094147ec..210dd50e19 100644
--- a/src/dstack/plugins/builtin/rest_plugin/_plugin.py
+++ b/src/dstack/plugins/builtin/rest_plugin/_plugin.py
@@ -86,6 +86,7 @@ def _on_apply(
spec: ApplySpec,
excludes: Optional[Dict] = None,
) -> ApplySpec:
+ spec_json = None
try:
spec_request = request_cls(user=user, project=project, spec=spec)
spec_json = self._call_plugin_service(spec_request, endpoint, excludes)