diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index 772da55274..6fcc6d0392 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -1,10 +1,10 @@ import re from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, Union import orjson from pydantic import Field -from pydantic_duality import DualBaseModel +from pydantic_duality import generate_dual_base_model from typing_extensions import Annotated from dstack._internal.utils.json_utils import pydantic_orjson_dumps @@ -17,46 +17,73 @@ IncludeExcludeType = Union[IncludeExcludeSetType, IncludeExcludeDictType] +class CoreConfig: + json_loads = orjson.loads + json_dumps = pydantic_orjson_dumps + + +# All dstack models inherit from pydantic-duality's DualBaseModel. # DualBaseModel creates two classes for the model: # one with extra = "forbid" (CoreModel/CoreModel.__request__), # and another with extra = "ignore" (CoreModel.__response__). -# This allows to use the same model both for a strict parsing of the user input and -# for a permissive parsing of the server responses. -class CoreModel(DualBaseModel): - class Config: - json_loads = orjson.loads - json_dumps = pydantic_orjson_dumps - - def json( - self, - *, - include: Optional[IncludeExcludeType] = None, - exclude: Optional[IncludeExcludeType] = None, - by_alias: bool = False, - skip_defaults: Optional[bool] = None, # ignore as it's deprecated - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - encoder: Optional[Callable[[Any], Any]] = None, - models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies - **dumps_kwargs: Any, - ) -> str: - """ - Override `json()` method so that it calls `dict()`. - Allows changing how models are serialized by overriding `dict()` only. - By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place. - """ - data = self.dict( - by_alias=by_alias, - include=include, - exclude=exclude, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - if self.__custom_root_type__: - data = data["__root__"] - return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) +# This allows to use the same model both for strict parsing of the user input and +# for permissive parsing of the server responses. +# +# We define a func to generate CoreModel dynamically that can be used +# to define custom Config for both __request__ and __response__ models. +# Note: Defining config in the model class directly overrides +# pydantic-duality's base config, breaking __response__. +def generate_dual_core_model( + custom_config: Union[type, Mapping], +) -> "type[CoreModel]": + class CoreModel(generate_dual_base_model(custom_config)): + def json( + self, + *, + include: Optional[IncludeExcludeType] = None, + exclude: Optional[IncludeExcludeType] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, # ignore as it's deprecated + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + encoder: Optional[Callable[[Any], Any]] = None, + models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies + **dumps_kwargs: Any, + ) -> str: + """ + Override `json()` method so that it calls `dict()`. + Allows changing how models are serialized by overriding `dict()` only. + By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place. + """ + data = self.dict( + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + if self.__custom_root_type__: + data = data["__root__"] + return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) + + return CoreModel + + +if TYPE_CHECKING: + + class CoreModel(generate_dual_base_model(CoreConfig)): + pass +else: + CoreModel = generate_dual_core_model(CoreConfig) + + +class FrozenConfig(CoreConfig): + frozen = True + + +FrozenCoreModel = generate_dual_core_model(FrozenConfig) class Duration(int): @@ -93,7 +120,7 @@ def parse(cls, v: Union[int, str]) -> "Duration": raise ValueError(f"Cannot parse the duration {v}") -class RegistryAuth(CoreModel): +class RegistryAuth(FrozenCoreModel): """ Credentials for pulling a private Docker image. @@ -105,9 +132,6 @@ class RegistryAuth(CoreModel): username: Annotated[str, Field(description="The username")] password: Annotated[str, Field(description="The password or access token")] - class Config(CoreModel.Config): - frozen = True - class ApplyAction(str, Enum): CREATE = "create" # resource is to be created or overridden diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 39baeb6ddf..6fe8132de9 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -10,12 +10,23 @@ from typing_extensions import Self from dstack._internal.core.errors import ConfigurationError -from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth +from dstack._internal.core.models.common import ( + CoreConfig, + CoreModel, + Duration, + RegistryAuth, + generate_dual_core_model, +) from dstack._internal.core.models.envs import Env from dstack._internal.core.models.files import FilePathMapping from dstack._internal.core.models.fleets import FleetConfiguration from dstack._internal.core.models.gateways import GatewayConfiguration -from dstack._internal.core.models.profiles import ProfileParams, parse_duration, parse_off_duration +from dstack._internal.core.models.profiles import ( + ProfileParams, + ProfileParamsConfig, + parse_duration, + parse_off_duration, +) from dstack._internal.core.models.resources import Range, ResourcesSpec from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser @@ -276,7 +287,20 @@ class HTTPHeaderSpec(CoreModel): ] -class ProbeConfig(CoreModel): +class ProbeConfigConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["timeout"], + extra_types=[{"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["interval"], + extra_types=[{"type": "string"}], + ) + + +class ProbeConfig(generate_dual_core_model(ProbeConfigConfig)): type: Literal["http"] # expect other probe types in the future, namely `exec` url: Annotated[ Optional[str], Field(description=f"The URL to request. Defaults to `{DEFAULT_PROBE_URL}`") @@ -331,18 +355,6 @@ class ProbeConfig(CoreModel): ), ] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["timeout"], - extra_types=[{"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["interval"], - extra_types=[{"type": "string"}], - ) - @validator("timeout", pre=True) def parse_timeout(cls, v: Optional[Union[int, str]]) -> Optional[int]: if v is None: @@ -381,6 +393,19 @@ def validate_body_matches_method(cls, values): return values +class BaseRunConfigurationConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["volumes"]["items"], + extra_types=[{"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["files"]["items"], + extra_types=[{"type": "string"}], + ) + + class BaseRunConfiguration(CoreModel): type: Literal["none"] name: Annotated[ @@ -484,18 +509,6 @@ class BaseRunConfiguration(CoreModel): # deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init` setup: CommandsList = [] - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["volumes"]["items"], - extra_types=[{"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["files"]["items"], - extra_types=[{"type": "string"}], - ) - @validator("python", pre=True, always=True) def convert_python(cls, v, values) -> Optional[PythonVersion]: if v is not None and values.get("image"): @@ -621,20 +634,25 @@ def parse_inactivity_duration( return None +class DevEnvironmentConfigurationConfig( + ProfileParamsConfig, + BaseRunConfigurationConfig, +): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + ProfileParamsConfig.schema_extra(schema) + BaseRunConfigurationConfig.schema_extra(schema) + + class DevEnvironmentConfiguration( ProfileParams, BaseRunConfiguration, ConfigurationWithPortsParams, DevEnvironmentConfigurationParams, + generate_dual_core_model(DevEnvironmentConfigurationConfig), ): type: Literal["dev-environment"] = "dev-environment" - class Config(ProfileParams.Config, BaseRunConfiguration.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - ProfileParams.Config.schema_extra(schema) - BaseRunConfiguration.Config.schema_extra(schema) - @validator("entrypoint") def validate_entrypoint(cls, v: Optional[str]) -> Optional[str]: if v is not None: @@ -646,20 +664,38 @@ class TaskConfigurationParams(CoreModel): nodes: Annotated[int, Field(description="Number of nodes", ge=1)] = 1 +class TaskConfigurationConfig( + ProfileParamsConfig, + BaseRunConfigurationConfig, +): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + ProfileParamsConfig.schema_extra(schema) + BaseRunConfigurationConfig.schema_extra(schema) + + class TaskConfiguration( ProfileParams, BaseRunConfiguration, ConfigurationWithCommandsParams, ConfigurationWithPortsParams, TaskConfigurationParams, + generate_dual_core_model(TaskConfigurationConfig), ): type: Literal["task"] = "task" - class Config(ProfileParams.Config, BaseRunConfiguration.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - ProfileParams.Config.schema_extra(schema) - BaseRunConfiguration.Config.schema_extra(schema) + +class ServiceConfigurationParamsConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["replicas"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["model"], + extra_types=[{"type": "string"}], + ) class ServiceConfigurationParams(CoreModel): @@ -719,18 +755,6 @@ class ServiceConfigurationParams(CoreModel): Field(description="List of probes used to determine job health"), ] = [] - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["replicas"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["model"], - extra_types=[{"type": "string"}], - ) - @validator("port") def convert_port(cls, v) -> PortMapping: if isinstance(v, int): @@ -797,25 +821,27 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: return v +class ServiceConfigurationConfig( + ProfileParamsConfig, + BaseRunConfigurationConfig, + ServiceConfigurationParamsConfig, +): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + ProfileParamsConfig.schema_extra(schema) + BaseRunConfigurationConfig.schema_extra(schema) + ServiceConfigurationParamsConfig.schema_extra(schema) + + class ServiceConfiguration( ProfileParams, BaseRunConfiguration, ConfigurationWithCommandsParams, ServiceConfigurationParams, + generate_dual_core_model(ServiceConfigurationConfig), ): type: Literal["service"] = "service" - class Config( - ProfileParams.Config, - BaseRunConfiguration.Config, - ServiceConfigurationParams.Config, - ): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - ProfileParams.Config.schema_extra(schema) - BaseRunConfiguration.Config.schema_extra(schema) - ServiceConfigurationParams.Config.schema_extra(schema) - AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration] @@ -876,7 +902,7 @@ class DstackConfiguration(CoreModel): Field(discriminator="type"), ] - class Config(CoreModel.Config): + class Config(CoreConfig): json_loads = orjson.loads json_dumps = pydantic_orjson_dumps_with_indent diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index 32a6761bb4..596d8ec821 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -2,13 +2,18 @@ import uuid from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Union from pydantic import Field, root_validator, validator from typing_extensions import Annotated, Literal from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import ApplyAction, CoreModel +from dstack._internal.core.models.common import ( + ApplyAction, + CoreConfig, + CoreModel, + generate_dual_core_model, +) from dstack._internal.core.models.envs import Env from dstack._internal.core.models.instances import Instance, InstanceOfferWithAvailability, SSHKey from dstack._internal.core.models.profiles import ( @@ -202,6 +207,21 @@ def _post_validate_ranges(cls, values): return values +class InstanceGroupParamsConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + del schema["properties"]["termination_policy"] + del schema["properties"]["termination_idle_time"] + add_extra_schema_types( + schema["properties"]["nodes"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["idle_duration"], + extra_types=[{"type": "string"}], + ) + + class InstanceGroupParams(CoreModel): env: Annotated[ Env, @@ -297,20 +317,6 @@ class InstanceGroupParams(CoreModel): termination_policy: Annotated[Optional[TerminationPolicy], Field(exclude=True)] = None termination_idle_time: Annotated[Optional[Union[str, int]], Field(exclude=True)] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any], model: Type): - del schema["properties"]["termination_policy"] - del schema["properties"]["termination_idle_time"] - add_extra_schema_types( - schema["properties"]["nodes"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["idle_duration"], - extra_types=[{"type": "string"}], - ) - @validator("nodes", pre=True) def parse_nodes(cls, v: Optional[Union[dict, str]]) -> Optional[dict]: if isinstance(v, str) and ".." in v: @@ -331,7 +337,17 @@ class FleetProps(CoreModel): name: Annotated[Optional[str], Field(description="The fleet name")] = None -class FleetConfiguration(InstanceGroupParams, FleetProps): +class FleetConfigurationConfig(InstanceGroupParamsConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + InstanceGroupParamsConfig.schema_extra(schema) + + +class FleetConfiguration( + InstanceGroupParams, + FleetProps, + generate_dual_core_model(FleetConfigurationConfig), +): tags: Annotated[ Optional[Dict[str, str]], Field( @@ -346,7 +362,14 @@ class FleetConfiguration(InstanceGroupParams, FleetProps): _validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator) -class FleetSpec(CoreModel): +class FleetSpecConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + prop = schema.get("properties", {}) + prop.pop("merged_profile", None) + + +class FleetSpec(generate_dual_core_model(FleetSpecConfig)): configuration: FleetConfiguration configuration_path: Optional[str] = None profile: Profile @@ -356,12 +379,6 @@ class FleetSpec(CoreModel): # TODO: make merged_profile a computed field after migrating to pydanticV2 merged_profile: Annotated[Profile, Field(exclude=True)] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any], model: Type) -> None: - prop = schema.get("properties", {}) - prop.pop("merged_profile", None) - @root_validator def _merged_profile(cls, values) -> Dict: try: diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index 81537edced..f1f802d54f 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -7,7 +7,10 @@ from pydantic import root_validator from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.common import ( + CoreModel, + FrozenCoreModel, +) from dstack._internal.core.models.envs import Env from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.volumes import Volume @@ -117,14 +120,11 @@ class InstanceType(CoreModel): resources: Resources -class SSHConnectionParams(CoreModel): +class SSHConnectionParams(FrozenCoreModel): hostname: str username: str port: int - class Config(CoreModel.Config): - frozen = True - class SSHKey(CoreModel): public: str diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index e63a67557e..9097abed4d 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -6,7 +6,12 @@ from typing_extensions import Annotated, Literal from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import CoreModel, Duration +from dstack._internal.core.models.common import ( + CoreConfig, + CoreModel, + Duration, + generate_dual_core_model, +) from dstack._internal.utils.common import list_enum_values_for_annotation from dstack._internal.utils.cron import validate_cron from dstack._internal.utils.json_schema import add_extra_schema_types @@ -112,7 +117,16 @@ class RetryEvent(str, Enum): ERROR = "error" -class ProfileRetry(CoreModel): +class ProfileRetryConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["duration"], + extra_types=[{"type": "string"}], + ) + + +class ProfileRetry(generate_dual_core_model(ProfileRetryConfig)): on_events: Annotated[ Optional[List[RetryEvent]], Field( @@ -128,14 +142,6 @@ class ProfileRetry(CoreModel): Field(description="The maximum period of retrying the run, e.g., `4h` or `1d`"), ] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["duration"], - extra_types=[{"type": "string"}], - ) - _validate_duration = validator("duration", pre=True, allow_reuse=True)(parse_duration) @root_validator @@ -146,7 +152,16 @@ def _validate_fields(cls, values): return values -class UtilizationPolicy(CoreModel): +class UtilizationPolicyConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["time_window"], + extra_types=[{"type": "string"}], + ) + + +class UtilizationPolicy(generate_dual_core_model(UtilizationPolicyConfig)): _min_time_window = "5m" min_gpu_utilization: Annotated[ @@ -171,14 +186,6 @@ class UtilizationPolicy(CoreModel): ), ] - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["time_window"], - extra_types=[{"type": "string"}], - ) - @validator("time_window", pre=True) def validate_time_window(cls, v: Union[int, str]) -> int: v = parse_duration(v) @@ -219,6 +226,28 @@ def crons(self) -> List[str]: return self.cron +class ProfileParamsConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + del schema["properties"]["pool_name"] + del schema["properties"]["instance_name"] + del schema["properties"]["retry_policy"] + del schema["properties"]["termination_policy"] + del schema["properties"]["termination_idle_time"] + add_extra_schema_types( + schema["properties"]["max_duration"], + extra_types=[{"type": "boolean"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["stop_duration"], + extra_types=[{"type": "boolean"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["idle_duration"], + extra_types=[{"type": "string"}], + ) + + class ProfileParams(CoreModel): backends: Annotated[ Optional[List[BackendType]], @@ -358,27 +387,6 @@ class ProfileParams(CoreModel): termination_policy: Annotated[Optional[TerminationPolicy], Field(exclude=True)] = None termination_idle_time: Annotated[Optional[Union[str, int]], Field(exclude=True)] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]) -> None: - del schema["properties"]["pool_name"] - del schema["properties"]["instance_name"] - del schema["properties"]["retry_policy"] - del schema["properties"]["termination_policy"] - del schema["properties"]["termination_idle_time"] - add_extra_schema_types( - schema["properties"]["max_duration"], - extra_types=[{"type": "boolean"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["stop_duration"], - extra_types=[{"type": "boolean"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["idle_duration"], - extra_types=[{"type": "string"}], - ) - _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration ) @@ -403,17 +411,28 @@ class ProfileProps(CoreModel): ] = False -class Profile(ProfileProps, ProfileParams): +class ProfileConfig(ProfileParamsConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + ProfileParamsConfig.schema_extra(schema) + + +class Profile( + ProfileProps, + ProfileParams, + generate_dual_core_model(ProfileConfig), +): pass -class ProfilesConfig(CoreModel): - profiles: List[Profile] +class ProfilesConfigConfig(CoreConfig): + json_loads = orjson.loads + json_dumps = pydantic_orjson_dumps_with_indent + schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"} - class Config(CoreModel.Config): - json_loads = orjson.loads - json_dumps = pydantic_orjson_dumps_with_indent - schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"} + +class ProfilesConfig(generate_dual_core_model(ProfilesConfigConfig)): + profiles: List[Profile] def default(self) -> Optional[Profile]: for p in self.profiles: diff --git a/src/dstack/_internal/core/models/repos/remote.py b/src/dstack/_internal/core/models/repos/remote.py index 366767fe79..54002f8c6d 100644 --- a/src/dstack/_internal/core/models/repos/remote.py +++ b/src/dstack/_internal/core/models/repos/remote.py @@ -11,7 +11,7 @@ from typing_extensions import Literal from dstack._internal.core.errors import DstackError -from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.common import CoreConfig, generate_dual_core_model from dstack._internal.core.models.repos.base import BaseRepoInfo, Repo from dstack._internal.utils.hash import get_sha256, slugify from dstack._internal.utils.path import PathLike @@ -24,21 +24,33 @@ class RepoError(DstackError): pass -class RemoteRepoCreds(CoreModel): +class RemoteRepoCredsConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + del schema["properties"]["protocol"] + + +class RemoteRepoCreds(generate_dual_core_model(RemoteRepoCredsConfig)): clone_url: str - private_key: Optional[str] - oauth_token: Optional[str] + private_key: Optional[str] = None + oauth_token: Optional[str] = None # TODO: remove in 0.20. Left for compatibility with CLI <=0.18.44 protocol: Annotated[Optional[str], Field(exclude=True)] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]) -> None: - del schema["properties"]["protocol"] +class RemoteRepoInfoConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + del schema["properties"]["repo_host_name"] + del schema["properties"]["repo_port"] + del schema["properties"]["repo_user_name"] -class RemoteRepoInfo(BaseRepoInfo): + +class RemoteRepoInfo( + BaseRepoInfo, + generate_dual_core_model(RemoteRepoInfoConfig), +): repo_type: Literal["remote"] = "remote" repo_name: str @@ -47,13 +59,6 @@ class RemoteRepoInfo(BaseRepoInfo): repo_port: Annotated[Optional[int], Field(exclude=True)] = None repo_user_name: Annotated[Optional[str], Field(exclude=True)] = None - class Config(BaseRepoInfo.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]) -> None: - del schema["properties"]["repo_host_name"] - del schema["properties"]["repo_port"] - del schema["properties"]["repo_user_name"] - class RemoteRunRepoData(RemoteRepoInfo): repo_branch: Optional[str] = None diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py index 1aedb85c27..20b4f3aa55 100644 --- a/src/dstack/_internal/core/models/resources.py +++ b/src/dstack/_internal/core/models/resources.py @@ -7,7 +7,7 @@ from pydantic.generics import GenericModel from typing_extensions import Annotated -from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.common import CoreConfig, CoreModel, generate_dual_core_model from dstack._internal.utils.common import pretty_resources from dstack._internal.utils.json_schema import add_extra_schema_types from dstack._internal.utils.logging import get_logger @@ -129,21 +129,22 @@ def __str__(self): DEFAULT_GPU_COUNT = Range[int](min=1) -class CPUSpec(CoreModel): +class CPUSpecConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["count"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + + +class CPUSpec(generate_dual_core_model(CPUSpecConfig)): arch: Annotated[ Optional[gpuhunt.CPUArchitecture], Field(description="The CPU architecture, one of: `x86`, `arm`"), ] = None count: Annotated[Range[int], Field(description="The number of CPU cores")] = DEFAULT_CPU_COUNT - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["count"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - @classmethod def __get_validators__(cls): yield cls.parse @@ -190,7 +191,28 @@ def _validate_arch(cls, v: Any) -> Any: return v -class GPUSpec(CoreModel): +class GPUSpecConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["count"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["name"], + extra_types=[{"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["memory"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["total_memory"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + + +class GPUSpec(generate_dual_core_model(GPUSpecConfig)): vendor: Annotated[ Optional[gpuhunt.AcceleratorVendor], Field( @@ -218,26 +240,6 @@ class GPUSpec(CoreModel): Field(description="The minimum compute capability of the GPU (e.g., `7.5`)"), ] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["count"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["name"], - extra_types=[{"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["memory"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["total_memory"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - @classmethod def __get_validators__(cls): yield cls.parse @@ -317,16 +319,17 @@ def _vendor_from_string(cls, v: str) -> gpuhunt.AcceleratorVendor: return gpuhunt.AcceleratorVendor.cast(v) -class DiskSpec(CoreModel): - size: Annotated[Range[Memory], Field(description="Disk size")] +class DiskSpecConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["size"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["size"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) + +class DiskSpec(generate_dual_core_model(DiskSpecConfig)): + size: Annotated[Range[Memory], Field(description="Disk size")] @classmethod def __get_validators__(cls): @@ -343,7 +346,32 @@ def _parse(cls, v: Any) -> Any: DEFAULT_DISK = DiskSpec(size=Range[Memory](min=Memory.parse("100GB"), max=None)) -class ResourcesSpec(CoreModel): +class ResourcesSpecConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["cpu"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["memory"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["shm_size"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["gpu"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["disk"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + + +class ResourcesSpec(generate_dual_core_model(ResourcesSpecConfig)): # TODO: Remove Range[int] in 0.20. Range[int] for backward compatibility only. cpu: Annotated[Union[CPUSpec, Range[int]], Field(description="The CPU requirements")] = ( CPUSpec() @@ -362,30 +390,6 @@ class ResourcesSpec(CoreModel): gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["cpu"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["memory"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["shm_size"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["gpu"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["disk"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - def pretty_format(self) -> str: # TODO: Remove in 0.20. Use self.cpu directly cpu = parse_obj_as(CPUSpec, self.cpu) diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 0f0a87c13a..969b336b9d 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -1,13 +1,20 @@ from datetime import datetime, timedelta from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Type +from typing import Any, Dict, List, Literal, Optional from urllib.parse import urlparse from pydantic import UUID4, Field, root_validator from typing_extensions import Annotated from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import ApplyAction, CoreModel, NetworkMode, RegistryAuth +from dstack._internal.core.models.common import ( + ApplyAction, + CoreConfig, + CoreModel, + NetworkMode, + RegistryAuth, + generate_dual_core_model, +) from dstack._internal.core.models.configurations import ( DEFAULT_PROBE_METHOD, LEGACY_REPO_DIR, @@ -385,7 +392,14 @@ class Job(CoreModel): job_submissions: List[JobSubmission] -class RunSpec(CoreModel): +class RunSpecConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + prop = schema.get("properties", {}) + prop.pop("merged_profile", None) + + +class RunSpec(generate_dual_core_model(RunSpecConfig)): # TODO: run_name, working_dir are redundant here since they already passed in configuration run_name: Annotated[ Optional[str], @@ -458,12 +472,6 @@ class RunSpec(CoreModel): # TODO: make merged_profile a computed field after migrating to pydanticV2 merged_profile: Annotated[Profile, Field(exclude=True)] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any], model: Type) -> None: - prop = schema.get("properties", {}) - prop.pop("merged_profile", None) - @root_validator def _merged_profile(cls, values) -> Dict: if values.get("profile") is None: diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 1243a4ae1d..b21ba81a4d 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -24,7 +24,7 @@ from dstack._internal.core.errors import DstackError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.common import CoreConfig, generate_dual_core_model from dstack._internal.core.models.fleets import FleetStatus from dstack._internal.core.models.gateways import GatewayStatus from dstack._internal.core.models.health import HealthStatus @@ -71,7 +71,11 @@ def process_result_value(self, value, dialect): return value.replace(tzinfo=timezone.utc) -class DecryptedString(CoreModel): +class DecryptedStringConfig(CoreConfig): + arbitrary_types_allowed = True + + +class DecryptedString(generate_dual_core_model(DecryptedStringConfig)): """ A type for representing plaintext strings encrypted with `EncryptedString`. Besides the string, stores information if the decryption was successful. @@ -84,9 +88,6 @@ class DecryptedString(CoreModel): decrypted: bool = True exc: Optional[Exception] = None - class Config(CoreModel.Config): - arbitrary_types_allowed = True - def get_plaintext_or_error(self) -> str: if self.decrypted and self.plaintext is not None: return self.plaintext diff --git a/src/dstack/_internal/server/schemas/gateways.py b/src/dstack/_internal/server/schemas/gateways.py index c4d7ebcb77..ffad6e78ab 100644 --- a/src/dstack/_internal/server/schemas/gateways.py +++ b/src/dstack/_internal/server/schemas/gateways.py @@ -3,24 +3,25 @@ from pydantic import Field from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.common import CoreConfig, CoreModel, generate_dual_core_model from dstack._internal.core.models.gateways import GatewayConfiguration -class CreateGatewayRequest(CoreModel): +class CreateGatewayRequestConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + del schema["properties"]["name"] + del schema["properties"]["backend_type"] + del schema["properties"]["region"] + + +class CreateGatewayRequest(generate_dual_core_model(CreateGatewayRequestConfig)): configuration: GatewayConfiguration # Deprecated and unused. Left for compatibility with 0.18 clients. name: Annotated[Optional[str], Field(exclude=True)] = None backend_type: Annotated[Optional[BackendType], Field(exclude=True)] = None region: Annotated[Optional[str], Field(exclude=True)] = None - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]) -> None: - del schema["properties"]["name"] - del schema["properties"]["backend_type"] - del schema["properties"]["region"] - class GetGatewayRequest(CoreModel): name: str diff --git a/src/dstack/_internal/server/services/docker.py b/src/dstack/_internal/server/services/docker.py index 7181edc7d3..74d9c391b1 100644 --- a/src/dstack/_internal/server/services/docker.py +++ b/src/dstack/_internal/server/services/docker.py @@ -9,7 +9,11 @@ from typing_extensions import Annotated from dstack._internal.core.errors import DockerRegistryError -from dstack._internal.core.models.common import CoreModel, RegistryAuth +from dstack._internal.core.models.common import ( + CoreModel, + FrozenCoreModel, + RegistryAuth, +) from dstack._internal.server.utils.common import join_byte_stream_checked from dstack._internal.utils.dxf import PatchedDXF @@ -31,15 +35,12 @@ def __call__(self, dxf: DXF, response: requests.Response) -> None: ) -class DockerImage(CoreModel): +class DockerImage(FrozenCoreModel): image: str - registry: Optional[str] + registry: Optional[str] = None repo: str tag: str - digest: Optional[str] - - class Config(CoreModel.Config): - frozen = True + digest: Optional[str] = None class ImageConfig(CoreModel):