Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ jobs:
run: uv sync
- name: Generate json schema
run: |
uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json
uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json
uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json())" > configuration.json
uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json())" > profiles.json
- name: Upload json schema to S3
run: |
VERSION=$((${{ github.run_number }} + ${{ env.BUILD_INCREMENT }}))
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ jobs:
run: uv sync
- name: Generate json schema
run: |
uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json
uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json
uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json())" > configuration.json
uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json())" > profiles.json
- name: Upload json schema to S3
run: |
VERSION=${GITHUB_REF#refs/tags/}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"gpuhunt==0.1.6",
"argcomplete>=3.5.0",
"ignore-python>=0.2.0",
"orjson",
]

[project.urls]
Expand Down
40 changes: 38 additions & 2 deletions src/dstack/_internal/core/models/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import re
from enum import Enum
from typing import Union
from typing import Any, Callable, Optional, Union

import orjson
from pydantic import Field
from pydantic_duality import DualBaseModel
from typing_extensions import Annotated

from dstack._internal.utils.json_utils import pydantic_orjson_dumps

IncludeExcludeFieldType = Union[int, str]
IncludeExcludeSetType = set[IncludeExcludeFieldType]
IncludeExcludeDictType = dict[
Expand All @@ -20,7 +23,40 @@
# This allows to use the same model both for a strict parsing of the user input and
# for a permissive parsing of the server responses.
class CoreModel(DualBaseModel):
pass
class Config:
json_loads = orjson.loads
json_dumps = pydantic_orjson_dumps

def json(
self,
*,
include: Optional[IncludeExcludeType] = None,
exclude: Optional[IncludeExcludeType] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None, # ignore as it's deprecated
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies
**dumps_kwargs: Any,
) -> str:
"""
Override `json()` method so that it calls `dict()`.
Allows changing how models are serialized by overriding `dict()` only.
By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place.
"""
data = self.dict(
by_alias=by_alias,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
if self.__custom_root_type__:
data = data["__root__"]
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)


class Duration(int):
Expand Down
7 changes: 7 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import PurePosixPath
from typing import Any, Dict, List, Optional, Union

import orjson
from pydantic import Field, ValidationError, conint, constr, root_validator, validator
from typing_extensions import Annotated, Literal

Expand All @@ -18,6 +19,9 @@
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
from dstack._internal.utils.json_utils import (
pydantic_orjson_dumps_with_indent,
)

CommandsList = List[str]
ValidPort = conint(gt=0, le=65536)
Expand Down Expand Up @@ -573,6 +577,9 @@ class DstackConfiguration(CoreModel):
]

class Config:
json_loads = orjson.loads
json_dumps = pydantic_orjson_dumps_with_indent

@staticmethod
def schema_extra(schema: Dict[str, Any]):
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
Expand Down
5 changes: 5 additions & 0 deletions src/dstack/_internal/core/models/profiles.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union, overload

import orjson
from pydantic import Field, root_validator, validator
from typing_extensions import Annotated, Literal

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel, Duration
from dstack._internal.utils.common import list_enum_values_for_annotation
from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent
from dstack._internal.utils.tags import tags_validator

DEFAULT_RETRY_DURATION = 3600
Expand Down Expand Up @@ -343,6 +345,9 @@ class ProfilesConfig(CoreModel):
profiles: List[Profile]

class Config:
json_loads = orjson.loads
json_dumps = pydantic_orjson_dumps_with_indent

schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"}

def default(self) -> Optional[Profile]:
Expand Down
23 changes: 15 additions & 8 deletions src/dstack/_internal/core/models/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,6 @@ def schema_extra(schema: Dict[str, Any]):
gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None
disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK

# TODO: Remove in 0.20. Added for backward compatibility.
@root_validator
def _post_validate(cls, values):
cpu = values.get("cpu")
if isinstance(cpu, CPUSpec) and cpu.arch in [None, gpuhunt.CPUArchitecture.X86]:
values["cpu"] = cpu.count
return values

def pretty_format(self) -> str:
# TODO: Remove in 0.20. Use self.cpu directly
cpu = parse_obj_as(CPUSpec, self.cpu)
Expand All @@ -407,3 +399,18 @@ def pretty_format(self) -> str:
resources.update(disk_size=self.disk.size)
res = pretty_resources(**resources)
return res

def dict(self, *args, **kwargs) -> Dict:
# super() does not work with pydantic-duality
res = CoreModel.dict(self, *args, **kwargs)
self._update_serialized_cpu(res)
return res

# TODO: Remove in 0.20. Added for backward compatibility.
def _update_serialized_cpu(self, values: Dict):
cpu = values["cpu"]
if cpu:
arch = cpu.get("arch")
count = cpu.get("count")
if count and arch in [None, gpuhunt.CPUArchitecture.X86.value]:
values["cpu"] = count
160 changes: 69 additions & 91 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,56 +325,45 @@ def duration(self) -> timedelta:
end_time = self.finished_at
return end_time - self.submitted_at

@root_validator
def _status_message(cls, values) -> Dict:
try:
status = values["status"]
termination_reason = values["termination_reason"]
exit_code = values["exit_status"]
except KeyError:
return values
values["status_message"] = JobSubmission._get_status_message(
status=status,
termination_reason=termination_reason,
exit_status=exit_code,
)
return values
def dict(self, *args, **kwargs) -> Dict:
status_message = self._get_status_message()
error = self._get_error()
# super() does not work with pydantic-duality
res = CoreModel.dict(self, *args, **kwargs)
res["status_message"] = status_message
res["error"] = error
return res

