Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 67 additions & 43 deletions src/dstack/_internal/core/models/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import re
from enum import Enum
from typing import Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, Union

import orjson
from pydantic import Field
from pydantic_duality import DualBaseModel
from pydantic_duality import generate_dual_base_model
from typing_extensions import Annotated

from dstack._internal.utils.json_utils import pydantic_orjson_dumps
Expand All @@ -17,46 +17,73 @@
IncludeExcludeType = Union[IncludeExcludeSetType, IncludeExcludeDictType]


class CoreConfig:
json_loads = orjson.loads
json_dumps = pydantic_orjson_dumps


# All dstack models inherit from pydantic-duality's DualBaseModel.
# DualBaseModel creates two classes for the model:
# one with extra = "forbid" (CoreModel/CoreModel.__request__),
# and another with extra = "ignore" (CoreModel.__response__).
# This allows to use the same model both for a strict parsing of the user input and
# for a permissive parsing of the server responses.
class CoreModel(DualBaseModel):
class Config:
json_loads = orjson.loads
json_dumps = pydantic_orjson_dumps

def json(
self,
*,
include: Optional[IncludeExcludeType] = None,
exclude: Optional[IncludeExcludeType] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None, # ignore as it's deprecated
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies
**dumps_kwargs: Any,
) -> str:
"""
Override `json()` method so that it calls `dict()`.
Allows changing how models are serialized by overriding `dict()` only.
By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place.
"""
data = self.dict(
by_alias=by_alias,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
if self.__custom_root_type__:
data = data["__root__"]
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)
# This allows to use the same model both for strict parsing of the user input and
# for permissive parsing of the server responses.
#
# We define a func to generate CoreModel dynamically that can be used
# to define custom Config for both __request__ and __response__ models.
# Note: Defining config in the model class directly overrides
# pydantic-duality's base config, breaking __response__.
def generate_dual_core_model(
custom_config: Union[type, Mapping],
) -> "type[CoreModel]":
class CoreModel(generate_dual_base_model(custom_config)):
def json(
self,
*,
include: Optional[IncludeExcludeType] = None,
exclude: Optional[IncludeExcludeType] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None, # ignore as it's deprecated
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies
**dumps_kwargs: Any,
) -> str:
"""
Override `json()` method so that it calls `dict()`.
Allows changing how models are serialized by overriding `dict()` only.
By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place.
"""
data = self.dict(
by_alias=by_alias,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
if self.__custom_root_type__:
data = data["__root__"]
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)

return CoreModel


if TYPE_CHECKING:

class CoreModel(generate_dual_base_model(CoreConfig)):
pass
else:
CoreModel = generate_dual_core_model(CoreConfig)


class FrozenConfig(CoreConfig):
frozen = True


FrozenCoreModel = generate_dual_core_model(FrozenConfig)


class Duration(int):
Expand Down Expand Up @@ -93,7 +120,7 @@ def parse(cls, v: Union[int, str]) -> "Duration":
raise ValueError(f"Cannot parse the duration {v}")


