From 1992632af29cfc7e6f0d2c6187baed42cc729f6c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 7 Jul 2025 10:40:21 +0500 Subject: [PATCH 01/12] Use orjson --- pyproject.toml | 1 + src/dstack/_internal/core/models/common.py | 17 +++++++-- src/dstack/_internal/server/app.py | 19 ++++++---- src/dstack/_internal/server/routers/runs.py | 36 +++++++++++-------- src/dstack/_internal/server/schemas/logs.py | 4 +-- src/dstack/_internal/server/utils/routers.py | 36 ++++++++++++++----- src/dstack/_internal/utils/json_utils.py | 17 +++++++++ .../_internal/server/services/test_logs.py | 4 +-- 8 files changed, 99 insertions(+), 35 deletions(-) create mode 100644 src/dstack/_internal/utils/json_utils.py diff --git a/pyproject.toml b/pyproject.toml index 47353886bf..736ba67682 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "gpuhunt==0.1.6", "argcomplete>=3.5.0", "ignore-python>=0.2.0", + "orjson", ] [project.urls] diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index c347cf0d32..0a40f6da67 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -1,11 +1,14 @@ import re from enum import Enum -from typing import Union +from typing import Any, Union +import orjson from pydantic import Field from pydantic_duality import DualBaseModel from typing_extensions import Annotated +from dstack._internal.utils.json_utils import get_orjson_options, orjson_default + IncludeExcludeFieldType = Union[int, str] IncludeExcludeSetType = set[IncludeExcludeFieldType] IncludeExcludeDictType = dict[ @@ -14,13 +17,23 @@ IncludeExcludeType = Union[IncludeExcludeSetType, IncludeExcludeDictType] +def _orjson_dumps(v: Any, *, default: Any) -> str: + return orjson.dumps( + v, + option=get_orjson_options(), + default=orjson_default, + ).decode() + + # DualBaseModel creates two classes for the model: # one with extra = "forbid" (CoreModel/CoreModel.__request__), # and another with extra = "ignore" (CoreModel.__response__). # This allows to use the same model both for a strict parsing of the user input and # for a permissive parsing of the server responses. class CoreModel(DualBaseModel): - pass + class Config: + json_loads = orjson.loads + json_dumps = _orjson_dumps class Duration(int): diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 7435ff8cb0..3a993a6282 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -10,7 +10,7 @@ import sentry_sdk from fastapi import FastAPI, Request, Response, status from fastapi.datastructures import URL -from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from prometheus_client import Counter, Histogram @@ -56,6 +56,7 @@ ) from dstack._internal.server.utils.logging import configure_logging from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, check_client_server_compatibility, error_detail, get_server_client_error_details, @@ -90,7 +91,11 @@ def create_app() -> FastAPI: profiles_sample_rate=settings.SENTRY_PROFILES_SAMPLE_RATE, ) - app = FastAPI(docs_url="/api/docs", lifespan=lifespan) + app = FastAPI( + docs_url="/api/docs", + lifespan=lifespan, + default_response_class=CustomORJSONResponse, + ) app.state.proxy_dependency_injector = ServerProxyDependencyInjector() return app @@ -208,14 +213,14 @@ async def forbidden_error_handler(request: Request, exc: ForbiddenError): msg = "Access denied" if len(exc.args) > 0: msg = exc.args[0] - return JSONResponse( + return CustomORJSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=error_detail(msg), ) @app.exception_handler(ServerClientError) async def server_client_error_handler(request: Request, exc: ServerClientError): - return JSONResponse( + return CustomORJSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": get_server_client_error_details(exc)}, ) @@ -223,7 +228,7 @@ async def server_client_error_handler(request: Request, exc: ServerClientError): @app.exception_handler(OSError) async def os_error_handler(request, exc: OSError): if exc.errno in [36, 63]: - return JSONResponse( + return CustomORJSONResponse( {"detail": "Filename too long"}, status_code=status.HTTP_400_BAD_REQUEST, ) @@ -309,7 +314,7 @@ async def check_client_version(request: Request, call_next): @app.get("/healthcheck") async def healthcheck(): - return JSONResponse(content={"status": "running"}) + return CustomORJSONResponse(content={"status": "running"}) if ui and Path(__file__).parent.joinpath("statics").exists(): app.mount( @@ -323,7 +328,7 @@ async def custom_http_exception_handler(request, exc): or _is_proxy_request(request) or _is_prometheus_request(request) ): - return JSONResponse( + return CustomORJSONResponse( {"detail": exc.detail}, status_code=status.HTTP_404_NOT_FOUND, ) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index c6a4b60f80..322a08ef8d 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -18,7 +18,10 @@ ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember from dstack._internal.server.services import runs -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/runs", @@ -32,12 +35,15 @@ ) -@root_router.post("/list") +@root_router.post( + "/list", + response_model=List[Run], +) async def list_runs( body: ListRunsRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Run]: +): """ Returns all runs visible to user sorted by descending `submitted_at`. `project_name`, `repo_id`, `username`, and `only_active` can be specified as filters. @@ -47,17 +53,19 @@ async def list_runs( The results are paginated. To get the next page, pass `submitted_at` and `id` of the last run from the previous page as `prev_submitted_at` and `prev_run_id`. """ - return await runs.list_user_runs( - session=session, - user=user, - project_name=body.project_name, - repo_id=body.repo_id, - username=body.username, - only_active=body.only_active, - prev_submitted_at=body.prev_submitted_at, - prev_run_id=body.prev_run_id, - limit=body.limit, - ascending=body.ascending, + return CustomORJSONResponse( + await runs.list_user_runs( + session=session, + user=user, + project_name=body.project_name, + repo_id=body.repo_id, + username=body.username, + only_active=body.only_active, + prev_submitted_at=body.prev_submitted_at, + prev_run_id=body.prev_run_id, + limit=body.limit, + ascending=body.ascending, + ) ) diff --git a/src/dstack/_internal/server/schemas/logs.py b/src/dstack/_internal/server/schemas/logs.py index 267f5612fa..f97d4fde37 100644 --- a/src/dstack/_internal/server/schemas/logs.py +++ b/src/dstack/_internal/server/schemas/logs.py @@ -9,8 +9,8 @@ class PollLogsRequest(CoreModel): run_name: str job_submission_id: UUID4 - start_time: Optional[datetime] - end_time: Optional[datetime] + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None descending: bool = False next_token: Optional[str] = None limit: int = Field(100, ge=0, le=1000) diff --git a/src/dstack/_internal/server/utils/routers.py b/src/dstack/_internal/server/utils/routers.py index f8ca21004b..e6576a1494 100644 --- a/src/dstack/_internal/server/utils/routers.py +++ b/src/dstack/_internal/server/utils/routers.py @@ -1,11 +1,31 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional -from fastapi import HTTPException, Request, status -from fastapi.responses import JSONResponse +import orjson +from fastapi import HTTPException, Request, Response, status from packaging import version from dstack._internal.core.errors import ServerClientError, ServerClientErrorCode from dstack._internal.core.models.common import CoreModel +from dstack._internal.utils.json_utils import get_orjson_options, orjson_default + + +class CustomORJSONResponse(Response): + """ + Custom JSONResponse that uses orjson for serialization. + + It's recommended to return this class from routers directly + to avoid the FastAPI's jsonable_encoder overhead. + See https://fastapi.tiangolo.com/advanced/custom-response/#use-orjsonresponse. + """ + + media_type = "application/json" + + def render(self, content: Any) -> bytes: + return orjson.dumps( + content, + option=get_orjson_options(), + default=orjson_default, + ) class BadRequestDetailsModel(CoreModel): @@ -30,7 +50,7 @@ def get_base_api_additional_responses() -> Dict: """ Returns additional responses for the OpenAPI docs relevant to all API endpoints. The endpoints may override responses to make them as specific as possible. - E.g. an enpoint may specify which error codes it may return in `code`. + E.g. an endpoint may specify which error codes it may return in `code`. """ return { 400: get_bad_request_additional_response(), @@ -102,7 +122,7 @@ def get_request_size(request: Request) -> int: def check_client_server_compatibility( client_version: Optional[str], server_version: Optional[str], -) -> Optional[JSONResponse]: +) -> Optional[CustomORJSONResponse]: """ Returns `JSONResponse` with error if client/server versions are incompatible. Returns `None` otherwise. @@ -116,7 +136,7 @@ def check_client_server_compatibility( try: parsed_client_version = version.parse(client_version) except version.InvalidVersion: - return JSONResponse( + return CustomORJSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={ "detail": get_server_client_error_details( @@ -138,11 +158,11 @@ def error_incompatible_versions( client_version: Optional[str], server_version: str, ask_cli_update: bool, -) -> JSONResponse: +) -> CustomORJSONResponse: msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})." if ask_cli_update: msg += f" Update the dstack CLI: `pip install dstack=={server_version}`." - return JSONResponse( + return CustomORJSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"detail": get_server_client_error_details(ServerClientError(msg=msg))}, ) diff --git a/src/dstack/_internal/utils/json_utils.py b/src/dstack/_internal/utils/json_utils.py new file mode 100644 index 0000000000..27f7831fa7 --- /dev/null +++ b/src/dstack/_internal/utils/json_utils.py @@ -0,0 +1,17 @@ +import orjson +from pydantic import BaseModel + + +def orjson_default(obj): + if isinstance(obj, float): + # orjson does not convert float subclasses be default + return float(obj) + if isinstance(obj, BaseModel): + # Allows calling orjson.dumps() on pydantic models + # (e.g. to return from the API) + return obj.dict() + raise TypeError + + +def get_orjson_options() -> int: + return orjson.OPT_NON_STR_KEYS diff --git a/src/tests/_internal/server/services/test_logs.py b/src/tests/_internal/server/services/test_logs.py index 19769a3602..553b9dffe9 100644 --- a/src/tests/_internal/server/services/test_logs.py +++ b/src/tests/_internal/server/services/test_logs.py @@ -51,8 +51,8 @@ async def test_writes_logs(self, test_db, session: AsyncSession, tmp_path: Path) / "runner.log" ) assert runner_log_path.read_text() == ( - '{"timestamp": "2023-10-06T10:01:53.234000+00:00", "log_source": "stdout", "message": "SGVsbG8="}\n' - '{"timestamp": "2023-10-06T10:01:53.235000+00:00", "log_source": "stdout", "message": "V29ybGQ="}\n' + '{"timestamp":"2023-10-06T10:01:53.234000+00:00","log_source":"stdout","message":"SGVsbG8="}\n' + '{"timestamp":"2023-10-06T10:01:53.235000+00:00","log_source":"stdout","message":"V29ybGQ="}\n' ) @pytest.mark.asyncio From 262c0f87a6ac4f01cfbadd4c524303c1ac15496b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 7 Jul 2025 10:54:00 +0500 Subject: [PATCH 02/12] Use CustomORJSONResponse for runs API --- src/dstack/_internal/server/routers/runs.py | 39 +++++++++++++-------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 322a08ef8d..ba73b2f332 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -69,12 +69,15 @@ async def list_runs( ) -@project_router.post("/get") +@project_router.post( + "/get", + response_model=Run, +) async def get_run( body: GetRunRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Run: +): """ Returns a run given `run_name` or `id`. If given `run_name`, does not return deleted runs. @@ -89,15 +92,18 @@ async def get_run( ) if run is None: raise ResourceNotExistsError("Run not found") - return run + return CustomORJSONResponse(run) -@project_router.post("/get_plan") +@project_router.post( + "/get_plan", + response_model=RunPlan, +) async def get_plan( body: GetRunPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> RunPlan: +): """ Returns a run plan for the given run spec. This is an optional step before calling `/apply`. @@ -110,15 +116,18 @@ async def get_plan( run_spec=body.run_spec, max_offers=body.max_offers, ) - return run_plan + return CustomORJSONResponse(run_plan) -@project_router.post("/apply") +@project_router.post( + "/apply", + response_model=Run, +) async def apply_plan( body: ApplyRunPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Run: +): """ Creates a new run or updates an existing run. Errors if the expected current resource from the plan does not match the current resource. @@ -126,12 +135,14 @@ async def apply_plan( If the existing run is active and cannot be updated, it must be stopped first. """ user, project = user_project - return await runs.apply_plan( - session=session, - user=user, - project=project, - plan=body.plan, - force=body.force, + return CustomORJSONResponse( + await runs.apply_plan( + session=session, + user=user, + project=project, + plan=body.plan, + force=body.force, + ) ) From 07a09ade6b430f264c403a46d378f3d761d3352f Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 7 Jul 2025 11:00:43 +0500 Subject: [PATCH 03/12] refactor: wrap fleet API responses with CustomORJSONResponse --- src/dstack/_internal/server/routers/fleets.py | 59 +++++++++++-------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index 92aca9e18f..4d106f8546 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -18,7 +18,10 @@ ListFleetsRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/fleets", @@ -45,15 +48,17 @@ async def list_fleets( The results are paginated. To get the next page, pass `created_at` and `id` of the last fleet from the previous page as `prev_created_at` and `prev_id`. """ - return await fleets_services.list_fleets( - session=session, - user=user, - project_name=body.project_name, - only_active=body.only_active, - prev_created_at=body.prev_created_at, - prev_id=body.prev_id, - limit=body.limit, - ascending=body.ascending, + return CustomORJSONResponse( + await fleets_services.list_fleets( + session=session, + user=user, + project_name=body.project_name, + only_active=body.only_active, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) ) @@ -66,7 +71,9 @@ async def list_project_fleets( Returns all fleets in the project. """ _, project = user_project - return await fleets_services.list_project_fleets(session=session, project=project) + return CustomORJSONResponse( + await fleets_services.list_project_fleets(session=session, project=project) + ) @project_router.post("/get") @@ -86,7 +93,7 @@ async def get_fleet( ) if fleet is None: raise ResourceNotExistsError() - return fleet + return CustomORJSONResponse(fleet) @project_router.post("/get_plan") @@ -105,7 +112,7 @@ async def get_plan( user=user, spec=body.spec, ) - return plan + return CustomORJSONResponse(plan) @project_router.post("/apply") @@ -120,12 +127,14 @@ async def apply_plan( Use `force: true` to apply even if the current resource does not match. """ user, project = user_project - return await fleets_services.apply_plan( - session=session, - user=user, - project=project, - plan=body.plan, - force=body.force, + return CustomORJSONResponse( + await fleets_services.apply_plan( + session=session, + user=user, + project=project, + plan=body.plan, + force=body.force, + ) ) @@ -139,11 +148,13 @@ async def create_fleet( Creates a fleet given a fleet configuration. """ user, project = user_project - return await fleets_services.create_fleet( - session=session, - project=project, - user=user, - spec=body.spec, + return CustomORJSONResponse( + await fleets_services.create_fleet( + session=session, + project=project, + user=user, + spec=body.spec, + ) ) From c925da883efce99f42f7ffbc27cbff6fd691639e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 7 Jul 2025 11:00:45 +0500 Subject: [PATCH 04/12] feat: add response_model to decorators and remove return type annotations Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) --- src/dstack/_internal/server/routers/fleets.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index 4d106f8546..3cbab9508c 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -35,12 +35,12 @@ ) -@root_router.post("/list") +@root_router.post("/list", response_model=List[Fleet]) async def list_fleets( body: ListFleetsRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Fleet]: +): """ Returns all fleets and instances within them visible to user sorted by descending `created_at`. `project_name` and `only_active` can be specified as filters. @@ -62,11 +62,11 @@ async def list_fleets( ) -@project_router.post("/list") +@project_router.post("/list", response_model=List[Fleet]) async def list_project_fleets( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> List[Fleet]: +): """ Returns all fleets in the project. """ @@ -76,12 +76,12 @@ async def list_project_fleets( ) -@project_router.post("/get") +@project_router.post("/get", response_model=Fleet) async def get_fleet( body: GetFleetRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Fleet: +): """ Returns a fleet given `name` or `id`. If given `name`, does not return deleted fleets. @@ -96,12 +96,12 @@ async def get_fleet( return CustomORJSONResponse(fleet) -@project_router.post("/get_plan") +@project_router.post("/get_plan", response_model=FleetPlan) async def get_plan( body: GetFleetPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> FleetPlan: +): """ Returns a fleet plan for the given fleet configuration. """ @@ -115,12 +115,12 @@ async def get_plan( return CustomORJSONResponse(plan) -@project_router.post("/apply") +@project_router.post("/apply", response_model=Fleet) async def apply_plan( body: ApplyFleetPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Fleet: +): """ Creates a new fleet or updates an existing fleet. Errors if the expected current resource from the plan does not match the current resource. @@ -138,12 +138,12 @@ async def apply_plan( ) -@project_router.post("/create") +@project_router.post("/create", response_model=Fleet) async def create_fleet( body: CreateFleetRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Fleet: +): """ Creates a fleet given a fleet configuration. """ From fcc8dc5c78cd3233cc078aaae9047ebf13aca150 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 7 Jul 2025 11:07:02 +0500 Subject: [PATCH 05/12] refactor: Consistently use CustomORJSONResponse and response_model in router files Co-authored-by: aider (bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0) --- .../_internal/server/routers/backends.py | 36 ++++++----- src/dstack/_internal/server/routers/files.py | 13 ++-- .../_internal/server/routers/gateways.py | 42 +++++++------ .../_internal/server/routers/instances.py | 28 +++++---- src/dstack/_internal/server/routers/logs.py | 9 ++- .../_internal/server/routers/metrics.py | 19 +++--- .../_internal/server/routers/projects.py | 62 ++++++++++++------- src/dstack/_internal/server/routers/repos.py | 15 +++-- .../_internal/server/routers/secrets.py | 35 ++++++----- src/dstack/_internal/server/routers/server.py | 11 ++-- src/dstack/_internal/server/routers/users.py | 48 ++++++++------ .../_internal/server/routers/volumes.py | 56 +++++++++-------- 12 files changed, 219 insertions(+), 155 deletions(-) diff --git a/src/dstack/_internal/server/routers/backends.py b/src/dstack/_internal/server/routers/backends.py index 7b6056b92b..e9d79318d8 100644 --- a/src/dstack/_internal/server/routers/backends.py +++ b/src/dstack/_internal/server/routers/backends.py @@ -27,7 +27,7 @@ get_backend_config_yaml, update_backend_config_yaml, ) -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses root_router = APIRouter( prefix="/api/backends", @@ -41,35 +41,37 @@ ) -@root_router.post("/list_types") -async def list_backend_types() -> List[BackendType]: - return dstack._internal.core.backends.configurators.list_available_backend_types() +@root_router.post("/list_types", response_model=List[BackendType]) +async def list_backend_types(): + return CustomORJSONResponse( + dstack._internal.core.backends.configurators.list_available_backend_types() + ) -@project_router.post("/create") +@project_router.post("/create", response_model=AnyBackendConfigWithCreds) async def create_backend( body: AnyBackendConfigWithCreds, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> AnyBackendConfigWithCreds: +): _, project = user_project config = await backends.create_backend(session=session, project=project, config=body) if settings.SERVER_CONFIG_ENABLED: await ServerConfigManager().sync_config(session=session) - return config + return CustomORJSONResponse(config) -@project_router.post("/update") +@project_router.post("/update", response_model=AnyBackendConfigWithCreds) async def update_backend( body: AnyBackendConfigWithCreds, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> AnyBackendConfigWithCreds: +): _, project = user_project config = await backends.update_backend(session=session, project=project, config=body) if settings.SERVER_CONFIG_ENABLED: await ServerConfigManager().sync_config(session=session) - return config + return CustomORJSONResponse(config) @project_router.post("/delete") @@ -86,16 +88,16 @@ async def delete_backends( await ServerConfigManager().sync_config(session=session) -@project_router.post("/{backend_name}/config_info") +@project_router.post("/{backend_name}/config_info", response_model=AnyBackendConfigWithCreds) async def get_backend_config_info( backend_name: BackendType, user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> AnyBackendConfigWithCreds: +): _, project = user_project config = await backends.get_backend_config(project=project, backend_type=backend_name) if config is None: raise ResourceNotExistsError() - return config + return CustomORJSONResponse(config) @project_router.post("/create_yaml") @@ -126,10 +128,12 @@ async def update_backend_yaml( ) -@project_router.post("/{backend_name}/get_yaml") +@project_router.post("/{backend_name}/get_yaml", response_model=BackendInfoYAML) async def get_backend_yaml( backend_name: BackendType, user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> BackendInfoYAML: +): _, project = user_project - return await get_backend_config_yaml(project=project, backend_type=backend_name) + return CustomORJSONResponse( + await get_backend_config_yaml(project=project, backend_type=backend_name) + ) diff --git a/src/dstack/_internal/server/routers/files.py b/src/dstack/_internal/server/routers/files.py index 574ef01776..ff7b2a3d57 100644 --- a/src/dstack/_internal/server/routers/files.py +++ b/src/dstack/_internal/server/routers/files.py @@ -12,6 +12,7 @@ from dstack._internal.server.services import files from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, get_base_api_additional_responses, get_request_size, ) @@ -24,12 +25,12 @@ ) -@router.post("/get_archive_by_hash") +@router.post("/get_archive_by_hash", response_model=FileArchive) async def get_archive_by_hash( body: GetFileArchiveByHashRequest, session: Annotated[AsyncSession, Depends(get_session)], user: Annotated[UserModel, Depends(Authenticated())], -) -> FileArchive: +): archive = await files.get_archive_by_hash( session=session, user=user, @@ -37,16 +38,16 @@ async def get_archive_by_hash( ) if archive is None: raise ResourceNotExistsError() - return archive + return CustomORJSONResponse(archive) -@router.post("/upload_archive") +@router.post("/upload_archive", response_model=FileArchive) async def upload_archive( request: Request, file: UploadFile, session: Annotated[AsyncSession, Depends(get_session)], user: Annotated[UserModel, Depends(Authenticated())], -) -> FileArchive: +): request_size = get_request_size(request) if SERVER_CODE_UPLOAD_LIMIT > 0 and request_size > SERVER_CODE_UPLOAD_LIMIT: diff_size_fmt = sizeof_fmt(request_size) @@ -64,4 +65,4 @@ async def upload_archive( user=user, file=file, ) - return archive + return CustomORJSONResponse(archive) diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index e0e0ad37d1..13be132f48 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -13,7 +13,7 @@ ProjectAdmin, ProjectMemberOrPublicAccess, ) -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses router = APIRouter( prefix="/api/project/{project_name}/gateways", @@ -22,40 +22,44 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[models.Gateway]) async def list_gateways( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), -) -> List[models.Gateway]: +): _, project = user_project - return await gateways.list_project_gateways(session=session, project=project) + return CustomORJSONResponse( + await gateways.list_project_gateways(session=session, project=project) + ) -@router.post("/get") +@router.post("/get", response_model=models.Gateway) async def get_gateway( body: schemas.GetGatewayRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), -) -> models.Gateway: +): _, project = user_project gateway = await gateways.get_gateway_by_name(session=session, project=project, name=body.name) if gateway is None: raise ResourceNotExistsError() - return gateway + return CustomORJSONResponse(gateway) -@router.post("/create") +@router.post("/create", response_model=models.Gateway) async def create_gateway( body: schemas.CreateGatewayRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> models.Gateway: +): user, project = user_project - return await gateways.create_gateway( - session=session, - user=user, - project=project, - configuration=body.configuration, + return CustomORJSONResponse( + await gateways.create_gateway( + session=session, + user=user, + project=project, + configuration=body.configuration, + ) ) @@ -83,13 +87,15 @@ async def set_default_gateway( await gateways.set_default_gateway(session=session, project=project, name=body.name) -@router.post("/set_wildcard_domain") +@router.post("/set_wildcard_domain", response_model=models.Gateway) async def set_gateway_wildcard_domain( body: schemas.SetWildcardDomainRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> models.Gateway: +): _, project = user_project - return await gateways.set_gateway_wildcard_domain( - session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain + return CustomORJSONResponse( + await gateways.set_gateway_wildcard_domain( + session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain + ) ) diff --git a/src/dstack/_internal/server/routers/instances.py b/src/dstack/_internal/server/routers/instances.py index 489b3bf1c2..7d873c8f1b 100644 --- a/src/dstack/_internal/server/routers/instances.py +++ b/src/dstack/_internal/server/routers/instances.py @@ -9,7 +9,7 @@ from dstack._internal.server.models import UserModel from dstack._internal.server.schemas.instances import ListInstancesRequest from dstack._internal.server.security.permissions import Authenticated -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses root_router = APIRouter( prefix="/api/instances", @@ -18,12 +18,12 @@ ) -@root_router.post("/list") +@root_router.post("/list", response_model=List[Instance]) async def list_instances( body: ListInstancesRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Instance]: +): """ Returns all instances visible to user sorted by descending `created_at`. `project_names` and `fleet_ids` can be specified as filters. @@ -31,14 +31,16 @@ async def list_instances( The results are paginated. To get the next page, pass `created_at` and `id` of the last instance from the previous page as `prev_created_at` and `prev_id`. """ - return await instances.list_user_instances( - session=session, - user=user, - project_names=body.project_names, - fleet_ids=body.fleet_ids, - only_active=body.only_active, - prev_created_at=body.prev_created_at, - prev_id=body.prev_id, - limit=body.limit, - ascending=body.ascending, + return CustomORJSONResponse( + await instances.list_user_instances( + session=session, + user=user, + project_names=body.project_names, + fleet_ids=body.fleet_ids, + only_active=body.only_active, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) ) diff --git a/src/dstack/_internal/server/routers/logs.py b/src/dstack/_internal/server/routers/logs.py index a86424ee62..639fa7101b 100644 --- a/src/dstack/_internal/server/routers/logs.py +++ b/src/dstack/_internal/server/routers/logs.py @@ -7,7 +7,7 @@ from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.server.security.permissions import ProjectMember from dstack._internal.server.services import logs -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses router = APIRouter( prefix="/api/project/{project_name}/logs", @@ -18,13 +18,16 @@ @router.post( "/poll", + response_model=JobSubmissionLogs, ) async def poll_logs( body: PollLogsRequest, user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> JobSubmissionLogs: +): _, project = user_project # The runner guarantees logs have different timestamps if throughput < 1k logs / sec. # Otherwise, some logs with duplicated timestamps may be filtered out. # This limitation is imposed by cloud log services that support up to millisecond timestamp resolution. - return await logs.poll_logs_async(project=project, request=body) + return CustomORJSONResponse( + await logs.poll_logs_async(project=project, request=body) + ) diff --git a/src/dstack/_internal/server/routers/metrics.py b/src/dstack/_internal/server/routers/metrics.py index 1d4ffb1db0..56025f2fd0 100644 --- a/src/dstack/_internal/server/routers/metrics.py +++ b/src/dstack/_internal/server/routers/metrics.py @@ -11,7 +11,7 @@ from dstack._internal.server.security.permissions import ProjectMember from dstack._internal.server.services import metrics from dstack._internal.server.services.jobs import get_run_job_model -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses router = APIRouter( prefix="/api/project/{project_name}/metrics", @@ -22,6 +22,7 @@ @router.get( "/job/{run_name}", + response_model=JobMetrics, ) async def get_job_metrics( run_name: str, @@ -32,7 +33,7 @@ async def get_job_metrics( before: Optional[datetime] = None, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> JobMetrics: +): """ Returns job-level metrics such as hardware utilization given `run_name`, `replica_num`, and `job_num`. @@ -63,10 +64,12 @@ async def get_job_metrics( if job_model is None: raise ResourceNotExistsError("Found no job with given parameters") - return await metrics.get_job_metrics( - session=session, - job_model=job_model, - limit=limit, - after=after, - before=before, + return CustomORJSONResponse( + await metrics.get_job_metrics( + session=session, + job_model=job_model, + limit=limit, + after=after, + before=before, + ) ) diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index 1d967c6c8d..2d15907e78 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -23,7 +23,7 @@ ProjectMemberOrPublicAccess, ) from dstack._internal.server.services import projects -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses router = APIRouter( prefix="/api/projects", @@ -32,30 +32,34 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[Project]) async def list_projects( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Project]: +): """ Returns all projects visible to user sorted by descending `created_at`. `members` and `backends` are always empty - call `/api/projects/{project_name}/get` to retrieve them. """ - return await projects.list_user_accessible_projects(session=session, user=user) + return CustomORJSONResponse( + await projects.list_user_accessible_projects(session=session, user=user) + ) -@router.post("/create") +@router.post("/create", response_model=Project) async def create_project( body: CreateProjectRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> Project: - return await projects.create_project( - session=session, - user=user, - project_name=body.project_name, - is_public=body.is_public, +): + return CustomORJSONResponse( + await projects.create_project( + session=session, + user=user, + project_name=body.project_name, + is_public=body.is_public, + ) ) @@ -72,23 +76,26 @@ async def delete_projects( ) -@router.post("/{project_name}/get") +@router.post("/{project_name}/get", response_model=Project) async def get_project( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), -) -> Project: +): _, project = user_project - return projects.project_model_to_project(project) + return CustomORJSONResponse( + projects.project_model_to_project(project) + ) @router.post( "/{project_name}/set_members", + response_model=Project, ) async def set_project_members( body: SetProjectMembersRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManager()), -) -> Project: +): user, project = user_project await projects.set_project_members( session=session, @@ -97,17 +104,20 @@ async def set_project_members( members=body.members, ) await session.refresh(project) - return projects.project_model_to_project(project) + return CustomORJSONResponse( + projects.project_model_to_project(project) + ) @router.post( "/{project_name}/add_members", + response_model=Project, ) async def add_project_members( body: AddProjectMemberRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManagerOrPublicProject()), -) -> Project: +): user, project = user_project await projects.add_project_members( session=session, @@ -116,17 +126,20 @@ async def add_project_members( members=body.members, ) await session.refresh(project) - return projects.project_model_to_project(project) + return CustomORJSONResponse( + projects.project_model_to_project(project) + ) @router.post( "/{project_name}/remove_members", + response_model=Project, ) async def remove_project_members( body: RemoveProjectMemberRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManagerOrSelfLeave()), -) -> Project: +): user, project = user_project await projects.remove_project_members( session=session, @@ -135,17 +148,20 @@ async def remove_project_members( usernames=body.usernames, ) await session.refresh(project) - return projects.project_model_to_project(project) + return CustomORJSONResponse( + projects.project_model_to_project(project) + ) @router.post( "/{project_name}/update", + response_model=Project, ) async def update_project( body: UpdateProjectRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> Project: +): user, project = user_project await projects.update_project( session=session, @@ -154,4 +170,6 @@ async def update_project( is_public=body.is_public, ) await session.refresh(project) - return projects.project_model_to_project(project) + return CustomORJSONResponse( + projects.project_model_to_project(project) + ) diff --git a/src/dstack/_internal/server/routers/repos.py b/src/dstack/_internal/server/routers/repos.py index 32e59f6317..a6ef2c4729 100644 --- a/src/dstack/_internal/server/routers/repos.py +++ b/src/dstack/_internal/server/routers/repos.py @@ -16,6 +16,7 @@ from dstack._internal.server.services import repos from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, get_base_api_additional_responses, get_request_size, ) @@ -28,21 +29,23 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[RepoHead]) async def list_repos( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> List[RepoHead]: +): _, project = user_project - return await repos.list_repos(session=session, project=project) + return CustomORJSONResponse( + await repos.list_repos(session=session, project=project) + ) -@router.post("/get") +@router.post("/get", response_model=RepoHeadWithCreds) async def get_repo( body: GetRepoRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> RepoHeadWithCreds: +): user, project = user_project repo = await repos.get_repo( session=session, @@ -53,7 +56,7 @@ async def get_repo( ) if repo is None: raise ResourceNotExistsError() - return repo + return CustomORJSONResponse(repo) @router.post("/init") diff --git a/src/dstack/_internal/server/routers/secrets.py b/src/dstack/_internal/server/routers/secrets.py index bbfa26be93..c19f15bccc 100644 --- a/src/dstack/_internal/server/routers/secrets.py +++ b/src/dstack/_internal/server/routers/secrets.py @@ -14,6 +14,7 @@ ) from dstack._internal.server.security.permissions import ProjectAdmin from dstack._internal.server.services import secrets as secrets_services +from dstack._internal.server.utils.routers import CustomORJSONResponse router = APIRouter( prefix="/api/project/{project_name}/secrets", @@ -21,24 +22,26 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[Secret]) async def list_secrets( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> List[Secret]: +): _, project = user_project - return await secrets_services.list_secrets( - session=session, - project=project, + return CustomORJSONResponse( + await secrets_services.list_secrets( + session=session, + project=project, + ) ) -@router.post("/get") +@router.post("/get", response_model=Secret) async def get_secret( body: GetSecretRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> Secret: +): _, project = user_project secret = await secrets_services.get_secret( session=session, @@ -47,21 +50,23 @@ async def get_secret( ) if secret is None: raise ResourceNotExistsError() - return secret + return CustomORJSONResponse(secret) -@router.post("/create_or_update") +@router.post("/create_or_update", response_model=Secret) async def create_or_update_secret( body: CreateOrUpdateSecretRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), -) -> Secret: +): _, project = user_project - return await secrets_services.create_or_update_secret( - session=session, - project=project, - name=body.name, - value=body.value, + return CustomORJSONResponse( + await secrets_services.create_or_update_secret( + session=session, + project=project, + name=body.name, + value=body.value, + ) ) diff --git a/src/dstack/_internal/server/routers/server.py b/src/dstack/_internal/server/routers/server.py index 31e1e04c93..28c742772b 100644 --- a/src/dstack/_internal/server/routers/server.py +++ b/src/dstack/_internal/server/routers/server.py @@ -2,6 +2,7 @@ from dstack._internal import settings from dstack._internal.core.models.server import ServerInfo +from dstack._internal.server.utils.routers import CustomORJSONResponse router = APIRouter( prefix="/api/server", @@ -9,8 +10,10 @@ ) -@router.post("/get_info") -async def get_server_info() -> ServerInfo: - return ServerInfo( - server_version=settings.DSTACK_VERSION, +@router.post("/get_info", response_model=ServerInfo) +async def get_server_info(): + return CustomORJSONResponse( + ServerInfo( + server_version=settings.DSTACK_VERSION, + ) ) diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index 670f9f0a5d..f3ded25ec4 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -16,7 +16,7 @@ ) from dstack._internal.server.security.permissions import Authenticated, GlobalAdmin from dstack._internal.server.services import users -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses router = APIRouter( prefix="/api/users", @@ -25,41 +25,45 @@ ) -@router.post("/list") +@router.post("/list", response_model=List[User]) async def list_users( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[User]: - return await users.list_users_for_user(session=session, user=user) +): + return CustomORJSONResponse( + await users.list_users_for_user(session=session, user=user) + ) -@router.post("/get_my_user") +@router.post("/get_my_user", response_model=User) async def get_my_user( user: UserModel = Depends(Authenticated()), -) -> User: - return users.user_model_to_user(user) +): + return CustomORJSONResponse( + users.user_model_to_user(user) + ) -@router.post("/get_user") +@router.post("/get_user", response_model=UserWithCreds) async def get_user( body: GetUserRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> UserWithCreds: +): res = await users.get_user_with_creds_by_name( session=session, current_user=user, username=body.username ) if res is None: raise ResourceNotExistsError() - return res + return CustomORJSONResponse(res) -@router.post("/create") +@router.post("/create", response_model=User) async def create_user( body: CreateUserRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(GlobalAdmin()), -) -> User: +): res = await users.create_user( session=session, username=body.username, @@ -67,15 +71,17 @@ async def create_user( email=body.email, active=body.active, ) - return users.user_model_to_user(res) + return CustomORJSONResponse( + users.user_model_to_user(res) + ) -@router.post("/update") +@router.post("/update", response_model=User) async def update_user( body: UpdateUserRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(GlobalAdmin()), -) -> User: +): res = await users.update_user( session=session, username=body.username, @@ -85,19 +91,23 @@ async def update_user( ) if res is None: raise ResourceNotExistsError() - return users.user_model_to_user(res) + return CustomORJSONResponse( + users.user_model_to_user(res) + ) -@router.post("/refresh_token") +@router.post("/refresh_token", response_model=UserWithCreds) async def refresh_token( body: RefreshTokenRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> UserWithCreds: +): res = await users.refresh_user_token(session=session, user=user, username=body.username) if res is None: raise ResourceNotExistsError() - return users.user_model_to_user_with_creds(res) + return CustomORJSONResponse( + users.user_model_to_user_with_creds(res) + ) @router.post("/delete") diff --git a/src/dstack/_internal/server/routers/volumes.py b/src/dstack/_internal/server/routers/volumes.py index d4137099fa..ab7750966b 100644 --- a/src/dstack/_internal/server/routers/volumes.py +++ b/src/dstack/_internal/server/routers/volumes.py @@ -15,7 +15,7 @@ ListVolumesRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses root_router = APIRouter( prefix="/api/volumes", @@ -25,12 +25,12 @@ project_router = APIRouter(prefix="/api/project/{project_name}/volumes", tags=["volumes"]) -@root_router.post("/list") +@root_router.post("/list", response_model=List[Volume]) async def list_volumes( body: ListVolumesRequest, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), -) -> List[Volume]: +): """ Returns all volumes visible to user sorted by descending `created_at`. `project_name` and `only_active` can be specified as filters. @@ -38,36 +38,40 @@ async def list_volumes( The results are paginated. To get the next page, pass `created_at` and `id` of the last fleet from the previous page as `prev_created_at` and `prev_id`. """ - return await volumes_services.list_volumes( - session=session, - user=user, - project_name=body.project_name, - only_active=body.only_active, - prev_created_at=body.prev_created_at, - prev_id=body.prev_id, - limit=body.limit, - ascending=body.ascending, + return CustomORJSONResponse( + await volumes_services.list_volumes( + session=session, + user=user, + project_name=body.project_name, + only_active=body.only_active, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) ) -@project_router.post("/list") +@project_router.post("/list", response_model=List[Volume]) async def list_project_volumes( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> List[Volume]: +): """ Returns all volumes in the project. """ _, project = user_project - return await volumes_services.list_project_volumes(session=session, project=project) + return CustomORJSONResponse( + await volumes_services.list_project_volumes(session=session, project=project) + ) -@project_router.post("/get") +@project_router.post("/get", response_model=Volume) async def get_volume( body: GetVolumeRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Volume: +): """ Returns a volume given a volume name. """ @@ -77,24 +81,26 @@ async def get_volume( ) if volume is None: raise ResourceNotExistsError() - return volume + return CustomORJSONResponse(volume) -@project_router.post("/create") +@project_router.post("/create", response_model=Volume) async def create_volume( body: CreateVolumeRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Volume: +): """ Creates a volume given a volume configuration. """ user, project = user_project - return await volumes_services.create_volume( - session=session, - project=project, - user=user, - configuration=body.configuration, + return CustomORJSONResponse( + await volumes_services.create_volume( + session=session, + project=project, + user=user, + configuration=body.configuration, + ) ) From 39bbe3671a3ef48ee34903eaad76e5e9c10b22be Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 7 Jul 2025 18:49:22 +0500 Subject: [PATCH 06/12] Replace validators with dict() for serialization tweaks --- src/dstack/_internal/core/models/common.py | 34 +++- src/dstack/_internal/core/models/resources.py | 23 ++- src/dstack/_internal/core/models/runs.py | 160 ++++++++---------- src/dstack/_internal/server/app.py | 1 - src/dstack/_internal/server/utils/routers.py | 7 +- src/dstack/_internal/utils/json_utils.py | 9 + src/tests/_internal/core/models/test_runs.py | 6 +- .../_internal/server/routers/test_fleets.py | 4 +- 8 files changed, 136 insertions(+), 108 deletions(-) diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index 0a40f6da67..11709d7a04 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -1,8 +1,9 @@ import re from enum import Enum -from typing import Any, Union +from typing import Any, Callable, Union import orjson +from git import Optional from pydantic import Field from pydantic_duality import DualBaseModel from typing_extensions import Annotated @@ -35,6 +36,37 @@ class Config: json_loads = orjson.loads json_dumps = _orjson_dumps + def json( + self, + *, + include: Optional[IncludeExcludeType] = None, + exclude: Optional[IncludeExcludeType] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, # ignore as it's deprecated + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + encoder: Optional[Callable[[Any], Any]] = None, + models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies + **dumps_kwargs: Any, + ) -> str: + """ + Override `json()` method so that it calls `dict()`. + Allows changing how models are serialized by overriding `dict()` only. + By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place. + """ + data = self.dict( + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + if self.__custom_root_type__: + data = data["__root__"] + return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) + class Duration(int): """ diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py index a0ecff9532..15c80f7166 100644 --- a/src/dstack/_internal/core/models/resources.py +++ b/src/dstack/_internal/core/models/resources.py @@ -382,14 +382,6 @@ def schema_extra(schema: Dict[str, Any]): gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK - # TODO: Remove in 0.20. Added for backward compatibility. - @root_validator - def _post_validate(cls, values): - cpu = values.get("cpu") - if isinstance(cpu, CPUSpec) and cpu.arch in [None, gpuhunt.CPUArchitecture.X86]: - values["cpu"] = cpu.count - return values - def pretty_format(self) -> str: # TODO: Remove in 0.20. Use self.cpu directly cpu = parse_obj_as(CPUSpec, self.cpu) @@ -407,3 +399,18 @@ def pretty_format(self) -> str: resources.update(disk_size=self.disk.size) res = pretty_resources(**resources) return res + + def dict(self, *args, **kwargs) -> Dict: + # super() does not work with pydantic-duality + res = CoreModel.dict(self, *args, **kwargs) + self._update_serialized_cpu(res) + return res + + # TODO: Remove in 0.20. Added for backward compatibility. + def _update_serialized_cpu(self, values: Dict): + cpu = values["cpu"] + if cpu: + arch = cpu.get("arch") + count = cpu.get("count") + if count and arch in [None, gpuhunt.CPUArchitecture.X86.value]: + values["cpu"] = count diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 49691eb504..606de4e336 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -325,56 +325,45 @@ def duration(self) -> timedelta: end_time = self.finished_at return end_time - self.submitted_at - @root_validator - def _status_message(cls, values) -> Dict: - try: - status = values["status"] - termination_reason = values["termination_reason"] - exit_code = values["exit_status"] - except KeyError: - return values - values["status_message"] = JobSubmission._get_status_message( - status=status, - termination_reason=termination_reason, - exit_status=exit_code, - ) - return values + def dict(self, *args, **kwargs) -> Dict: + status_message = self._get_status_message() + error = self._get_error() + # super() does not work with pydantic-duality + res = CoreModel.dict(self, *args, **kwargs) + res["status_message"] = status_message + res["error"] = error + return res - @staticmethod - def _get_status_message( - status: JobStatus, - termination_reason: Optional[JobTerminationReason], - exit_status: Optional[int], - ) -> str: - if status == JobStatus.DONE: + def _get_status_message(self) -> Optional[str]: + if self.status == JobStatus.DONE: return "exited (0)" - elif status == JobStatus.FAILED: - if termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR: - return f"exited ({exit_status})" - elif termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY: + elif self.status == JobStatus.FAILED: + if self.termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR: + return f"exited ({self.exit_status})" + elif ( + self.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + ): return "no offers" - elif termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY: + elif self.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY: return "interrupted" else: return "error" - elif status == JobStatus.TERMINATED: - if termination_reason == JobTerminationReason.TERMINATED_BY_USER: + elif self.status == JobStatus.TERMINATED: + if self.termination_reason == JobTerminationReason.TERMINATED_BY_USER: return "stopped" - elif termination_reason == JobTerminationReason.ABORTED_BY_USER: + elif self.termination_reason == JobTerminationReason.ABORTED_BY_USER: return "aborted" - return status.value + return self.status.value - @root_validator - def _error(cls, values) -> Dict: - try: - termination_reason = values["termination_reason"] - except KeyError: - return values - values["error"] = JobSubmission._get_error(termination_reason=termination_reason) - return values + def _get_error(self) -> Optional[str]: + return JobSubmission._termination_reason_to_error( + termination_reason=self.termination_reason + ) @staticmethod - def _get_error(termination_reason: Optional[JobTerminationReason]) -> Optional[str]: + def _termination_reason_to_error( + termination_reason: Optional[JobTerminationReason], + ) -> Optional[str]: error_mapping = { JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable", JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded", @@ -395,6 +384,12 @@ class Job(CoreModel): job_spec: JobSpec job_submissions: List[JobSubmission] + def get_last_termination_reason(self) -> Optional[JobTerminationReason]: + for submission in reversed(self.job_submissions): + if submission.termination_reason is not None: + return submission.termination_reason + return None + class RunSpec(CoreModel): # TODO: run_name, working_dir are redundant here since they already passed in configuration @@ -525,10 +520,10 @@ class Run(CoreModel): last_processed_at: datetime status: RunStatus status_message: Optional[str] = None - termination_reason: Optional[RunTerminationReason] + termination_reason: Optional[RunTerminationReason] = None run_spec: RunSpec jobs: List[Job] - latest_job_submission: Optional[JobSubmission] + latest_job_submission: Optional[JobSubmission] = None cost: float = 0 service: Optional[ServiceSpec] = None deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers @@ -536,17 +531,22 @@ class Run(CoreModel): error: Optional[str] = None deleted: Optional[bool] = None - @root_validator - def _error(cls, values) -> Dict: - try: - termination_reason = values["termination_reason"] - except KeyError: - return values - values["error"] = Run._get_error(termination_reason=termination_reason) - return values + def dict(self, *args, **kwargs) -> Dict: + status_message = self._get_status_message() + error = self._get_error() + # super() does not work with pydantic-duality + res = CoreModel.dict(self, *args, **kwargs) + res["status_message"] = status_message + res["error"] = error + return res + + def _get_error(self) -> Optional[str]: + return Run._termination_reason_to_error(termination_reason=self.termination_reason) @staticmethod - def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[str]: + def _termination_reason_to_error( + termination_reason: Optional[RunTerminationReason], + ) -> Optional[str]: if termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED: return "retry limit exceeded" elif termination_reason == RunTerminationReason.SERVER_ERROR: @@ -554,54 +554,32 @@ def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[s else: return None - @root_validator - def _status_message(cls, values) -> Dict: - try: - status = values["status"] - jobs: List[Job] = values["jobs"] - retry_on_events = ( - jobs[0].job_spec.retry.on_events if jobs and jobs[0].job_spec.retry else [] - ) - job_status = ( - jobs[0].job_submissions[-1].status - if len(jobs) == 1 and jobs[0].job_submissions - else None - ) - termination_reason = Run.get_last_termination_reason(jobs[0]) if jobs else None - except KeyError: - return values - values["status_message"] = Run._get_status_message( - status=status, - job_status=job_status, - retry_on_events=retry_on_events, - termination_reason=termination_reason, - ) - return values + def _get_status_message(self) -> Optional[str]: + if len(self.jobs) == 0: + return self.status.value - @staticmethod - def get_last_termination_reason(job: "Job") -> Optional[JobTerminationReason]: - for submission in reversed(job.job_submissions): - if submission.termination_reason is not None: - return submission.termination_reason - return None + last_job = self.jobs[0] + last_job_termination_reason = last_job.get_last_termination_reason() - @staticmethod - def _get_status_message( - status: RunStatus, - job_status: Optional[JobStatus], - retry_on_events: List[RetryEvent], - termination_reason: Optional[JobTerminationReason], - ) -> str: - if job_status == JobStatus.PULLING: - return "pulling" + if len(self.jobs) == 1: + # FIXME: Clarify why show "pulling" only in case of one job + if ( + last_job.job_submissions + and last_job.job_submissions[-1].status == JobStatus.PULLING + ): + return "pulling" + + retry_on_events = last_job.job_spec.retry.on_events if last_job.job_spec.retry else [] # Currently, `retrying` is shown only for `no-capacity` events if ( - status in [RunStatus.SUBMITTED, RunStatus.PENDING] - and termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + self.status in [RunStatus.SUBMITTED, RunStatus.PENDING] + and last_job_termination_reason + == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY and RetryEvent.NO_CAPACITY in retry_on_events ): return "retrying" - return status.value + + return self.status.value def is_deployment_in_progress(self) -> bool: return any( diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 3a993a6282..6dd09d766c 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -94,7 +94,6 @@ def create_app() -> FastAPI: app = FastAPI( docs_url="/api/docs", lifespan=lifespan, - default_response_class=CustomORJSONResponse, ) app.state.proxy_dependency_injector = ServerProxyDependencyInjector() return app diff --git a/src/dstack/_internal/server/utils/routers.py b/src/dstack/_internal/server/utils/routers.py index e6576a1494..ee2b94fa79 100644 --- a/src/dstack/_internal/server/utils/routers.py +++ b/src/dstack/_internal/server/utils/routers.py @@ -13,9 +13,12 @@ class CustomORJSONResponse(Response): """ Custom JSONResponse that uses orjson for serialization. - It's recommended to return this class from routers directly - to avoid the FastAPI's jsonable_encoder overhead. + It's recommended to return this class from routers directly instead of + returning pydantic models to avoid the FastAPI's jsonable_encoder overhead. See https://fastapi.tiangolo.com/advanced/custom-response/#use-orjsonresponse. + + Beware that FastAPI skips model validation when responses are returned directly. + If serialization needs to be modified, override `dict()` instead of adding validators. """ media_type = "application/json" diff --git a/src/dstack/_internal/utils/json_utils.py b/src/dstack/_internal/utils/json_utils.py index 27f7831fa7..26658d7c84 100644 --- a/src/dstack/_internal/utils/json_utils.py +++ b/src/dstack/_internal/utils/json_utils.py @@ -1,6 +1,12 @@ import orjson from pydantic import BaseModel +FREEZEGUN = True +try: + from freezegun.api import FakeDatetime +except ImportError: + FREEZEGUN = False + def orjson_default(obj): if isinstance(obj, float): @@ -10,6 +16,9 @@ def orjson_default(obj): # Allows calling orjson.dumps() on pydantic models # (e.g. to return from the API) return obj.dict() + if FREEZEGUN: + if isinstance(obj, FakeDatetime): + return obj.isoformat() raise TypeError diff --git a/src/tests/_internal/core/models/test_runs.py b/src/tests/_internal/core/models/test_runs.py index 851cba9e39..6837a4528c 100644 --- a/src/tests/_internal/core/models/test_runs.py +++ b/src/tests/_internal/core/models/test_runs.py @@ -33,7 +33,7 @@ def test_job_termination_reason_to_retry_event_works_with_all_enum_variants(): assert retry_event is None or isinstance(retry_event, RetryEvent) -# Will fail if JobTerminationReason value is added without updaing JobSubmission._get_error +# Will fail if JobTerminationReason value is added without updating JobSubmission._get_error def test_get_error_returns_expected_messages(): no_error_reasons = [ JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY, @@ -47,7 +47,7 @@ def test_get_error_returns_expected_messages(): ] for reason in JobTerminationReason: - if JobSubmission._get_error(reason) is None: + if JobSubmission._termination_reason_to_error(reason) is None: # Fail no-error reason is not in the list assert reason in no_error_reasons @@ -62,6 +62,6 @@ def test_run_get_error_returns_none_for_specific_reasons(): ] for reason in RunTerminationReason: - if Run._get_error(reason) is None: + if Run._termination_reason_to_error(reason) is None: # Fail no-error reason is not in the list assert reason in no_error_reasons diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 87a970c73c..9da8c1f55e 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -855,8 +855,8 @@ async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncC assert response.json() == { "project_name": project.name, "user": user.name, - "spec": spec.dict(), - "effective_spec": spec.dict(), + "spec": json.loads(spec.json()), + "effective_spec": json.loads(spec.json()), "current_resource": None, "offers": [json.loads(o.json()) for o in offers], "total_offers": len(offers), From a84a6a6875056ad7364bf405ae26c5317d4044cf Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 8 Jul 2025 08:23:58 +0500 Subject: [PATCH 07/12] Fix linting --- .../_internal/server/routers/backends.py | 5 +++- .../_internal/server/routers/gateways.py | 5 +++- .../_internal/server/routers/instances.py | 5 +++- src/dstack/_internal/server/routers/logs.py | 9 ++++--- .../_internal/server/routers/metrics.py | 5 +++- .../_internal/server/routers/projects.py | 25 +++++++------------ src/dstack/_internal/server/routers/repos.py | 4 +-- src/dstack/_internal/server/routers/users.py | 25 +++++++------------ .../_internal/server/routers/volumes.py | 5 +++- 9 files changed, 44 insertions(+), 44 deletions(-) diff --git a/src/dstack/_internal/server/routers/backends.py b/src/dstack/_internal/server/routers/backends.py index e9d79318d8..b43463a905 100644 --- a/src/dstack/_internal/server/routers/backends.py +++ b/src/dstack/_internal/server/routers/backends.py @@ -27,7 +27,10 @@ get_backend_config_yaml, update_backend_config_yaml, ) -from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/backends", diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index 13be132f48..fb03a3d69c 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -13,7 +13,10 @@ ProjectAdmin, ProjectMemberOrPublicAccess, ) -from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/project/{project_name}/gateways", diff --git a/src/dstack/_internal/server/routers/instances.py b/src/dstack/_internal/server/routers/instances.py index 7d873c8f1b..740c51fd6c 100644 --- a/src/dstack/_internal/server/routers/instances.py +++ b/src/dstack/_internal/server/routers/instances.py @@ -9,7 +9,10 @@ from dstack._internal.server.models import UserModel from dstack._internal.server.schemas.instances import ListInstancesRequest from dstack._internal.server.security.permissions import Authenticated -from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/instances", diff --git a/src/dstack/_internal/server/routers/logs.py b/src/dstack/_internal/server/routers/logs.py index 639fa7101b..29685f6f5f 100644 --- a/src/dstack/_internal/server/routers/logs.py +++ b/src/dstack/_internal/server/routers/logs.py @@ -7,7 +7,10 @@ from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.server.security.permissions import ProjectMember from dstack._internal.server.services import logs -from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/project/{project_name}/logs", @@ -28,6 +31,4 @@ async def poll_logs( # The runner guarantees logs have different timestamps if throughput < 1k logs / sec. # Otherwise, some logs with duplicated timestamps may be filtered out. # This limitation is imposed by cloud log services that support up to millisecond timestamp resolution. - return CustomORJSONResponse( - await logs.poll_logs_async(project=project, request=body) - ) + return CustomORJSONResponse(await logs.poll_logs_async(project=project, request=body)) diff --git a/src/dstack/_internal/server/routers/metrics.py b/src/dstack/_internal/server/routers/metrics.py index 56025f2fd0..e61a0d9bfa 100644 --- a/src/dstack/_internal/server/routers/metrics.py +++ b/src/dstack/_internal/server/routers/metrics.py @@ -11,7 +11,10 @@ from dstack._internal.server.security.permissions import ProjectMember from dstack._internal.server.services import metrics from dstack._internal.server.services.jobs import get_run_job_model -from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/project/{project_name}/metrics", diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index 2d15907e78..56d41b6ca0 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -23,7 +23,10 @@ ProjectMemberOrPublicAccess, ) from dstack._internal.server.services import projects -from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/projects", @@ -82,9 +85,7 @@ async def get_project( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), ): _, project = user_project - return CustomORJSONResponse( - projects.project_model_to_project(project) - ) + return CustomORJSONResponse(projects.project_model_to_project(project)) @router.post( @@ -104,9 +105,7 @@ async def set_project_members( members=body.members, ) await session.refresh(project) - return CustomORJSONResponse( - projects.project_model_to_project(project) - ) + return CustomORJSONResponse(projects.project_model_to_project(project)) @router.post( @@ -126,9 +125,7 @@ async def add_project_members( members=body.members, ) await session.refresh(project) - return CustomORJSONResponse( - projects.project_model_to_project(project) - ) + return CustomORJSONResponse(projects.project_model_to_project(project)) @router.post( @@ -148,9 +145,7 @@ async def remove_project_members( usernames=body.usernames, ) await session.refresh(project) - return CustomORJSONResponse( - projects.project_model_to_project(project) - ) + return CustomORJSONResponse(projects.project_model_to_project(project)) @router.post( @@ -170,6 +165,4 @@ async def update_project( is_public=body.is_public, ) await session.refresh(project) - return CustomORJSONResponse( - projects.project_model_to_project(project) - ) + return CustomORJSONResponse(projects.project_model_to_project(project)) diff --git a/src/dstack/_internal/server/routers/repos.py b/src/dstack/_internal/server/routers/repos.py index a6ef2c4729..202732f4f0 100644 --- a/src/dstack/_internal/server/routers/repos.py +++ b/src/dstack/_internal/server/routers/repos.py @@ -35,9 +35,7 @@ async def list_repos( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ): _, project = user_project - return CustomORJSONResponse( - await repos.list_repos(session=session, project=project) - ) + return CustomORJSONResponse(await repos.list_repos(session=session, project=project)) @router.post("/get", response_model=RepoHeadWithCreds) diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index f3ded25ec4..abb6729141 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -16,7 +16,10 @@ ) from dstack._internal.server.security.permissions import Authenticated, GlobalAdmin from dstack._internal.server.services import users -from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) router = APIRouter( prefix="/api/users", @@ -30,18 +33,14 @@ async def list_users( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), ): - return CustomORJSONResponse( - await users.list_users_for_user(session=session, user=user) - ) + return CustomORJSONResponse(await users.list_users_for_user(session=session, user=user)) @router.post("/get_my_user", response_model=User) async def get_my_user( user: UserModel = Depends(Authenticated()), ): - return CustomORJSONResponse( - users.user_model_to_user(user) - ) + return CustomORJSONResponse(users.user_model_to_user(user)) @router.post("/get_user", response_model=UserWithCreds) @@ -71,9 +70,7 @@ async def create_user( email=body.email, active=body.active, ) - return CustomORJSONResponse( - users.user_model_to_user(res) - ) + return CustomORJSONResponse(users.user_model_to_user(res)) @router.post("/update", response_model=User) @@ -91,9 +88,7 @@ async def update_user( ) if res is None: raise ResourceNotExistsError() - return CustomORJSONResponse( - users.user_model_to_user(res) - ) + return CustomORJSONResponse(users.user_model_to_user(res)) @router.post("/refresh_token", response_model=UserWithCreds) @@ -105,9 +100,7 @@ async def refresh_token( res = await users.refresh_user_token(session=session, user=user, username=body.username) if res is None: raise ResourceNotExistsError() - return CustomORJSONResponse( - users.user_model_to_user_with_creds(res) - ) + return CustomORJSONResponse(users.user_model_to_user_with_creds(res)) @router.post("/delete") diff --git a/src/dstack/_internal/server/routers/volumes.py b/src/dstack/_internal/server/routers/volumes.py index ab7750966b..2ac5034707 100644 --- a/src/dstack/_internal/server/routers/volumes.py +++ b/src/dstack/_internal/server/routers/volumes.py @@ -15,7 +15,10 @@ ListVolumesRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember -from dstack._internal.server.utils.routers import CustomORJSONResponse, get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) root_router = APIRouter( prefix="/api/volumes", From 4c9d63a5a2305bc4b1edf92926ca9354d6c79ba5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 8 Jul 2025 08:45:04 +0500 Subject: [PATCH 08/12] Fix asyncpg.pgproto.pgproto.UUID serialization --- src/dstack/_internal/utils/json_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/dstack/_internal/utils/json_utils.py b/src/dstack/_internal/utils/json_utils.py index 26658d7c84..5f3dfce788 100644 --- a/src/dstack/_internal/utils/json_utils.py +++ b/src/dstack/_internal/utils/json_utils.py @@ -1,3 +1,4 @@ +import asyncpg.pgproto.pgproto import orjson from pydantic import BaseModel @@ -8,6 +9,13 @@ FREEZEGUN = False +ASYNCPG = True +try: + import asyncpg.pgproto.pgproto +except ImportError: + ASYNCPG = False + + def orjson_default(obj): if isinstance(obj, float): # orjson does not convert float subclasses be default @@ -16,6 +24,9 @@ def orjson_default(obj): # Allows calling orjson.dumps() on pydantic models # (e.g. to return from the API) return obj.dict() + if ASYNCPG: + if isinstance(obj, asyncpg.pgproto.pgproto.UUID): + return str(obj) if FREEZEGUN: if isinstance(obj, FakeDatetime): return obj.isoformat() From 05602bd65c5d0c215f95dfb853ff88cf8e12ac88 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 8 Jul 2025 08:53:50 +0500 Subject: [PATCH 09/12] Fix extra import --- src/dstack/_internal/utils/json_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dstack/_internal/utils/json_utils.py b/src/dstack/_internal/utils/json_utils.py index 5f3dfce788..4cbccd6408 100644 --- a/src/dstack/_internal/utils/json_utils.py +++ b/src/dstack/_internal/utils/json_utils.py @@ -1,4 +1,3 @@ -import asyncpg.pgproto.pgproto import orjson from pydantic import BaseModel From 820b4f7566a1c5f68dfa70c4d137c81a401cb926 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 8 Jul 2025 09:13:22 +0500 Subject: [PATCH 10/12] Use orjson with indentation for json schema --- src/dstack/_internal/core/models/common.py | 4 ++-- src/dstack/_internal/core/models/configurations.py | 14 ++++++++++++++ src/dstack/_internal/server/utils/routers.py | 4 ++-- src/dstack/_internal/utils/json_utils.py | 2 +- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index 11709d7a04..d324442ee8 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -8,7 +8,7 @@ from pydantic_duality import DualBaseModel from typing_extensions import Annotated -from dstack._internal.utils.json_utils import get_orjson_options, orjson_default +from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default IncludeExcludeFieldType = Union[int, str] IncludeExcludeSetType = set[IncludeExcludeFieldType] @@ -21,7 +21,7 @@ def _orjson_dumps(v: Any, *, default: Any) -> str: return orjson.dumps( v, - option=get_orjson_options(), + option=get_orjson_default_options(), default=orjson_default, ).decode() diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 97be403ca1..5d8d2bcbc1 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -4,6 +4,7 @@ from pathlib import PurePosixPath from typing import Any, Dict, List, Optional, Union +import orjson from pydantic import Field, ValidationError, conint, constr, root_validator, validator from typing_extensions import Annotated, Literal @@ -18,6 +19,7 @@ from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point +from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default CommandsList = List[str] ValidPort = conint(gt=0, le=65536) @@ -566,6 +568,15 @@ def parse_apply_configuration(data: dict) -> AnyApplyConfiguration: AnyDstackConfiguration = AnyApplyConfiguration +# Custom _orjson_dumps for DstackConfiguration with indentation +def _orjson_dumps(v: Any, *, default: Any) -> str: + return orjson.dumps( + v, + option=get_orjson_default_options() | orjson.OPT_INDENT_2, + default=orjson_default, + ).decode() + + class DstackConfiguration(CoreModel): __root__: Annotated[ AnyDstackConfiguration, @@ -573,6 +584,9 @@ class DstackConfiguration(CoreModel): ] class Config: + json_loads = orjson.loads + json_dumps = _orjson_dumps + @staticmethod def schema_extra(schema: Dict[str, Any]): schema["$schema"] = "http://json-schema.org/draft-07/schema#" diff --git a/src/dstack/_internal/server/utils/routers.py b/src/dstack/_internal/server/utils/routers.py index ee2b94fa79..131ec5cc3a 100644 --- a/src/dstack/_internal/server/utils/routers.py +++ b/src/dstack/_internal/server/utils/routers.py @@ -6,7 +6,7 @@ from dstack._internal.core.errors import ServerClientError, ServerClientErrorCode from dstack._internal.core.models.common import CoreModel -from dstack._internal.utils.json_utils import get_orjson_options, orjson_default +from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default class CustomORJSONResponse(Response): @@ -26,7 +26,7 @@ class CustomORJSONResponse(Response): def render(self, content: Any) -> bytes: return orjson.dumps( content, - option=get_orjson_options(), + option=get_orjson_default_options(), default=orjson_default, ) diff --git a/src/dstack/_internal/utils/json_utils.py b/src/dstack/_internal/utils/json_utils.py index 4cbccd6408..15970ef462 100644 --- a/src/dstack/_internal/utils/json_utils.py +++ b/src/dstack/_internal/utils/json_utils.py @@ -32,5 +32,5 @@ def orjson_default(obj): raise TypeError -def get_orjson_options() -> int: +def get_orjson_default_options() -> int: return orjson.OPT_NON_STR_KEYS From cc8c272e319ae2a49d13385b91f1e72c8a80174b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 8 Jul 2025 09:25:24 +0500 Subject: [PATCH 11/12] Fix generate-json-schema CI --- .github/workflows/build.yml | 4 ++-- .github/workflows/release.yml | 4 ++-- src/dstack/_internal/core/models/common.py | 12 ++---------- .../_internal/core/models/configurations.py | 15 ++++----------- src/dstack/_internal/core/models/profiles.py | 5 +++++ src/dstack/_internal/utils/json_utils.py | 18 ++++++++++++++++++ 6 files changed, 33 insertions(+), 25 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7803dd0aff..e512314018 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -237,8 +237,8 @@ jobs: run: uv sync - name: Generate json schema run: | - uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json - uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json + uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json())" > configuration.json + uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json())" > profiles.json - name: Upload json schema to S3 run: | VERSION=$((${{ github.run_number }} + ${{ env.BUILD_INCREMENT }})) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 391761e596..55a6c92439 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -311,8 +311,8 @@ jobs: run: uv sync - name: Generate json schema run: | - uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json(indent=2))" > configuration.json - uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json(indent=2))" > profiles.json + uv run python -c "from dstack._internal.core.models.configurations import DstackConfiguration; print(DstackConfiguration.schema_json())" > configuration.json + uv run python -c "from dstack._internal.core.models.profiles import ProfilesConfig; print(ProfilesConfig.schema_json())" > profiles.json - name: Upload json schema to S3 run: | VERSION=${GITHUB_REF#refs/tags/} diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index d324442ee8..f7564ece27 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -8,7 +8,7 @@ from pydantic_duality import DualBaseModel from typing_extensions import Annotated -from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default +from dstack._internal.utils.json_utils import pydantic_orjson_dumps IncludeExcludeFieldType = Union[int, str] IncludeExcludeSetType = set[IncludeExcludeFieldType] @@ -18,14 +18,6 @@ IncludeExcludeType = Union[IncludeExcludeSetType, IncludeExcludeDictType] -def _orjson_dumps(v: Any, *, default: Any) -> str: - return orjson.dumps( - v, - option=get_orjson_default_options(), - default=orjson_default, - ).decode() - - # DualBaseModel creates two classes for the model: # one with extra = "forbid" (CoreModel/CoreModel.__request__), # and another with extra = "ignore" (CoreModel.__response__). @@ -34,7 +26,7 @@ def _orjson_dumps(v: Any, *, default: Any) -> str: class CoreModel(DualBaseModel): class Config: json_loads = orjson.loads - json_dumps = _orjson_dumps + json_dumps = pydantic_orjson_dumps def json( self, diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 5d8d2bcbc1..770673b1b7 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -19,7 +19,9 @@ from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point -from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default +from dstack._internal.utils.json_utils import ( + pydantic_orjson_dumps_with_indent, +) CommandsList = List[str] ValidPort = conint(gt=0, le=65536) @@ -568,15 +570,6 @@ def parse_apply_configuration(data: dict) -> AnyApplyConfiguration: AnyDstackConfiguration = AnyApplyConfiguration -# Custom _orjson_dumps for DstackConfiguration with indentation -def _orjson_dumps(v: Any, *, default: Any) -> str: - return orjson.dumps( - v, - option=get_orjson_default_options() | orjson.OPT_INDENT_2, - default=orjson_default, - ).decode() - - class DstackConfiguration(CoreModel): __root__: Annotated[ AnyDstackConfiguration, @@ -585,7 +578,7 @@ class DstackConfiguration(CoreModel): class Config: json_loads = orjson.loads - json_dumps = _orjson_dumps + json_dumps = pydantic_orjson_dumps_with_indent @staticmethod def schema_extra(schema: Dict[str, Any]): diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 62997ce4e4..5a4909dcdf 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -1,12 +1,14 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union, overload +import orjson from pydantic import Field, root_validator, validator from typing_extensions import Annotated, Literal from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel, Duration from dstack._internal.utils.common import list_enum_values_for_annotation +from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent from dstack._internal.utils.tags import tags_validator DEFAULT_RETRY_DURATION = 3600 @@ -343,6 +345,9 @@ class ProfilesConfig(CoreModel): profiles: List[Profile] class Config: + json_loads = orjson.loads + json_dumps = pydantic_orjson_dumps_with_indent + schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"} def default(self) -> Optional[Profile]: diff --git a/src/dstack/_internal/utils/json_utils.py b/src/dstack/_internal/utils/json_utils.py index 15970ef462..9017e94c31 100644 --- a/src/dstack/_internal/utils/json_utils.py +++ b/src/dstack/_internal/utils/json_utils.py @@ -1,3 +1,5 @@ +from typing import Any + import orjson from pydantic import BaseModel @@ -15,6 +17,22 @@ ASYNCPG = False +def pydantic_orjson_dumps(v: Any, *, default: Any) -> str: + return orjson.dumps( + v, + option=get_orjson_default_options(), + default=orjson_default, + ).decode() + + +def pydantic_orjson_dumps_with_indent(v: Any, *, default: Any) -> str: + return orjson.dumps( + v, + option=get_orjson_default_options() | orjson.OPT_INDENT_2, + default=orjson_default, + ).decode() + + def orjson_default(obj): if isinstance(obj, float): # orjson does not convert float subclasses be default From 994c5f76da2fb450bd7e0cd8480738328e9c6f81 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 10 Jul 2025 12:13:18 +0500 Subject: [PATCH 12/12] Fix import --- src/dstack/_internal/core/models/common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index f7564ece27..a139226712 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -1,9 +1,8 @@ import re from enum import Enum -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union import orjson -from git import Optional from pydantic import Field from pydantic_duality import DualBaseModel from typing_extensions import Annotated