@staticmethod
def _get_status_message(
status: JobStatus,
termination_reason: Optional[JobTerminationReason],
exit_status: Optional[int],
) -> str:
if status == JobStatus.DONE:
def _get_status_message(self) -> Optional[str]:
if self.status == JobStatus.DONE:
return "exited (0)"
elif status == JobStatus.FAILED:
if termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
return f"exited ({exit_status})"
elif termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY:
elif self.status == JobStatus.FAILED:
if self.termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
return f"exited ({self.exit_status})"
elif (
self.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
):
return "no offers"
elif termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
elif self.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
return "interrupted"
else:
return "error"
elif status == JobStatus.TERMINATED:
if termination_reason == JobTerminationReason.TERMINATED_BY_USER:
elif self.status == JobStatus.TERMINATED:
if self.termination_reason == JobTerminationReason.TERMINATED_BY_USER:
return "stopped"
elif termination_reason == JobTerminationReason.ABORTED_BY_USER:
elif self.termination_reason == JobTerminationReason.ABORTED_BY_USER:
return "aborted"
return status.value
return self.status.value

@root_validator
def _error(cls, values) -> Dict:
try:
termination_reason = values["termination_reason"]
except KeyError:
return values
values["error"] = JobSubmission._get_error(termination_reason=termination_reason)
return values
def _get_error(self) -> Optional[str]:
return JobSubmission._termination_reason_to_error(
termination_reason=self.termination_reason
)

@staticmethod
def _get_error(termination_reason: Optional[JobTerminationReason]) -> Optional[str]:
def _termination_reason_to_error(
termination_reason: Optional[JobTerminationReason],
) -> Optional[str]:
error_mapping = {
JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
Expand All @@ -395,6 +384,12 @@ class Job(CoreModel):
job_spec: JobSpec
job_submissions: List[JobSubmission]

def get_last_termination_reason(self) -> Optional[JobTerminationReason]:
for submission in reversed(self.job_submissions):
if submission.termination_reason is not None:
return submission.termination_reason
return None


class RunSpec(CoreModel):
# TODO: run_name, working_dir are redundant here since they already passed in configuration
Expand Down Expand Up @@ -525,87 +520,70 @@ class Run(CoreModel):
last_processed_at: datetime
status: RunStatus
status_message: Optional[str] = None
termination_reason: Optional[RunTerminationReason]
termination_reason: Optional[RunTerminationReason] = None
run_spec: RunSpec
jobs: List[Job]
latest_job_submission: Optional[JobSubmission]
latest_job_submission: Optional[JobSubmission] = None
cost: float = 0
service: Optional[ServiceSpec] = None
deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers
# TODO: make error a computed field after migrating to pydanticV2
error: Optional[str] = None
deleted: Optional[bool] = None

@root_validator
def _error(cls, values) -> Dict:
try:
termination_reason = values["termination_reason"]
except KeyError:
return values
values["error"] = Run._get_error(termination_reason=termination_reason)
return values
def dict(self, *args, **kwargs) -> Dict:
status_message = self._get_status_message()
error = self._get_error()
# super() does not work with pydantic-duality
res = CoreModel.dict(self, *args, **kwargs)
res["status_message"] = status_message
res["error"] = error
return res

def _get_error(self) -> Optional[str]:
return Run._termination_reason_to_error(termination_reason=self.termination_reason)

@staticmethod
def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[str]:
def _termination_reason_to_error(
termination_reason: Optional[RunTerminationReason],
) -> Optional[str]:
if termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED:
return "retry limit exceeded"
elif termination_reason == RunTerminationReason.SERVER_ERROR:
return "server error"
else:
return None

@root_validator
def _status_message(cls, values) -> Dict:
def _get_status_message(self) -> Optional[str]:
if len(self.jobs) == 0:
return self.status.value

last_job = self.jobs[0]
# FIXME: status_message should not require all job submissions for status calculation
# since it's very expensive and is not required for anything else.
# May return a different status if not all job submissions requested.
# TODO: Calculate status_message by looking at job models directly instead job submissions.
try:
status = values["status"]
jobs: List[Job] = values["jobs"]
retry_on_events = (
jobs[0].job_spec.retry.on_events if jobs and jobs[0].job_spec.retry else []
)
job_status = (
jobs[0].job_submissions[-1].status
if len(jobs) == 1 and jobs[0].job_submissions
else None
)
termination_reason = Run.get_last_termination_reason(jobs[0]) if jobs else None
except KeyError:
return values
values["status_message"] = Run._get_status_message(
status=status,
job_status=job_status,
retry_on_events=retry_on_events,
termination_reason=termination_reason,
)
return values
last_job_termination_reason = last_job.get_last_termination_reason()

@staticmethod
def get_last_termination_reason(job: "Job") -> Optional[JobTerminationReason]:
for submission in reversed(job.job_submissions):
if submission.termination_reason is not None:
return submission.termination_reason
return None
if len(self.jobs) == 1:
# FIXME: Clarify why show "pulling" only in case of one job
if (
last_job.job_submissions
and last_job.job_submissions[-1].status == JobStatus.PULLING
):
return "pulling"

@staticmethod
def _get_status_message(
status: RunStatus,
job_status: Optional[JobStatus],
retry_on_events: List[RetryEvent],
termination_reason: Optional[JobTerminationReason],
) -> str:
if job_status == JobStatus.PULLING:
return "pulling"
retry_on_events = last_job.job_spec.retry.on_events if last_job.job_spec.retry else []
# Currently, `retrying` is shown only for `no-capacity` events
if (
status in [RunStatus.SUBMITTED, RunStatus.PENDING]
and termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
self.status in [RunStatus.SUBMITTED, RunStatus.PENDING]
and last_job_termination_reason
== JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
and RetryEvent.NO_CAPACITY in retry_on_events
):
return "retrying"
return status.value

return self.status.value

def is_deployment_in_progress(self) -> bool:
return any(
Expand Down
Loading
Loading