class RegistryAuth(CoreModel):
class RegistryAuth(FrozenCoreModel):
"""
Credentials for pulling a private Docker image.

Expand All @@ -105,9 +132,6 @@ class RegistryAuth(CoreModel):
username: Annotated[str, Field(description="The username")]
password: Annotated[str, Field(description="The password or access token")]

class Config(CoreModel.Config):
frozen = True


class ApplyAction(str, Enum):
CREATE = "create" # resource is to be created or overridden
Expand Down
150 changes: 88 additions & 62 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,23 @@
from typing_extensions import Self

from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth
from dstack._internal.core.models.common import (
CoreConfig,
CoreModel,
Duration,
RegistryAuth,
generate_dual_core_model,
)
from dstack._internal.core.models.envs import Env
from dstack._internal.core.models.files import FilePathMapping
from dstack._internal.core.models.fleets import FleetConfiguration
from dstack._internal.core.models.gateways import GatewayConfiguration
from dstack._internal.core.models.profiles import ProfileParams, parse_duration, parse_off_duration
from dstack._internal.core.models.profiles import (
ProfileParams,
ProfileParamsConfig,
parse_duration,
parse_off_duration,
)
from dstack._internal.core.models.resources import Range, ResourcesSpec
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
Expand Down Expand Up @@ -276,7 +287,20 @@ class HTTPHeaderSpec(CoreModel):
]


class ProbeConfig(CoreModel):
class ProbeConfigConfig(CoreConfig):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
add_extra_schema_types(
schema["properties"]["timeout"],
extra_types=[{"type": "string"}],
)
add_extra_schema_types(
schema["properties"]["interval"],
extra_types=[{"type": "string"}],
)


class ProbeConfig(generate_dual_core_model(ProbeConfigConfig)):
type: Literal["http"] # expect other probe types in the future, namely `exec`
url: Annotated[
Optional[str], Field(description=f"The URL to request. Defaults to `{DEFAULT_PROBE_URL}`")
Expand Down Expand Up @@ -331,18 +355,6 @@ class ProbeConfig(CoreModel):
),
] = None

class Config(CoreModel.Config):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
add_extra_schema_types(
schema["properties"]["timeout"],
extra_types=[{"type": "string"}],
)
add_extra_schema_types(
schema["properties"]["interval"],
extra_types=[{"type": "string"}],
)

@validator("timeout", pre=True)
def parse_timeout(cls, v: Optional[Union[int, str]]) -> Optional[int]:
if v is None:
Expand Down Expand Up @@ -381,6 +393,19 @@ def validate_body_matches_method(cls, values):
return values


class BaseRunConfigurationConfig(CoreConfig):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
add_extra_schema_types(
schema["properties"]["volumes"]["items"],
extra_types=[{"type": "string"}],
)
add_extra_schema_types(
schema["properties"]["files"]["items"],
extra_types=[{"type": "string"}],
)


class BaseRunConfiguration(CoreModel):
type: Literal["none"]
name: Annotated[
Expand Down Expand Up @@ -484,18 +509,6 @@ class BaseRunConfiguration(CoreModel):
# deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init`
setup: CommandsList = []

class Config(CoreModel.Config):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
add_extra_schema_types(
schema["properties"]["volumes"]["items"],
extra_types=[{"type": "string"}],
)
add_extra_schema_types(
schema["properties"]["files"]["items"],
extra_types=[{"type": "string"}],
)

@validator("python", pre=True, always=True)
def convert_python(cls, v, values) -> Optional[PythonVersion]:
if v is not None and values.get("image"):
Expand Down Expand Up @@ -621,20 +634,25 @@ def parse_inactivity_duration(
return None


class DevEnvironmentConfigurationConfig(
ProfileParamsConfig,
BaseRunConfigurationConfig,
):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
ProfileParamsConfig.schema_extra(schema)
BaseRunConfigurationConfig.schema_extra(schema)


class DevEnvironmentConfiguration(
ProfileParams,
BaseRunConfiguration,
ConfigurationWithPortsParams,
DevEnvironmentConfigurationParams,
generate_dual_core_model(DevEnvironmentConfigurationConfig),
):
type: Literal["dev-environment"] = "dev-environment"

class Config(ProfileParams.Config, BaseRunConfiguration.Config):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
ProfileParams.Config.schema_extra(schema)
BaseRunConfiguration.Config.schema_extra(schema)

@validator("entrypoint")
def validate_entrypoint(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
Expand All @@ -646,20 +664,38 @@ class TaskConfigurationParams(CoreModel):
nodes: Annotated[int, Field(description="Number of nodes", ge=1)] = 1


class TaskConfigurationConfig(
ProfileParamsConfig,
BaseRunConfigurationConfig,
):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
ProfileParamsConfig.schema_extra(schema)
BaseRunConfigurationConfig.schema_extra(schema)


class TaskConfiguration(
ProfileParams,
BaseRunConfiguration,
ConfigurationWithCommandsParams,
ConfigurationWithPortsParams,
TaskConfigurationParams,
generate_dual_core_model(TaskConfigurationConfig),
):
type: Literal["task"] = "task"

class Config(ProfileParams.Config, BaseRunConfiguration.Config):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
ProfileParams.Config.schema_extra(schema)
BaseRunConfiguration.Config.schema_extra(schema)

class ServiceConfigurationParamsConfig(CoreConfig):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
add_extra_schema_types(
schema["properties"]["replicas"],
extra_types=[{"type": "integer"}, {"type": "string"}],
)
add_extra_schema_types(
schema["properties"]["model"],
extra_types=[{"type": "string"}],
)


class ServiceConfigurationParams(CoreModel):
Expand Down Expand Up @@ -719,18 +755,6 @@ class ServiceConfigurationParams(CoreModel):
Field(description="List of probes used to determine job health"),
] = []

class Config(CoreModel.Config):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
add_extra_schema_types(
schema["properties"]["replicas"],
extra_types=[{"type": "integer"}, {"type": "string"}],
)
add_extra_schema_types(
schema["properties"]["model"],
extra_types=[{"type": "string"}],
)

@validator("port")
def convert_port(cls, v) -> PortMapping:
if isinstance(v, int):
Expand Down Expand Up @@ -797,25 +821,27 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
return v


class ServiceConfigurationConfig(
ProfileParamsConfig,
BaseRunConfigurationConfig,
ServiceConfigurationParamsConfig,
):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
ProfileParamsConfig.schema_extra(schema)
BaseRunConfigurationConfig.schema_extra(schema)
ServiceConfigurationParamsConfig.schema_extra(schema)


class ServiceConfiguration(
ProfileParams,
BaseRunConfiguration,
ConfigurationWithCommandsParams,
ServiceConfigurationParams,
generate_dual_core_model(ServiceConfigurationConfig),
):
type: Literal["service"] = "service"

class Config(
ProfileParams.Config,
BaseRunConfiguration.Config,
ServiceConfigurationParams.Config,
):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
ProfileParams.Config.schema_extra(schema)
BaseRunConfiguration.Config.schema_extra(schema)
ServiceConfigurationParams.Config.schema_extra(schema)


AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]

Expand Down Expand Up @@ -876,7 +902,7 @@ class DstackConfiguration(CoreModel):
Field(discriminator="type"),
]

class Config(CoreModel.Config):
class Config(CoreConfig):
json_loads = orjson.loads
json_dumps = pydantic_orjson_dumps_with_indent

Expand Down
Loading
Loading