From 63a0f878423d9cf08d594539f8f4980f95ba02ce Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 12:19:24 +0500 Subject: [PATCH 01/54] Fix possible unbound vars in background tasks --- .../_internal/server/background/tasks/process_gateways.py | 2 +- .../_internal/server/background/tasks/process_instances.py | 2 +- .../_internal/server/background/tasks/process_running_jobs.py | 3 +-- src/dstack/_internal/server/background/tasks/process_runs.py | 2 +- .../server/background/tasks/process_submitted_jobs.py | 2 +- .../server/background/tasks/process_terminating_jobs.py | 4 ++-- .../_internal/server/background/tasks/process_volumes.py | 2 +- 7 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py index ef6c1aebe6..4ff65e2b3d 100644 --- a/src/dstack/_internal/server/background/tasks/process_gateways.py +++ b/src/dstack/_internal/server/background/tasks/process_gateways.py @@ -49,8 +49,8 @@ async def process_gateways(): if gateway_model is None: return lockset.add(gateway_model.id) + gateway_model_id = gateway_model.id try: - gateway_model_id = gateway_model.id initial_status = gateway_model.status if initial_status == GatewayStatus.SUBMITTED: await _process_submitted_gateway(session=session, gateway_model=gateway_model) diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 5be54f21ce..188915e03f 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -181,8 +181,8 @@ async def _process_next_instance(): if instance is None: return lockset.add(instance.id) + instance_model_id = instance.id try: - instance_model_id = instance.id await _process_instance(session=session, instance=instance) finally: lockset.difference_update([instance_model_id]) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 0a98bc7fae..7624f6e316 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -128,9 +128,8 @@ async def _process_next_running_job(): if job_model is None: return lockset.add(job_model.id) - + job_model_id = job_model.id try: - job_model_id = job_model.id await _process_running_job(session=session, job_model=job_model) finally: lockset.difference_update([job_model_id]) diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 16a84dcb93..e9d13a5009 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -129,8 +129,8 @@ async def _process_next_run(): job_ids = [j.id for j in run_model.jobs] run_lockset.add(run_model.id) job_lockset.update(job_ids) + run_model_id = run_model.id try: - run_model_id = run_model.id await _process_run(session=session, run_model=run_model) finally: run_lockset.difference_update([run_model_id]) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 9470e39b79..c85715f0e7 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -148,8 +148,8 @@ async def _process_next_submitted_job(): if job_model is None: return lockset.add(job_model.id) + job_model_id = job_model.id try: - job_model_id = job_model.id await _process_submitted_job(session=session, job_model=job_model) finally: lockset.difference_update([job_model_id]) diff --git a/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py b/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py index cd81765636..6a358dcd61 100644 --- a/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_terminating_jobs.py @@ -75,9 +75,9 @@ async def _process_next_terminating_job(): return instance_lockset.add(instance_model.id) job_lockset.add(job_model.id) + job_model_id = job_model.id + instance_model_id = job_model.used_instance_id try: - job_model_id = job_model.id - instance_model_id = job_model.used_instance_id await _process_job( session=session, job_model=job_model, diff --git a/src/dstack/_internal/server/background/tasks/process_volumes.py b/src/dstack/_internal/server/background/tasks/process_volumes.py index 4e37f6997b..534af8d48f 100644 --- a/src/dstack/_internal/server/background/tasks/process_volumes.py +++ b/src/dstack/_internal/server/background/tasks/process_volumes.py @@ -42,8 +42,8 @@ async def process_submitted_volumes(): if volume_model is None: return lockset.add(volume_model.id) + volume_model_id = volume_model.id try: - volume_model_id = volume_model.id await _process_submitted_volume(session=session, volume_model=volume_model) finally: lockset.difference_update([volume_model_id]) From d150224ed3b89ed22f4f2c846503160b40f5ace5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 14:36:54 +0500 Subject: [PATCH 02/54] Move class Config after fields --- src/dstack/_internal/core/models/common.py | 6 +- src/dstack/_internal/core/models/profiles.py | 12 ++- src/dstack/_internal/core/models/resources.py | 84 +++++++++---------- .../_internal/server/services/docker.py | 6 +- 4 files changed, 58 insertions(+), 50 deletions(-) diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py index 4c4e45fd09..772da55274 100644 --- a/src/dstack/_internal/core/models/common.py +++ b/src/dstack/_internal/core/models/common.py @@ -102,12 +102,12 @@ class RegistryAuth(CoreModel): password (str): The password or access token """ - class Config(CoreModel.Config): - frozen = True - username: Annotated[str, Field(description="The username")] password: Annotated[str, Field(description="The password or access token")] + class Config(CoreModel.Config): + frozen = True + class ApplyAction(str, Enum): CREATE = "create" # resource is to be created or overridden diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 5572ae25dd..2b243afba1 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -9,6 +9,7 @@ from dstack._internal.core.models.common import CoreModel, Duration from dstack._internal.utils.common import list_enum_values_for_annotation from dstack._internal.utils.cron import validate_cron +from dstack._internal.utils.json_schema import add_extra_schema_types from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent from dstack._internal.utils.tags import tags_validator @@ -151,7 +152,7 @@ class UtilizationPolicy(CoreModel): ), ] time_window: Annotated[ - Union[int, str], + int, Field( description=( "The time window of metric samples taking into account to measure utilization" @@ -160,6 +161,14 @@ class UtilizationPolicy(CoreModel): ), ] + class Config(CoreModel.Config): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["time_window"], + extra_types=[{"type": "string"}], + ) + @validator("time_window", pre=True) def validate_time_window(cls, v: Union[int, str]) -> int: v = parse_duration(v) @@ -382,7 +391,6 @@ class ProfilesConfig(CoreModel): class Config(CoreModel.Config): json_loads = orjson.loads json_dumps = pydantic_orjson_dumps_with_indent - schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"} def default(self) -> Optional[Profile]: diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py index 13d5dcf2a9..c959d4b059 100644 --- a/src/dstack/_internal/core/models/resources.py +++ b/src/dstack/_internal/core/models/resources.py @@ -130,6 +130,12 @@ def __str__(self): class CPUSpec(CoreModel): + arch: Annotated[ + Optional[gpuhunt.CPUArchitecture], + Field(description="The CPU architecture, one of: `x86`, `arm`"), + ] = None + count: Annotated[Range[int], Field(description="The number of CPU cores")] = DEFAULT_CPU_COUNT + class Config(CoreModel.Config): @staticmethod def schema_extra(schema: Dict[str, Any]): @@ -138,12 +144,6 @@ def schema_extra(schema: Dict[str, Any]): extra_types=[{"type": "integer"}, {"type": "string"}], ) - arch: Annotated[ - Optional[gpuhunt.CPUArchitecture], - Field(description="The CPU architecture, one of: `x86`, `arm`"), - ] = None - count: Annotated[Range[int], Field(description="The number of CPU cores")] = DEFAULT_CPU_COUNT - @classmethod def __get_validators__(cls): yield cls.parse @@ -191,22 +191,6 @@ def _validate_arch(cls, v: Any) -> Any: class GPUSpec(CoreModel): - class Config(CoreModel.Config): - @staticmethod - def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["count"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["memory"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - add_extra_schema_types( - schema["properties"]["total_memory"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) - vendor: Annotated[ Optional[gpuhunt.AcceleratorVendor], Field( @@ -234,6 +218,22 @@ def schema_extra(schema: Dict[str, Any]): Field(description="The minimum compute capability of the GPU (e.g., `7.5`)"), ] = None + class Config(CoreModel.Config): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["count"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["memory"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["total_memory"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + @classmethod def __get_validators__(cls): yield cls.parse @@ -314,6 +314,8 @@ def _vendor_from_string(cls, v: str) -> gpuhunt.AcceleratorVendor: class DiskSpec(CoreModel): + size: Annotated[Range[Memory], Field(description="Disk size")] + class Config(CoreModel.Config): @staticmethod def schema_extra(schema: Dict[str, Any]): @@ -322,8 +324,6 @@ def schema_extra(schema: Dict[str, Any]): extra_types=[{"type": "integer"}, {"type": "string"}], ) - size: Annotated[Range[Memory], Field(description="Disk size")] - @classmethod def __get_validators__(cls): yield cls._parse @@ -340,6 +340,24 @@ def _parse(cls, v: Any) -> Any: class ResourcesSpec(CoreModel): + # TODO: Remove Range[int] in 0.20. Range[int] for backward compatibility only. + cpu: Annotated[Union[CPUSpec, Range[int]], Field(description="The CPU requirements")] = ( + CPUSpec() + ) + memory: Annotated[Range[Memory], Field(description="The RAM size (e.g., `8GB`)")] = ( + DEFAULT_MEMORY_SIZE + ) + shm_size: Annotated[ + Optional[Memory], + Field( + description="The size of shared memory (e.g., `8GB`). " + "If you are using parallel communicating processes (e.g., dataloaders in PyTorch), " + "you may need to configure this" + ), + ] = None + gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None + disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK + class Config(CoreModel.Config): @staticmethod def schema_extra(schema: Dict[str, Any]): @@ -364,24 +382,6 @@ def schema_extra(schema: Dict[str, Any]): extra_types=[{"type": "integer"}, {"type": "string"}], ) - # TODO: Remove Range[int] in 0.20. Range[int] for backward compatibility only. - cpu: Annotated[Union[CPUSpec, Range[int]], Field(description="The CPU requirements")] = ( - CPUSpec() - ) - memory: Annotated[Range[Memory], Field(description="The RAM size (e.g., `8GB`)")] = ( - DEFAULT_MEMORY_SIZE - ) - shm_size: Annotated[ - Optional[Memory], - Field( - description="The size of shared memory (e.g., `8GB`). " - "If you are using parallel communicating processes (e.g., dataloaders in PyTorch), " - "you may need to configure this" - ), - ] = None - gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None - disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK - def pretty_format(self) -> str: # TODO: Remove in 0.20. Use self.cpu directly cpu = parse_obj_as(CPUSpec, self.cpu) diff --git a/src/dstack/_internal/server/services/docker.py b/src/dstack/_internal/server/services/docker.py index 49e8d8e857..40291179df 100644 --- a/src/dstack/_internal/server/services/docker.py +++ b/src/dstack/_internal/server/services/docker.py @@ -32,15 +32,15 @@ def __call__(self, dxf: DXF, response: requests.Response) -> None: class DockerImage(CoreModel): - class Config(CoreModel.Config): - frozen = True - image: str registry: Optional[str] repo: str tag: str digest: Optional[str] + class Config(CoreModel.Config): + frozen = True + class ImageConfig(CoreModel): user: Annotated[Optional[str], Field(alias="User")] = None From 26f1f42cee5d1c175c340d043ee50960fe4f7bfb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 14:41:39 +0500 Subject: [PATCH 03/54] Fix unbound vars in app.py --- src/dstack/_internal/server/app.py | 39 +++++++++++++++++------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 8e65897710..43c51aec9a 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -109,9 +109,11 @@ async def lifespan(app: FastAPI): _print_dstack_logo() if not check_required_ssh_version(): logger.warning("OpenSSH 8.4+ is required. The dstack server may not work properly") + server_config_manager = None + server_config_loaded = False if settings.SERVER_CONFIG_ENABLED: server_config_manager = ServerConfigManager() - config_loaded = server_config_manager.load_config() + server_config_loaded = server_config_manager.load_config() # Encryption has to be configured before working with users and projects await server_config_manager.apply_encryption() async with get_session_ctx() as session: @@ -125,11 +127,9 @@ async def lifespan(app: FastAPI): session=session, user=admin, ) - if settings.SERVER_CONFIG_ENABLED: - server_config_dir = str(SERVER_CONFIG_FILE_PATH).replace( - os.path.expanduser("~"), "~", 1 - ) - if not config_loaded: + if server_config_manager is not None: + server_config_dir = _get_server_config_dir() + if not server_config_loaded: logger.info("Initializing the default configuration...", {"show_path": False}) await server_config_manager.init_config(session=session) logger.info( @@ -152,6 +152,7 @@ async def lifespan(app: FastAPI): ) if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None: init_default_storage() + scheduler = None if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: scheduler = start_background_tasks() else: @@ -166,7 +167,7 @@ async def lifespan(app: FastAPI): for func in _ON_STARTUP_HOOKS: await func(app) yield - if settings.SERVER_BACKGROUND_PROCESSING_ENABLED: + if scheduler is not None: scheduler.shutdown() PROBES_SCHEDULER.shutdown(wait=False) await gateway_connections_pool.remove_all() @@ -369,6 +370,18 @@ def _is_prometheus_request(request: Request) -> bool: return request.url.path.startswith("/metrics") +def _sentry_traces_sampler(sampling_context: SamplingContext) -> float: + parent_sampling_decision = sampling_context["parent_sampled"] + if parent_sampling_decision is not None: + return float(parent_sampling_decision) + transaction_context = sampling_context["transaction_context"] + name = transaction_context.get("name") + if name is not None: + if name.startswith("background."): + return settings.SENTRY_TRACES_BACKGROUND_SAMPLE_RATE + return settings.SENTRY_TRACES_SAMPLE_RATE + + def _print_dstack_logo(): console.print( """[purple]╱╱╭╮╱╱╭╮╱╱╱╱╱╱╭╮ @@ -385,13 +398,5 @@ def _print_dstack_logo(): ) -def _sentry_traces_sampler(sampling_context: SamplingContext) -> float: - parent_sampling_decision = sampling_context["parent_sampled"] - if parent_sampling_decision is not None: - return float(parent_sampling_decision) - transaction_context = sampling_context["transaction_context"] - name = transaction_context.get("name") - if name is not None: - if name.startswith("background."): - return settings.SENTRY_TRACES_BACKGROUND_SAMPLE_RATE - return settings.SENTRY_TRACES_SAMPLE_RATE +def _get_server_config_dir() -> str: + return str(SERVER_CONFIG_FILE_PATH).replace(os.path.expanduser("~"), "~", 1) From 3484ce127323d2602dc511db4bc5f42533c147a3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 14:44:55 +0500 Subject: [PATCH 04/54] Assert provisioning gateways have compute --- .../_internal/server/background/tasks/process_gateways.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py index 4ff65e2b3d..a54cb9e319 100644 --- a/src/dstack/_internal/server/background/tasks/process_gateways.py +++ b/src/dstack/_internal/server/background/tasks/process_gateways.py @@ -165,6 +165,9 @@ async def _process_provisioning_gateway( ) gateway_model = res.unique().scalar_one() + # Provisioning gateways must have compute. + assert gateway_model.gateway_compute is not None + # FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway: # - cannot delete the gateway before it is provisioned because the DB model is locked # - connection retry counter is reset on server restart From ac1ec7bd89b973b5bf84965a0e968e62d4d5f0cb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 14:49:26 +0500 Subject: [PATCH 05/54] Assert region set for ssh instances --- .../_internal/server/background/tasks/process_instances.py | 1 + src/dstack/_internal/server/models.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 188915e03f..8e2127cd78 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -393,6 +393,7 @@ async def _add_remote(instance: InstanceModel) -> None: return region = instance.region + assert region is not None # always set for ssh instances jpd = JobProvisioningData( backend=BackendType.REMOTE, instance_type=instance_type, diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 915a7c7665..cd8873e73b 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -622,6 +622,7 @@ class InstanceModel(BaseModel): backend: Mapped[Optional[BackendType]] = mapped_column(EnumAsString(BackendType, 100)) backend_data: Mapped[Optional[str]] = mapped_column(Text) + # Not set for cloud fleets that haven't been provisioning offer: Mapped[Optional[str]] = mapped_column(Text) region: Mapped[Optional[str]] = mapped_column(String(2000)) price: Mapped[Optional[float]] = mapped_column(Float) From ef00febd77463c93e0d65817dffa3f02183f4e5c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 14:58:36 +0500 Subject: [PATCH 06/54] Assert region service configurations --- src/dstack/_internal/server/services/services/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index aba2698ec2..68fdc509cf 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -73,6 +73,8 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec: async def _register_service_in_gateway( session: AsyncSession, run_model: RunModel, run_spec: RunSpec, gateway: GatewayModel ) -> ServiceSpec: + assert run_spec.configuration.type == "service" + if gateway.gateway_compute is None: raise ServerClientError("Gateway has no instance associated with it") @@ -127,6 +129,7 @@ async def _register_service_in_gateway( def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec: + assert run_spec.configuration.type == "service" if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT: # Note: if the user sets `https: `, it will be ignored silently # TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted @@ -270,6 +273,7 @@ async def unregister_replica(session: AsyncSession, job_model: JobModel): def _get_service_https(run_spec: RunSpec, configuration: GatewayConfiguration) -> bool: + assert run_spec.configuration.type == "service" if not run_spec.configuration.https: return False if configuration.certificate is not None and configuration.certificate.type == "acm": From 89c0322a3a4ef4a2e839e42b51ec1f1244595bf3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:02:25 +0500 Subject: [PATCH 07/54] Fix logging type annotations --- src/dstack/_internal/server/utils/logging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/utils/logging.py b/src/dstack/_internal/server/utils/logging.py index 1ea58578bf..03d7d05cb4 100644 --- a/src/dstack/_internal/server/utils/logging.py +++ b/src/dstack/_internal/server/utils/logging.py @@ -31,15 +31,15 @@ def configure_logging(): rename_fields={"name": "logger", "asctime": "timestamp", "levelname": "level"}, ), } - handlers = { + handlers: dict[str, logging.Handler] = { "rich": DstackRichHandler(console=console), "standard": logging.StreamHandler(stream=sys.stdout), "json": logging.StreamHandler(stream=sys.stdout), } if settings.LOG_FORMAT not in formatters: raise ValueError(f"Invalid settings.LOG_FORMAT: {settings.LOG_FORMAT}") - formatter = formatters.get(settings.LOG_FORMAT) - handler = handlers.get(settings.LOG_FORMAT) + formatter = formatters[settings.LOG_FORMAT] + handler = handlers[settings.LOG_FORMAT] handler.setFormatter(formatter) handler.addFilter(AsyncioCancelledErrorFilter()) root_logger = logging.getLogger(None) From d271bcecfd78f6cb9f242b71785f69ab189071bd Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:17:50 +0500 Subject: [PATCH 08/54] Fix redundant add_project_members line --- src/dstack/_internal/server/services/projects.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 992e1be046..b901bc8e21 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -313,7 +313,6 @@ async def add_project_members( member_num=None, commit=False, ) - member_by_user_id[user_to_add.id] = None await session.commit() From 3f9d14706da93d9cd93d57dfb980870281564039 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:30:24 +0500 Subject: [PATCH 09/54] Fix Lockset annotations --- src/dstack/_internal/server/services/locking.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/services/locking.py b/src/dstack/_internal/server/services/locking.py index 4c3b7f938a..71a4aa7bfe 100644 --- a/src/dstack/_internal/server/services/locking.py +++ b/src/dstack/_internal/server/services/locking.py @@ -23,13 +23,13 @@ async def __aexit__(self, exc_type, exc, tb): ... class Lockset(Protocol[T]): - def __contains__(self, item: T) -> bool: ... + def __contains__(self, item: T, /) -> bool: ... def __iter__(self) -> Iterator[T]: ... def __len__(self) -> int: ... - def add(self, item: T) -> None: ... - def discard(self, item: T) -> None: ... - def update(self, other: Iterable[T]) -> None: ... - def difference_update(self, other: Iterable[T]) -> None: ... + def add(self, item: T, /) -> None: ... + def discard(self, item: T, /) -> None: ... + def update(self, other: Iterable[T], /) -> None: ... + def difference_update(self, other: Iterable[T], /) -> None: ... class ResourceLocker: From 738fa50f3ab05c09adbbbff217c52aa175bbcbbd Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:35:10 +0500 Subject: [PATCH 10/54] Assert configuration types in job configurators --- .../_internal/server/services/jobs/configurators/dev.py | 5 +++++ .../_internal/server/services/jobs/configurators/service.py | 1 + .../_internal/server/services/jobs/configurators/task.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py index a10922ef79..20aad1f232 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/dev.py +++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py @@ -18,6 +18,8 @@ class DevEnvironmentJobConfigurator(JobConfigurator): TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]): + assert run_spec.configuration.type == "dev-environment" + if run_spec.configuration.ide == "vscode": __class = VSCodeDesktop elif run_spec.configuration.ide == "cursor": @@ -32,6 +34,8 @@ def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]): super().__init__(run_spec=run_spec, secrets=secrets) def _shell_commands(self) -> List[str]: + assert self.run_spec.configuration.type == "dev-environment" + commands = self.ide.get_install_commands() commands.append(INSTALL_IPYKERNEL) commands += self.run_spec.configuration.setup @@ -56,4 +60,5 @@ def _spot_policy(self) -> SpotPolicy: return self.run_spec.merged_profile.spot_policy or SpotPolicy.ONDEMAND def _ports(self) -> List[PortMapping]: + assert self.run_spec.configuration.type == "dev-environment" return self.run_spec.configuration.ports diff --git a/src/dstack/_internal/server/services/jobs/configurators/service.py b/src/dstack/_internal/server/services/jobs/configurators/service.py index 7cd36f178a..a00216a6d4 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/service.py +++ b/src/dstack/_internal/server/services/jobs/configurators/service.py @@ -9,6 +9,7 @@ class ServiceJobConfigurator(JobConfigurator): TYPE: RunConfigurationType = RunConfigurationType.SERVICE def _shell_commands(self) -> List[str]: + assert self.run_spec.configuration.type == "service" return self.run_spec.configuration.commands def _default_single_branch(self) -> bool: diff --git a/src/dstack/_internal/server/services/jobs/configurators/task.py b/src/dstack/_internal/server/services/jobs/configurators/task.py index 4b1c93ce05..6a0da9f003 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/task.py +++ b/src/dstack/_internal/server/services/jobs/configurators/task.py @@ -10,6 +10,7 @@ class TaskJobConfigurator(JobConfigurator): TYPE: RunConfigurationType = RunConfigurationType.TASK async def get_job_specs(self, replica_num: int) -> List[JobSpec]: + assert self.run_spec.configuration.type == "task" job_specs = [] for job_num in range(self.run_spec.configuration.nodes): job_spec = await self._get_job_spec( @@ -21,6 +22,7 @@ async def get_job_specs(self, replica_num: int) -> List[JobSpec]: return job_specs def _shell_commands(self) -> List[str]: + assert self.run_spec.configuration.type == "task" return self.run_spec.configuration.commands def _default_single_branch(self) -> bool: @@ -33,6 +35,7 @@ def _spot_policy(self) -> SpotPolicy: return self.run_spec.merged_profile.spot_policy or SpotPolicy.ONDEMAND def _ports(self) -> List[PortMapping]: + assert self.run_spec.configuration.type == "task" return self.run_spec.configuration.ports def _working_dir(self) -> Optional[str]: From 961c2bcb113bf3e5bfa8b6a142ef8719df5da64a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:46:20 +0500 Subject: [PATCH 11/54] Fix unbound vars in process_running_jobs --- .../server/background/tasks/process_running_jobs.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 7624f6e316..2b6c173c54 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -169,6 +169,11 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): job = find_job(run.jobs, job_model.replica_num, job_model.job_num) + volumes = [] + secrets = {} + cluster_info = None + repo_creds = None + initial_status = job_model.status if initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]: # Wait until all other jobs in the replica are provisioned @@ -256,6 +261,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): user_ssh_key, ) else: + assert cluster_info is not None logger.debug( "%s: process provisioning job without shim, age=%s", fmt(job_model), @@ -274,7 +280,6 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): repo=repo_model, code_hash=_get_repo_code_hash(run, job), ) - success = await common_utils.run_async( _submit_job_to_runner, server_ssh_private_keys, @@ -308,6 +313,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): else: # fails are not acceptable if initial_status == JobStatus.PULLING: + assert cluster_info is not None logger.debug( "%s: process pulling job with shim, age=%s", fmt(job_model), job_submission.age ) From b5978dc796ad95786bc15762355bebb1a979ca3d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:49:30 +0500 Subject: [PATCH 12/54] Check backend is not None --- src/dstack/_internal/server/services/backends/handlers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/backends/handlers.py b/src/dstack/_internal/server/services/backends/handlers.py index bcd4b857b0..77f8d9832f 100644 --- a/src/dstack/_internal/server/services/backends/handlers.py +++ b/src/dstack/_internal/server/services/backends/handlers.py @@ -55,7 +55,11 @@ async def _check_active_instances( ) for fleet_model in fleet_models: for instance in fleet_model.instances: - if instance.status.is_active() and instance.backend in backends_types: + if ( + instance.status.is_active() + and instance.backend is not None + and instance.backend in backends_types + ): if error: msg = ( f"Backend {instance.backend.value} has active instances." @@ -83,6 +87,7 @@ async def _check_active_volumes( if ( volume_model.status.is_active() and volume_model.provisioning_data is not None + and volume_model.provisioning_data.backend is not None and volume_model.provisioning_data.backend in backends_types ): if error: From 75c58450e0b6fd9b101526366a3aa6b66d45971a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:55:56 +0500 Subject: [PATCH 13/54] Assert service and jpd in register_service --- src/dstack/_internal/server/services/gateways/client.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index aa4b4823cf..f8c0900792 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -7,7 +7,7 @@ from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import GatewayError -from dstack._internal.core.models.configurations import RateLimit, ServiceConfiguration +from dstack._internal.core.models.configurations import RateLimit from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.runs import JobSpec, JobSubmission, Run, get_service_port from dstack._internal.proxy.gateway.schemas.stats import ServiceStats @@ -85,7 +85,7 @@ async def register_replica( ssh_head_proxy: Optional[SSHConnectionParams], ssh_head_proxy_private_key: Optional[str], ): - assert isinstance(run.run_spec.configuration, ServiceConfiguration) + assert run.run_spec.configuration.type == "service" payload = { "job_id": job_submission.id.hex, "app_port": get_service_port(job_spec, run.run_spec.configuration), @@ -93,6 +93,9 @@ async def register_replica( "ssh_head_proxy_private_key": ssh_head_proxy_private_key, } jpd = job_submission.job_provisioning_data + assert jpd is not None + assert jpd.hostname is not None + assert jpd.ssh_port is not None if not jpd.dockerized: payload.update( { From 20d0f7289cc977ddfc4a1aa645bddf047288f077 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:56:59 +0500 Subject: [PATCH 14/54] Assert jpd in container_ssh_tunnel --- src/dstack/_internal/server/services/ssh.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/ssh.py b/src/dstack/_internal/server/services/ssh.py index 2ab685eadb..a7967d8031 100644 --- a/src/dstack/_internal/server/services/ssh.py +++ b/src/dstack/_internal/server/services/ssh.py @@ -20,10 +20,11 @@ def container_ssh_tunnel( """ Build SSHTunnel for connecting to the container running the specified job. """ - jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw( job.job_provisioning_data ) + assert jpd.hostname is not None + assert jpd.ssh_port is not None if not jpd.dockerized: ssh_destination = f"{jpd.username}@{jpd.hostname}" ssh_port = jpd.ssh_port From 005c74497174164e00c4c40cafe6eca2cb6dfd20 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 18 Aug 2025 15:59:26 +0500 Subject: [PATCH 15/54] Assert jpd in ServerProxyRepo --- src/dstack/_internal/server/services/proxy/repo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 8e12a6daeb..23d155355e 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -74,6 +74,8 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw( job.job_provisioning_data ) + assert jpd.hostname is not None + assert jpd.ssh_port is not None if not jpd.dockerized: ssh_destination = f"{jpd.username}@{jpd.hostname}" ssh_port = jpd.ssh_port From c58d1d51e6deb2e53cad4e4be1a1c632adb064cd Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 11:45:24 +0500 Subject: [PATCH 16/54] Fix configuration.model type annotation --- src/dstack/_internal/core/models/configurations.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index ee0ec61b5f..eb9989329d 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -547,7 +547,7 @@ class ServiceConfigurationParams(CoreModel): ), ] = STRIP_PREFIX_DEFAULT model: Annotated[ - Optional[Union[AnyModel, str]], + Optional[AnyModel], Field( description=( "Mapping of the model for the OpenAI-compatible endpoint provided by `dstack`." @@ -578,6 +578,14 @@ class ServiceConfigurationParams(CoreModel): Field(description="List of probes used to determine job health"), ] = [] + class Config(CoreModel.Config): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["model"], + extra_types=[{"type": "string"}], + ) + @validator("port") def convert_port(cls, v) -> PortMapping: if isinstance(v, int): @@ -586,7 +594,7 @@ def convert_port(cls, v) -> PortMapping: return PortMapping.parse(v) return v - @validator("model") + @validator("model", pre=True) def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]: if isinstance(v, str): return OpenAIChatModel(type="chat", name=v, format="openai") From 05839ffdfad561ef180722736ce38cb453ba9251 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 11:50:55 +0500 Subject: [PATCH 17/54] Assert conf.replicas --- src/dstack/_internal/server/services/services/autoscalers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index 47eabaab31..cd6d06e588 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -120,6 +120,8 @@ def get_desired_count( def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: + assert conf.replicas.min is not None + assert conf.replicas.max is not None if conf.scaling is None: return ManualScaler( min_replicas=conf.replicas.min, From 6be02a5c539c50ffe976c2aec5ee4c8fe653227d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 12:17:03 +0500 Subject: [PATCH 18/54] Fix abstract AsyncGenerator def --- src/dstack/_internal/proxy/lib/deps.py | 8 ++++++-- src/dstack/_internal/server/services/proxy/repo.py | 2 ++ src/dstack/_internal/server/services/runs.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/proxy/lib/deps.py b/src/dstack/_internal/proxy/lib/deps.py index ae10be7abe..21528899ce 100644 --- a/src/dstack/_internal/proxy/lib/deps.py +++ b/src/dstack/_internal/proxy/lib/deps.py @@ -21,12 +21,16 @@ class ProxyDependencyInjector(ABC): def __init__(self) -> None: self._service_conn_pool = ServiceConnectionPool() + # Abstract AsyncGenerator does not need async def since + # type checkers infer a different type without yield in body. + # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators + @abstractmethod - async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]: + def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]: pass @abstractmethod - async def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]: + def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]: pass async def get_service_connection_pool(self) -> ServiceConnectionPool: diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 23d155355e..3a8204377a 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -177,6 +177,8 @@ def _model_options_to_format_spec(model: AnyModel) -> AnyModelFormat: if model.format == "openai": return OpenAIChatModelFormat(prefix=model.prefix) elif model.format == "tgi": + assert model.chat_template is not None + assert model.eos_token is not None return TGIChatModelFormat( chat_template=model.chat_template, eos_token=model.eos_token, diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 81d34a2ae3..1e4a2d2ad6 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -529,7 +529,7 @@ async def submit_run( initial_status = RunStatus.PENDING initial_replicas = 0 elif run_spec.configuration.type == "service": - initial_replicas = run_spec.configuration.replicas.min + initial_replicas = run_spec.configuration.replicas.min or 0 run_model = RunModel( id=uuid.uuid4(), From 0f3336d8c83f18f6dfcccb4aede71cbbcbcb41d1 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 12:20:34 +0500 Subject: [PATCH 19/54] Fix ProbeConfig type annotations --- .../_internal/core/models/configurations.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index eb9989329d..7a75074518 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -221,7 +221,7 @@ class ProbeConfig(CoreModel): ), ] = None timeout: Annotated[ - Optional[Union[int, str]], + Optional[int], Field( description=( f"Maximum amount of time the HTTP request is allowed to take. Defaults to `{DEFAULT_PROBE_TIMEOUT}s`" @@ -229,7 +229,7 @@ class ProbeConfig(CoreModel): ), ] = None interval: Annotated[ - Optional[Union[int, str]], + Optional[int], Field( description=( "Minimum amount of time between the end of one probe execution" @@ -249,7 +249,19 @@ class ProbeConfig(CoreModel): ), ] = None - @validator("timeout") + class Config(CoreModel.Config): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["timeout"], + extra_types=[{"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["interval"], + extra_types=[{"type": "string"}], + ) + + @validator("timeout", pre=True) def parse_timeout(cls, v: Optional[Union[int, str]]) -> Optional[int]: if v is None: return v @@ -258,7 +270,7 @@ def parse_timeout(cls, v: Optional[Union[int, str]]) -> Optional[int]: raise ValueError(f"Probe timeout cannot be shorter than {MIN_PROBE_TIMEOUT}s") return parsed - @validator("interval") + @validator("interval", pre=True) def parse_interval(cls, v: Optional[Union[int, str]]) -> Optional[int]: if v is None: return v From 1fe4f27e91731aca3a6ce6482c9fddbd3318c6e7 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 14:06:50 +0500 Subject: [PATCH 20/54] Fix max_duration and stop_duration type annotations --- .../_internal/core/models/configurations.py | 22 ++++++++++++++----- src/dstack/_internal/core/models/profiles.py | 22 ++++++++++++++----- .../services/jobs/configurators/base.py | 10 ++++----- src/dstack/api/_public/runs.py | 6 ++--- 4 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 7a75074518..d111bff370 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -519,6 +519,11 @@ class DevEnvironmentConfiguration( ): type: Literal["dev-environment"] = "dev-environment" + class Config(ProfileParams.Config): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + ProfileParams.Config.schema_extra(schema) + class TaskConfigurationParams(CoreModel): nodes: Annotated[int, Field(description="Number of nodes", ge=1)] = 1 @@ -532,6 +537,11 @@ class TaskConfiguration( ): type: Literal["task"] = "task" + class Config(ProfileParams.Config): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + ProfileParams.Config.schema_extra(schema) + class ServiceConfigurationParams(CoreModel): port: Annotated[ @@ -593,6 +603,10 @@ class ServiceConfigurationParams(CoreModel): class Config(CoreModel.Config): @staticmethod def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["replicas"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) add_extra_schema_types( schema["properties"]["model"], extra_types=[{"type": "string"}], @@ -669,13 +683,11 @@ class ServiceConfiguration( ): type: Literal["service"] = "service" - class Config(CoreModel.Config): + class Config(ProfileParams.Config, ServiceConfigurationParams.Config): @staticmethod def schema_extra(schema: Dict[str, Any]): - add_extra_schema_types( - schema["properties"]["replicas"], - extra_types=[{"type": "integer"}, {"type": "string"}], - ) + ProfileParams.Config.schema_extra(schema) + ServiceConfigurationParams.Config.schema_extra(schema) AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration] diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 2b243afba1..10a7d29a7a 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -62,15 +62,17 @@ def parse_duration(v: Optional[Union[int, str]]) -> Optional[int]: return Duration.parse(v) -def parse_max_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str, int]]: +def parse_max_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[Literal["off"], int]]: return parse_off_duration(v) -def parse_stop_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str, int]]: +def parse_stop_duration( + v: Optional[Union[int, str, bool]], +) -> Optional[Union[Literal["off"], int]]: return parse_off_duration(v) -def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str, int]]: +def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[Literal["off"], int]]: if v == "off" or v is False: return "off" if v is True: @@ -78,7 +80,7 @@ def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str return parse_duration(v) -def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[Union[str, int]]: +def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[int]: if v == "off" or v == -1: return -1 return parse_duration(v) @@ -256,7 +258,7 @@ class ProfileParams(CoreModel): Field(description="The policy for resubmitting the run. Defaults to `false`"), ] = None max_duration: Annotated[ - Optional[Union[Literal["off"], str, int, bool]], + Optional[Union[Literal["off"], int]], Field( description=( "The maximum duration of a run (e.g., `2h`, `1d`, etc)." @@ -266,7 +268,7 @@ class ProfileParams(CoreModel): ), ] = None stop_duration: Annotated[ - Optional[Union[Literal["off"], str, int, bool]], + Optional[Union[Literal["off"], int]], Field( description=( "The maximum duration of a run graceful stopping." @@ -356,6 +358,14 @@ def schema_extra(schema: Dict[str, Any]) -> None: del schema["properties"]["retry_policy"] del schema["properties"]["termination_policy"] del schema["properties"]["termination_idle_time"] + add_extra_schema_types( + schema["properties"]["max_duration"], + extra_types=[{"type": "boolean"}, {"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["stop_duration"], + extra_types=[{"type": "boolean"}, {"type": "string"}], + ) _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 1a67ad3cf7..a787f7220f 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -258,19 +258,17 @@ def _single_branch(self) -> bool: return self.run_spec.configuration.single_branch def _max_duration(self) -> Optional[int]: - if self.run_spec.merged_profile.max_duration in [None, True]: + if self.run_spec.merged_profile.max_duration is None: return self._default_max_duration() - if self.run_spec.merged_profile.max_duration in ["off", False]: + if self.run_spec.merged_profile.max_duration == "off": return None - # pydantic validator ensures this is int return self.run_spec.merged_profile.max_duration def _stop_duration(self) -> Optional[int]: - if self.run_spec.merged_profile.stop_duration in [None, True]: + if self.run_spec.merged_profile.stop_duration is None: return DEFAULT_STOP_DURATION - if self.run_spec.merged_profile.stop_duration in ["off", False]: + if self.run_spec.merged_profile.stop_duration == "off": return None - # pydantic validator ensures this is int return self.run_spec.merged_profile.stop_duration def _utilization_policy(self) -> Optional[UtilizationPolicy]: diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index e1992068d0..d55b311f46 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -436,7 +436,7 @@ def get_run_plan( ) -> RunPlan: """ Get a run plan. - Use this method to see the run plan before applying the cofiguration. + Use this method to see the run plan before applying the configuration. Args: configuration (Union[Task, Service, DevEnvironment]): The run configuration. @@ -691,8 +691,8 @@ def get_plan( spot_policy=spot_policy, retry=None, utilization_policy=utilization_policy, - max_duration=max_duration, - stop_duration=stop_duration, + max_duration=max_duration, # type: ignore[assignment] + stop_duration=stop_duration, # type: ignore[assignment] max_price=max_price, creation_policy=creation_policy, idle_duration=idle_duration, From fa7d7239e024dafa50b6cf67e0fa8b2bbe315e62 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 14:12:28 +0500 Subject: [PATCH 21/54] Fix idle_duration type annotations --- src/dstack/_internal/cli/services/profile.py | 2 +- src/dstack/_internal/core/models/fleets.py | 6 +++++- src/dstack/_internal/core/models/profiles.py | 6 +++++- src/dstack/_internal/core/services/profiles.py | 4 ++-- src/dstack/api/_public/runs.py | 2 +- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/dstack/_internal/cli/services/profile.py b/src/dstack/_internal/cli/services/profile.py index d57ea2e130..6340719bd2 100644 --- a/src/dstack/_internal/cli/services/profile.py +++ b/src/dstack/_internal/cli/services/profile.py @@ -159,7 +159,7 @@ def apply_profile_args( if args.idle_duration is not None: profile_settings.idle_duration = args.idle_duration elif args.dont_destroy: - profile_settings.idle_duration = "off" + profile_settings.idle_duration = -1 if args.creation_policy_reuse: profile_settings.creation_policy = CreationPolicy.REUSE diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py index 8aaf0d18ee..357f9b5b0c 100644 --- a/src/dstack/_internal/core/models/fleets.py +++ b/src/dstack/_internal/core/models/fleets.py @@ -224,7 +224,7 @@ class InstanceGroupParams(CoreModel): Field(description="The maximum instance price per hour, in dollars", gt=0.0), ] = None idle_duration: Annotated[ - Optional[Union[Literal["off"], str, int]], + Optional[int], Field( description="Time to wait before terminating idle instances. Defaults to `5m` for runs and `3d` for fleets. Use `off` for unlimited duration" ), @@ -243,6 +243,10 @@ def schema_extra(schema: Dict[str, Any], model: Type): schema["properties"]["nodes"], extra_types=[{"type": "integer"}, {"type": "string"}], ) + add_extra_schema_types( + schema["properties"]["idle_duration"], + extra_types=[{"type": "string"}], + ) _validate_idle_duration = validator("idle_duration", pre=True, allow_reuse=True)( parse_idle_duration diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 10a7d29a7a..30aa94fd03 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -293,7 +293,7 @@ class ProfileParams(CoreModel): ), ] = None idle_duration: Annotated[ - Optional[Union[Literal["off"], str, int]], + Optional[int], Field( description=( "Time to wait before terminating idle instances." @@ -366,6 +366,10 @@ def schema_extra(schema: Dict[str, Any]) -> None: schema["properties"]["stop_duration"], extra_types=[{"type": "boolean"}, {"type": "string"}], ) + add_extra_schema_types( + schema["properties"]["idle_duration"], + extra_types=[{"type": "string"}], + ) _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration diff --git a/src/dstack/_internal/core/services/profiles.py b/src/dstack/_internal/core/services/profiles.py index cd268aeac0..71ed2e520e 100644 --- a/src/dstack/_internal/core/services/profiles.py +++ b/src/dstack/_internal/core/services/profiles.py @@ -37,10 +37,10 @@ def get_termination( ) -> Tuple[TerminationPolicy, int]: termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE termination_idle_time = default_termination_idle_time - if profile.idle_duration is not None and int(profile.idle_duration) < 0: + if profile.idle_duration is not None and profile.idle_duration < 0: termination_policy = TerminationPolicy.DONT_DESTROY elif profile.idle_duration is not None: termination_idle_time = profile.idle_duration if termination_policy == TerminationPolicy.DONT_DESTROY: termination_idle_time = -1 - return termination_policy, int(termination_idle_time) + return termination_policy, termination_idle_time diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index d55b311f46..6470e0b392 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -695,7 +695,7 @@ def get_plan( stop_duration=stop_duration, # type: ignore[assignment] max_price=max_price, creation_policy=creation_policy, - idle_duration=idle_duration, + idle_duration=idle_duration, # type: ignore[assignment] ) run_spec = RunSpec( run_name=run_name, From 895b36c26cf6986f47c3dce35ccd8fedae08bf5d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 14:37:19 +0500 Subject: [PATCH 22/54] Fix volumes and files type annotations --- .../_internal/core/models/configurations.py | 39 +++++++++++++------ .../services/jobs/configurators/base.py | 7 +--- src/dstack/api/_public/runs.py | 1 - 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index d111bff370..a721dfbfae 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -385,9 +385,7 @@ class BaseRunConfiguration(CoreModel): ), ), ] = None - volumes: Annotated[ - List[Union[MountPoint, str]], Field(description="The volumes mount points") - ] = [] + volumes: Annotated[List[MountPoint], Field(description="The volumes mount points")] = [] docker: Annotated[ Optional[bool], Field( @@ -395,12 +393,24 @@ class BaseRunConfiguration(CoreModel): ), ] = None files: Annotated[ - list[Union[FilePathMapping, str]], + list[FilePathMapping], Field(description="The local to container file path mappings"), ] = [] # deprecated since 0.18.31; task, service -- no effect; dev-environment -- executed right before `init` setup: CommandsList = [] + class Config(CoreModel.Config): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["volumes"]["items"], + extra_types=[{"type": "string"}], + ) + add_extra_schema_types( + schema["properties"]["files"]["items"], + extra_types=[{"type": "string"}], + ) + @validator("python", pre=True, always=True) def convert_python(cls, v, values) -> Optional[PythonVersion]: if v is not None and values.get("image"): @@ -425,14 +435,14 @@ def _docker(cls, v, values) -> Optional[bool]: # but it's not possible to do so without breaking backwards compatibility. return v - @validator("volumes", each_item=True) - def convert_volumes(cls, v) -> MountPoint: + @validator("volumes", each_item=True, pre=True) + def convert_volumes(cls, v: Union[MountPoint, str]) -> MountPoint: if isinstance(v, str): return parse_mount_point(v) return v - @validator("files", each_item=True) - def convert_files(cls, v) -> FilePathMapping: + @validator("files", each_item=True, pre=True) + def convert_files(cls, v: Union[FilePathMapping, str]) -> FilePathMapping: if isinstance(v, str): return FilePathMapping.parse(v) return v @@ -519,10 +529,11 @@ class DevEnvironmentConfiguration( ): type: Literal["dev-environment"] = "dev-environment" - class Config(ProfileParams.Config): + class Config(ProfileParams.Config, BaseRunConfigurationWithPorts.Config): @staticmethod def schema_extra(schema: Dict[str, Any]): ProfileParams.Config.schema_extra(schema) + BaseRunConfigurationWithPorts.Config.schema_extra(schema) class TaskConfigurationParams(CoreModel): @@ -537,10 +548,11 @@ class TaskConfiguration( ): type: Literal["task"] = "task" - class Config(ProfileParams.Config): + class Config(ProfileParams.Config, BaseRunConfiguration.Config): @staticmethod def schema_extra(schema: Dict[str, Any]): ProfileParams.Config.schema_extra(schema) + BaseRunConfiguration.Config.schema_extra(schema) class ServiceConfigurationParams(CoreModel): @@ -683,10 +695,15 @@ class ServiceConfiguration( ): type: Literal["service"] = "service" - class Config(ProfileParams.Config, ServiceConfigurationParams.Config): + class Config( + ProfileParams.Config, + BaseRunConfigurationWithCommands.Config, + ServiceConfigurationParams.Config, + ): @staticmethod def schema_extra(schema: Dict[str, Any]): ProfileParams.Config.schema_extra(schema) + BaseRunConfigurationWithCommands.Config.schema_extra(schema) ServiceConfigurationParams.Config.schema_extra(schema) diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index a787f7220f..c5f90b54ec 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -3,7 +3,7 @@ import threading from abc import ABC, abstractmethod from pathlib import PurePosixPath -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from cachetools import TTLCache, cached @@ -326,7 +326,7 @@ def _probes(self) -> list[ProbeSpec]: def interpolate_job_volumes( - run_volumes: List[Union[MountPoint, str]], + run_volumes: List[MountPoint], job_num: int, ) -> List[MountPoint]: if len(run_volumes) == 0: @@ -341,9 +341,6 @@ def interpolate_job_volumes( ) job_volumes = [] for mount_point in run_volumes: - if isinstance(mount_point, str): - # pydantic validator ensures strings are converted to MountPoint - continue if not isinstance(mount_point, VolumeMountPoint): job_volumes.append(mount_point.copy()) continue diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 6470e0b392..473c462139 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -812,7 +812,6 @@ def _validate_configuration_files( if configuration_path is not None: base_dir = Path(configuration_path).expanduser().resolve().parent for file_mapping in configuration.files: - assert isinstance(file_mapping, FilePathMapping) path = Path(file_mapping.local_path).expanduser() if not path.is_absolute(): if base_dir is None: From 95a4b999188de8010b30034a6a6940350d692a49 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 14:43:30 +0500 Subject: [PATCH 23/54] Fix retry.duration type annotation --- src/dstack/_internal/core/models/profiles.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 30aa94fd03..79da7e41cb 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -124,10 +124,18 @@ class ProfileRetry(CoreModel): ), ] = None duration: Annotated[ - Optional[Union[int, str]], + Optional[int], Field(description="The maximum period of retrying the run, e.g., `4h` or `1d`"), ] = None + class Config(CoreModel.Config): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["duration"], + extra_types=[{"type": "string"}], + ) + _validate_duration = validator("duration", pre=True, allow_reuse=True)(parse_duration) @root_validator From 407ef862bd257da985e96786d8b73d57b049fccc Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 15:18:40 +0500 Subject: [PATCH 24/54] Do not define Storage implementations when deps missing --- .../server/services/storage/__init__.py | 11 +- .../_internal/server/services/storage/gcs.py | 98 ++++++++--------- .../_internal/server/services/storage/s3.py | 104 +++++++++--------- 3 files changed, 106 insertions(+), 107 deletions(-) diff --git a/src/dstack/_internal/server/services/storage/__init__.py b/src/dstack/_internal/server/services/storage/__init__.py index 14b75c3477..d76a5d4bee 100644 --- a/src/dstack/_internal/server/services/storage/__init__.py +++ b/src/dstack/_internal/server/services/storage/__init__.py @@ -1,9 +1,8 @@ from typing import Optional from dstack._internal.server import settings +from dstack._internal.server.services.storage import gcs, s3 from dstack._internal.server.services.storage.base import BaseStorage -from dstack._internal.server.services.storage.gcs import GCS_AVAILABLE, GCSStorage -from dstack._internal.server.services.storage.s3 import BOTO_AVAILABLE, S3Storage _default_storage = None @@ -20,16 +19,16 @@ def init_default_storage(): ) if settings.SERVER_S3_BUCKET: - if not BOTO_AVAILABLE: + if not s3.BOTO_AVAILABLE: raise ValueError("AWS dependencies are not installed") - _default_storage = S3Storage( + _default_storage = s3.S3Storage( bucket=settings.SERVER_S3_BUCKET, region=settings.SERVER_S3_BUCKET_REGION, ) elif settings.SERVER_GCS_BUCKET: - if not GCS_AVAILABLE: + if not gcs.GCS_AVAILABLE: raise ValueError("GCS dependencies are not installed") - _default_storage = GCSStorage( + _default_storage = gcs.GCSStorage( bucket=settings.SERVER_GCS_BUCKET, ) diff --git a/src/dstack/_internal/server/services/storage/gcs.py b/src/dstack/_internal/server/services/storage/gcs.py index 6c565625e2..a0f9ac568f 100644 --- a/src/dstack/_internal/server/services/storage/gcs.py +++ b/src/dstack/_internal/server/services/storage/gcs.py @@ -8,59 +8,59 @@ from google.cloud.exceptions import NotFound except ImportError: GCS_AVAILABLE = False +else: + class GCSStorage(BaseStorage): + def __init__( + self, + bucket: str, + ): + self._client = storage.Client() + self._bucket = self._client.bucket(bucket) -class GCSStorage(BaseStorage): - def __init__( - self, - bucket: str, - ): - self._client = storage.Client() - self._bucket = self._client.bucket(bucket) + def upload_code( + self, + project_id: str, + repo_id: str, + code_hash: str, + blob: bytes, + ): + key = self._get_code_key(project_id, repo_id, code_hash) + self._upload(key, blob) - def upload_code( - self, - project_id: str, - repo_id: str, - code_hash: str, - blob: bytes, - ): - key = self._get_code_key(project_id, repo_id, code_hash) - self._upload(key, blob) + def get_code( + self, + project_id: str, + repo_id: str, + code_hash: str, + ) -> Optional[bytes]: + key = self._get_code_key(project_id, repo_id, code_hash) + return self._get(key) - def get_code( - self, - project_id: str, - repo_id: str, - code_hash: str, - ) -> Optional[bytes]: - key = self._get_code_key(project_id, repo_id, code_hash) - return self._get(key) + def upload_archive( + self, + user_id: str, + archive_hash: str, + blob: bytes, + ): + key = self._get_archive_key(user_id, archive_hash) + self._upload(key, blob) - def upload_archive( - self, - user_id: str, - archive_hash: str, - blob: bytes, - ): - key = self._get_archive_key(user_id, archive_hash) - self._upload(key, blob) + def get_archive( + self, + user_id: str, + archive_hash: str, + ) -> Optional[bytes]: + key = self._get_archive_key(user_id, archive_hash) + return self._get(key) - def get_archive( - self, - user_id: str, - archive_hash: str, - ) -> Optional[bytes]: - key = self._get_archive_key(user_id, archive_hash) - return self._get(key) + def _upload(self, key: str, blob: bytes): + blob_obj = self._bucket.blob(key) + blob_obj.upload_from_string(blob) - def _upload(self, key: str, blob: bytes): - blob_obj = self._bucket.blob(key) - blob_obj.upload_from_string(blob) - - def _get(self, key: str) -> Optional[bytes]: - try: - blob = self._bucket.blob(key) - except NotFound: - return None - return blob.download_as_bytes() + def _get(self, key: str) -> Optional[bytes]: + try: + blob = self._bucket.blob(key) + except NotFound: + return None + return blob.download_as_bytes() diff --git a/src/dstack/_internal/server/services/storage/s3.py b/src/dstack/_internal/server/services/storage/s3.py index a0b993c731..df4b652d1d 100644 --- a/src/dstack/_internal/server/services/storage/s3.py +++ b/src/dstack/_internal/server/services/storage/s3.py @@ -8,62 +8,62 @@ from boto3 import Session except ImportError: BOTO_AVAILABLE = False +else: + class S3Storage(BaseStorage): + def __init__( + self, + bucket: str, + region: Optional[str] = None, + ): + self._session = Session() + self._client = self._session.client("s3", region_name=region) + self.bucket = bucket -class S3Storage(BaseStorage): - def __init__( - self, - bucket: str, - region: Optional[str] = None, - ): - self._session = Session() - self._client = self._session.client("s3", region_name=region) - self.bucket = bucket + def upload_code( + self, + project_id: str, + repo_id: str, + code_hash: str, + blob: bytes, + ): + key = self._get_code_key(project_id, repo_id, code_hash) + self._upload(key, blob) - def upload_code( - self, - project_id: str, - repo_id: str, - code_hash: str, - blob: bytes, - ): - key = self._get_code_key(project_id, repo_id, code_hash) - self._upload(key, blob) + def get_code( + self, + project_id: str, + repo_id: str, + code_hash: str, + ) -> Optional[bytes]: + key = self._get_code_key(project_id, repo_id, code_hash) + return self._get(key) - def get_code( - self, - project_id: str, - repo_id: str, - code_hash: str, - ) -> Optional[bytes]: - key = self._get_code_key(project_id, repo_id, code_hash) - return self._get(key) + def upload_archive( + self, + user_id: str, + archive_hash: str, + blob: bytes, + ): + key = self._get_archive_key(user_id, archive_hash) + self._upload(key, blob) - def upload_archive( - self, - user_id: str, - archive_hash: str, - blob: bytes, - ): - key = self._get_archive_key(user_id, archive_hash) - self._upload(key, blob) + def get_archive( + self, + user_id: str, + archive_hash: str, + ) -> Optional[bytes]: + key = self._get_archive_key(user_id, archive_hash) + return self._get(key) - def get_archive( - self, - user_id: str, - archive_hash: str, - ) -> Optional[bytes]: - key = self._get_archive_key(user_id, archive_hash) - return self._get(key) + def _upload(self, key: str, blob: bytes): + self._client.put_object(Bucket=self.bucket, Key=key, Body=blob) - def _upload(self, key: str, blob: bytes): - self._client.put_object(Bucket=self.bucket, Key=key, Body=blob) - - def _get(self, key: str) -> Optional[bytes]: - try: - response = self._client.get_object(Bucket=self.bucket, Key=key) - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "NoSuchKey": - return None - raise e - return response["Body"].read() + def _get(self, key: str) -> Optional[bytes]: + try: + response = self._client.get_object(Bucket=self.bucket, Key=key) + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + return None + raise e + return response["Body"].read() From 690e96798e96b73ec99014481c7ed6c97e105403 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 15:26:37 +0500 Subject: [PATCH 25/54] Fix gateway domain None --- src/dstack/_internal/core/models/runs.py | 4 ++++ src/dstack/_internal/server/services/services/__init__.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 87a274a0c3..75f3b6b829 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta from enum import Enum from typing import Any, Dict, List, Literal, Optional, Type +from urllib.parse import urlparse from pydantic import UUID4, Field, root_validator from typing_extensions import Annotated @@ -483,6 +484,9 @@ class ServiceSpec(CoreModel): model: Optional[ServiceModelSpec] = None options: Dict[str, Any] = {} + def get_domain(self) -> Optional[str]: + return urlparse(self.url).hostname + class RunStatus(str, Enum): PENDING = "pending" diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 68fdc509cf..a8089a93a9 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -5,7 +5,6 @@ import uuid from datetime import datetime from typing import Optional -from urllib.parse import urlparse import httpx from sqlalchemy import select @@ -102,6 +101,9 @@ async def _register_service_in_gateway( model_url=f"{gateway_protocol}://gateway.{wildcard_domain}", ) + domain = service_spec.get_domain() + assert domain is not None + conn = await get_or_add_gateway_connection(session, gateway.id) try: logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) @@ -109,7 +111,7 @@ async def _register_service_in_gateway( await client.register_service( project=run_model.project.name, run_name=run_model.run_name, - domain=urlparse(service_spec.url).hostname, + domain=domain, service_https=service_https, gateway_https=gateway_https, auth=run_spec.configuration.auth, From b5d3e6c10dfc787892706c6c7cad9ff26e56aecb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 15:40:02 +0500 Subject: [PATCH 26/54] Overload get_backend_config --- .../core/backends/base/configurator.py | 15 ++++++++++++- .../server/services/backends/__init__.py | 22 ++++++++++++------- .../_internal/server/services/projects.py | 6 ++--- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/dstack/_internal/core/backends/base/configurator.py b/src/dstack/_internal/core/backends/base/configurator.py index 994266c438..f31e978a31 100644 --- a/src/dstack/_internal/core/backends/base/configurator.py +++ b/src/dstack/_internal/core/backends/base/configurator.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from typing import Any, ClassVar, List, Optional +from typing import Any, ClassVar, List, Literal, Optional, overload from uuid import UUID from dstack._internal.core.backends.base.backend import Backend from dstack._internal.core.backends.models import ( AnyBackendConfig, AnyBackendConfigWithCreds, + AnyBackendConfigWithoutCreds, ) from dstack._internal.core.errors import BackendInvalidCredentialsError from dstack._internal.core.models.backends.base import BackendType @@ -77,6 +78,18 @@ def create_backend( """ pass + @overload + def get_backend_config( + self, record: StoredBackendRecord, include_creds: Literal[False] + ) -> AnyBackendConfigWithoutCreds: + pass + + @overload + def get_backend_config( + self, record: StoredBackendRecord, include_creds: Literal[True] + ) -> AnyBackendConfigWithCreds: + pass + @abstractmethod def get_backend_config( self, record: StoredBackendRecord, include_creds: bool diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 16d8ae7821..9711a503bf 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -17,8 +17,8 @@ ) from dstack._internal.core.backends.local.backend import LocalBackend from dstack._internal.core.backends.models import ( - AnyBackendConfig, AnyBackendConfigWithCreds, + AnyBackendConfigWithoutCreds, ) from dstack._internal.core.errors import ( BackendError, @@ -126,19 +126,25 @@ async def get_backend_config( ) continue if backend_model.type == backend_type: - return get_backend_config_from_backend_model( - configurator, backend_model, include_creds=True - ) + return get_backend_config_with_creds_from_backend_model(configurator, backend_model) return None -def get_backend_config_from_backend_model( +def get_backend_config_with_creds_from_backend_model( + configurator: Configurator, + backend_model: BackendModel, +) -> AnyBackendConfigWithCreds: + backend_record = get_stored_backend_record(backend_model) + backend_config = configurator.get_backend_config(backend_record, include_creds=True) + return backend_config + + +def get_backend_config_without_creds_from_backend_model( configurator: Configurator, backend_model: BackendModel, - include_creds: bool, -) -> AnyBackendConfig: +) -> AnyBackendConfigWithoutCreds: backend_record = get_stored_backend_record(backend_model) - backend_config = configurator.get_backend_config(backend_record, include_creds=include_creds) + backend_config = configurator.get_backend_config(backend_record, include_creds=False) return backend_config diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index b901bc8e21..2ec37523e4 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -19,7 +19,7 @@ from dstack._internal.server.schemas.projects import MemberSetting from dstack._internal.server.services import users from dstack._internal.server.services.backends import ( - get_backend_config_from_backend_model, + get_backend_config_without_creds_from_backend_model, ) from dstack._internal.server.services.permissions import get_default_permissions from dstack._internal.server.settings import DEFAULT_PROJECT_NAME @@ -543,9 +543,7 @@ def project_model_to_project( b.type.value, ) continue - backend_config = get_backend_config_from_backend_model( - configurator, b, include_creds=False - ) + backend_config = get_backend_config_without_creds_from_backend_model(configurator, b) if isinstance(backend_config, DstackBackendConfig): for backend_type in backend_config.base_backends: backends.append( From edba8a65f9f299cc225fd1a3816eecdc43520f2f Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 15:59:41 +0500 Subject: [PATCH 27/54] Do not define LogStorage implementations when deps missing --- .../server/services/logs/__init__.py | 14 +- .../_internal/server/services/logs/aws.py | 656 +++++++++--------- .../_internal/server/services/logs/gcp.py | 280 ++++---- 3 files changed, 478 insertions(+), 472 deletions(-) diff --git a/src/dstack/_internal/server/services/logs/__init__.py b/src/dstack/_internal/server/services/logs/__init__.py index b38264980d..5b06ff4ad2 100644 --- a/src/dstack/_internal/server/services/logs/__init__.py +++ b/src/dstack/_internal/server/services/logs/__init__.py @@ -7,14 +7,14 @@ from dstack._internal.server.models import ProjectModel from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent -from dstack._internal.server.services.logs.aws import BOTO_AVAILABLE, CloudWatchLogStorage +from dstack._internal.server.services.logs import aws as aws_logs +from dstack._internal.server.services.logs import gcp as gcp_logs from dstack._internal.server.services.logs.base import ( LogStorage, LogStorageError, b64encode_raw_message, ) from dstack._internal.server.services.logs.filelog import FileLogStorage -from dstack._internal.server.services.logs.gcp import GCP_LOGGING_AVAILABLE, GCPLogStorage from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger @@ -29,9 +29,9 @@ def get_log_storage() -> LogStorage: if _log_storage is not None: return _log_storage if settings.SERVER_CLOUDWATCH_LOG_GROUP: - if BOTO_AVAILABLE: + if aws_logs.BOTO_AVAILABLE: try: - _log_storage = CloudWatchLogStorage( + _log_storage = aws_logs.CloudWatchLogStorage( group=settings.SERVER_CLOUDWATCH_LOG_GROUP, region=settings.SERVER_CLOUDWATCH_LOG_REGION, ) @@ -44,9 +44,11 @@ def get_log_storage() -> LogStorage: else: logger.error("Cannot use CloudWatch Logs storage: boto3 is not installed") elif settings.SERVER_GCP_LOGGING_PROJECT: - if GCP_LOGGING_AVAILABLE: + if gcp_logs.GCP_LOGGING_AVAILABLE: try: - _log_storage = GCPLogStorage(project_id=settings.SERVER_GCP_LOGGING_PROJECT) + _log_storage = gcp_logs.GCPLogStorage( + project_id=settings.SERVER_GCP_LOGGING_PROJECT + ) except LogStorageError as e: logger.error("Failed to initialize GCP Logs storage: %s", e) except Exception: diff --git a/src/dstack/_internal/server/services/logs/aws.py b/src/dstack/_internal/server/services/logs/aws.py index 692ae1348e..6f03a7fdca 100644 --- a/src/dstack/_internal/server/services/logs/aws.py +++ b/src/dstack/_internal/server/services/logs/aws.py @@ -24,347 +24,349 @@ ) from dstack._internal.utils.logging import get_logger +logger = get_logger(__name__) + + BOTO_AVAILABLE = True try: import boto3 import botocore.exceptions except ImportError: BOTO_AVAILABLE = False - -logger = get_logger(__name__) - - -class _CloudWatchLogEvent(TypedDict): - timestamp: int # unix time in milliseconds - message: str - - -class CloudWatchLogStorage(LogStorage): - # "The maximum number of log events in a batch is 10,000". - EVENT_MAX_COUNT_IN_BATCH = 10000 - # "The maximum batch size is 1,048,576 bytes" — exactly 1 MiB. "This size is calculated - # as the sum of all event messages in UTF-8, plus 26 bytes for each log event". - BATCH_MAX_SIZE = 1048576 - # "Each log event can be no larger than 256 KB" — KB means KiB; includes MESSAGE_OVERHEAD_SIZE. - MESSAGE_MAX_SIZE = 262144 - # Message size in bytes = len(message.encode("utf-8")) + MESSAGE_OVERHEAD_SIZE. - MESSAGE_OVERHEAD_SIZE = 26 - # "A batch of log events in a single request cannot span more than 24 hours". - BATCH_MAX_SPAN = int(timedelta(hours=24).total_seconds()) * 1000 - # Decrease allowed deltas by possible clock drift between dstack and CloudWatch. - CLOCK_DRIFT = int(timedelta(minutes=10).total_seconds()) * 1000 - # "None of the log events in the batch can be more than 14 days in the past." - PAST_EVENT_MAX_DELTA = int((timedelta(days=14)).total_seconds()) * 1000 - CLOCK_DRIFT - # "None of the log events in the batch can be more than 2 hours in the future." - FUTURE_EVENT_MAX_DELTA = int((timedelta(hours=2)).total_seconds()) * 1000 - CLOCK_DRIFT - # Maximum number of retries when polling for log events to skip empty pages. - MAX_RETRIES = 10 - - def __init__(self, *, group: str, region: Optional[str] = None) -> None: - with self._wrap_boto_errors(): - session = boto3.Session(region_name=region) - self._client = session.client("logs") - self._check_group_exists(group) - self._group = group - self._region = self._client.meta.region_name - # Stores names of already created streams. - # XXX: This set acts as an unbound cache. If this becomes a problem (in case of _very_ long - # running server and/or lots of jobs, consider replacing it with an LRU cache, e.g., - # a simple OrderedDict-based implementation should be OK. - self._streams: Set[str] = set() - - def close(self) -> None: - self._client.close() - - def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: - log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB - stream = self._get_stream_name( - project.name, request.run_name, request.job_submission_id, log_producer - ) - cw_events: List[_CloudWatchLogEvent] - next_token: Optional[str] = None - with self._wrap_boto_errors(): - try: - cw_events, next_token = self._get_log_events_with_retry(stream, request) - except botocore.exceptions.ClientError as e: - if not self._is_resource_not_found_exception(e): - raise - # Check if the group exists to distinguish between group not found vs stream not found - try: - self._check_group_exists(self._group) - # Group exists, so the error must be due to missing stream - logger.debug("Stream %s not found, returning dummy response", stream) - cw_events = [] - except LogStorageError: - # Group doesn't exist, re-raise the LogStorageError - raise - logs = [ - LogEvent( - timestamp=unix_time_ms_to_datetime(cw_event["timestamp"]), - log_source=LogEventSource.STDOUT, - message=cw_event["message"], +else: + + class _CloudWatchLogEvent(TypedDict): + timestamp: int # unix time in milliseconds + message: str + + class CloudWatchLogStorage(LogStorage): + # "The maximum number of log events in a batch is 10,000". + EVENT_MAX_COUNT_IN_BATCH = 10000 + # "The maximum batch size is 1,048,576 bytes" — exactly 1 MiB. "This size is calculated + # as the sum of all event messages in UTF-8, plus 26 bytes for each log event". + BATCH_MAX_SIZE = 1048576 + # "Each log event can be no larger than 256 KB" — KB means KiB; includes MESSAGE_OVERHEAD_SIZE. + MESSAGE_MAX_SIZE = 262144 + # Message size in bytes = len(message.encode("utf-8")) + MESSAGE_OVERHEAD_SIZE. + MESSAGE_OVERHEAD_SIZE = 26 + # "A batch of log events in a single request cannot span more than 24 hours". + BATCH_MAX_SPAN = int(timedelta(hours=24).total_seconds()) * 1000 + # Decrease allowed deltas by possible clock drift between dstack and CloudWatch. + CLOCK_DRIFT = int(timedelta(minutes=10).total_seconds()) * 1000 + # "None of the log events in the batch can be more than 14 days in the past." + PAST_EVENT_MAX_DELTA = int((timedelta(days=14)).total_seconds()) * 1000 - CLOCK_DRIFT + # "None of the log events in the batch can be more than 2 hours in the future." + FUTURE_EVENT_MAX_DELTA = int((timedelta(hours=2)).total_seconds()) * 1000 - CLOCK_DRIFT + # Maximum number of retries when polling for log events to skip empty pages. + MAX_RETRIES = 10 + + def __init__(self, *, group: str, region: Optional[str] = None) -> None: + with self._wrap_boto_errors(): + session = boto3.Session(region_name=region) + self._client = session.client("logs") + self._check_group_exists(group) + self._group = group + self._region = self._client.meta.region_name + # Stores names of already created streams. + # XXX: This set acts as an unbound cache. If this becomes a problem (in case of _very_ long + # running server and/or lots of jobs, consider replacing it with an LRU cache, e.g., + # a simple OrderedDict-based implementation should be OK. + self._streams: Set[str] = set() + + def close(self) -> None: + self._client.close() + + def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: + log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB + stream = self._get_stream_name( + project.name, request.run_name, request.job_submission_id, log_producer ) - for cw_event in cw_events - ] - return JobSubmissionLogs( - logs=logs, - external_url=self._get_stream_external_url(stream), - next_token=next_token, - ) - - def _get_log_events_with_retry( - self, stream: str, request: PollLogsRequest - ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]: - current_request = request - previous_next_token = request.next_token - - for attempt in range(self.MAX_RETRIES): - cw_events, next_token = self._get_log_events(stream, current_request) - - if cw_events: - return cw_events, next_token - - if not next_token or next_token == previous_next_token: - return [], None - - previous_next_token = next_token - current_request = PollLogsRequest( - run_name=request.run_name, - job_submission_id=request.job_submission_id, - start_time=request.start_time, - end_time=request.end_time, - descending=request.descending, + cw_events: List[_CloudWatchLogEvent] + next_token: Optional[str] = None + with self._wrap_boto_errors(): + try: + cw_events, next_token = self._get_log_events_with_retry(stream, request) + except botocore.exceptions.ClientError as e: + if not self._is_resource_not_found_exception(e): + raise + # Check if the group exists to distinguish between group not found vs stream not found + try: + self._check_group_exists(self._group) + # Group exists, so the error must be due to missing stream + logger.debug("Stream %s not found, returning dummy response", stream) + cw_events = [] + except LogStorageError: + # Group doesn't exist, re-raise the LogStorageError + raise + logs = [ + LogEvent( + timestamp=unix_time_ms_to_datetime(cw_event["timestamp"]), + log_source=LogEventSource.STDOUT, + message=cw_event["message"], + ) + for cw_event in cw_events + ] + return JobSubmissionLogs( + logs=logs, + external_url=self._get_stream_external_url(stream), next_token=next_token, - limit=request.limit, - diagnose=request.diagnose, ) - if not request.descending: - logger.debug( - "Stream %s: exhausted %d retries without finding logs, returning empty response", - stream, - self.MAX_RETRIES, - ) - # Only return the next token after exhausting retries if going descending— - # AWS CloudWatch guarantees more logs in that case. In ascending mode, - # next token is always returned, even if no logs remain. - # So descending works reliably; ascending has limits if gaps are too large. - # In the future, UI/CLI should handle retries, and we can return next token for ascending too. - return [], next_token if request.descending else None - - def _get_log_events( - self, stream: str, request: PollLogsRequest - ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]: - start_from_head = not request.descending - parameters = { - "logGroupName": self._group, - "logStreamName": stream, - "limit": request.limit, - "startFromHead": start_from_head, - } - - if request.start_time: - parameters["startTime"] = datetime_to_unix_time_ms(request.start_time) - - if request.end_time: - parameters["endTime"] = datetime_to_unix_time_ms(request.end_time) - elif start_from_head: - # When startFromHead=true and no endTime is provided, set endTime to "now" - # to prevent infinite pagination as new logs arrive faster than we can read them - parameters["endTime"] = datetime_to_unix_time_ms(datetime.now(timezone.utc)) - - if request.next_token: - parameters["nextToken"] = request.next_token - - response = self._client.get_log_events(**parameters) - - events = response.get("events", []) - next_token_key = "nextForwardToken" if start_from_head else "nextBackwardToken" - next_token = response.get(next_token_key) - - # TODO: The code below is not going to be used until we migrate from base64-encoded logs to plain text logs. - if request.descending: - events = list(reversed(events)) - - return events, next_token - - def _get_stream_external_url(self, stream: str) -> str: - quoted_group = urllib.parse.quote(self._group, safe="") - quoted_stream = urllib.parse.quote(stream, safe="") - return f"https://console.aws.amazon.com/cloudwatch/home?region={self._region}#logsV2:log-groups/log-group/{quoted_group}/log-events/{quoted_stream}" - - def write_logs( - self, - project: ProjectModel, - run_name: str, - job_submission_id: UUID, - runner_logs: List[RunnerLogEvent], - job_logs: List[RunnerLogEvent], - ): - if len(runner_logs) > 0: - runner_stream = self._get_stream_name( - project.name, run_name, job_submission_id, LogProducer.RUNNER - ) - self._write_logs( - stream=runner_stream, - log_events=runner_logs, - ) - if len(job_logs) > 0: - jog_stream = self._get_stream_name( - project.name, run_name, job_submission_id, LogProducer.JOB - ) - self._write_logs( - stream=jog_stream, - log_events=job_logs, - ) + def _get_log_events_with_retry( + self, stream: str, request: PollLogsRequest + ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]: + current_request = request + previous_next_token = request.next_token + + for attempt in range(self.MAX_RETRIES): + cw_events, next_token = self._get_log_events(stream, current_request) + + if cw_events: + return cw_events, next_token + + if not next_token or next_token == previous_next_token: + return [], None + + previous_next_token = next_token + current_request = PollLogsRequest( + run_name=request.run_name, + job_submission_id=request.job_submission_id, + start_time=request.start_time, + end_time=request.end_time, + descending=request.descending, + next_token=next_token, + limit=request.limit, + diagnose=request.diagnose, + ) - def _write_logs(self, stream: str, log_events: List[RunnerLogEvent]) -> None: - with self._wrap_boto_errors(): - self._ensure_stream_exists(stream) - try: + if not request.descending: + logger.debug( + "Stream %s: exhausted %d retries without finding logs, returning empty response", + stream, + self.MAX_RETRIES, + ) + # Only return the next token after exhausting retries if going descending— + # AWS CloudWatch guarantees more logs in that case. In ascending mode, + # next token is always returned, even if no logs remain. + # So descending works reliably; ascending has limits if gaps are too large. + # In the future, UI/CLI should handle retries, and we can return next token for ascending too. + return [], next_token if request.descending else None + + def _get_log_events( + self, stream: str, request: PollLogsRequest + ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]: + start_from_head = not request.descending + parameters = { + "logGroupName": self._group, + "logStreamName": stream, + "limit": request.limit, + "startFromHead": start_from_head, + } + + if request.start_time: + parameters["startTime"] = datetime_to_unix_time_ms(request.start_time) + + if request.end_time: + parameters["endTime"] = datetime_to_unix_time_ms(request.end_time) + elif start_from_head: + # When startFromHead=true and no endTime is provided, set endTime to "now" + # to prevent infinite pagination as new logs arrive faster than we can read them + parameters["endTime"] = datetime_to_unix_time_ms(datetime.now(timezone.utc)) + + if request.next_token: + parameters["nextToken"] = request.next_token + + response = self._client.get_log_events(**parameters) + + events = response.get("events", []) + next_token_key = "nextForwardToken" if start_from_head else "nextBackwardToken" + next_token = response.get(next_token_key) + + # TODO: The code below is not going to be used until we migrate from base64-encoded logs to plain text logs. + if request.descending: + events = list(reversed(events)) + + return events, next_token + + def _get_stream_external_url(self, stream: str) -> str: + quoted_group = urllib.parse.quote(self._group, safe="") + quoted_stream = urllib.parse.quote(stream, safe="") + return f"https://console.aws.amazon.com/cloudwatch/home?region={self._region}#logsV2:log-groups/log-group/{quoted_group}/log-events/{quoted_stream}" + + def write_logs( + self, + project: ProjectModel, + run_name: str, + job_submission_id: UUID, + runner_logs: List[RunnerLogEvent], + job_logs: List[RunnerLogEvent], + ): + if len(runner_logs) > 0: + runner_stream = self._get_stream_name( + project.name, run_name, job_submission_id, LogProducer.RUNNER + ) + self._write_logs( + stream=runner_stream, + log_events=runner_logs, + ) + if len(job_logs) > 0: + jog_stream = self._get_stream_name( + project.name, run_name, job_submission_id, LogProducer.JOB + ) + self._write_logs( + stream=jog_stream, + log_events=job_logs, + ) + + def _write_logs(self, stream: str, log_events: List[RunnerLogEvent]) -> None: + with self._wrap_boto_errors(): + self._ensure_stream_exists(stream) + try: + self._put_log_events(stream, log_events) + return + except botocore.exceptions.ClientError as e: + if not self._is_resource_not_found_exception(e): + raise + logger.debug("Stream %s not found, recreating", stream) + # The stream is probably deleted due to retention policy, our cache is stale. + self._ensure_stream_exists(stream, force=True) self._put_log_events(stream, log_events) - return - except botocore.exceptions.ClientError as e: - if not self._is_resource_not_found_exception(e): - raise - logger.debug("Stream %s not found, recreating", stream) - # The stream is probably deleted due to retention policy, our cache is stale. - self._ensure_stream_exists(stream, force=True) - self._put_log_events(stream, log_events) - - def _put_log_events(self, stream: str, log_events: List[RunnerLogEvent]) -> None: - # Python docs: "The built-in sorted() function is guaranteed to be stable." - sorted_log_events = sorted(log_events, key=operator.attrgetter("timestamp")) - if tuple(map(id, log_events)) != tuple(map(id, sorted_log_events)): - logger.error( - "Stream %s: events are not in chronological order, something wrong with runner", - stream, - ) - for batch in self._get_batch_iter(stream, sorted_log_events): - self._client.put_log_events( - logGroupName=self._group, - logStreamName=stream, - logEvents=batch, - ) - def _get_batch_iter( - self, stream: str, log_events: List[RunnerLogEvent] - ) -> Iterator[List[_CloudWatchLogEvent]]: - shared_event_iter = iter(log_events) - event_iter = shared_event_iter - while True: - batch, excessive_event = self._get_next_batch(stream, event_iter) - if not batch: - return - yield batch - if excessive_event is not None: - event_iter = itertools.chain([excessive_event], shared_event_iter) - else: - event_iter = shared_event_iter - - def _get_next_batch( - self, stream: str, event_iter: Iterator[RunnerLogEvent] - ) -> Tuple[List[_CloudWatchLogEvent], Optional[RunnerLogEvent]]: - now_timestamp = int(datetime.now(timezone.utc).timestamp() * 1000) - batch: List[_CloudWatchLogEvent] = [] - total_size = 0 - event_count = 0 - first_timestamp: Optional[int] = None - skipped_past_events = 0 - skipped_future_events = 0 - # event that doesn't fit in the current batch - excessive_event: Optional[RunnerLogEvent] = None - for event in event_iter: - # Normally there should not be empty messages. - if not event.message: - continue - timestamp = event.timestamp - if first_timestamp is None: - first_timestamp = timestamp - elif timestamp - first_timestamp > self.BATCH_MAX_SPAN: - excessive_event = event - break - if now_timestamp - timestamp > self.PAST_EVENT_MAX_DELTA: - skipped_past_events += 1 - continue - if timestamp - now_timestamp > self.FUTURE_EVENT_MAX_DELTA: - skipped_future_events += 1 - continue - cw_event = self._runner_log_event_to_cloudwatch_event(event) - message_size = len(event.message) + self.MESSAGE_OVERHEAD_SIZE - if message_size > self.MESSAGE_MAX_SIZE: - # we should never hit this limit, as we use `io.Copy` to copy from pty to logs, - # which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go, - # `execJob` -> `io.Copy(logger, ptmx)` + def _put_log_events(self, stream: str, log_events: List[RunnerLogEvent]) -> None: + # Python docs: "The built-in sorted() function is guaranteed to be stable." + sorted_log_events = sorted(log_events, key=operator.attrgetter("timestamp")) + if tuple(map(id, log_events)) != tuple(map(id, sorted_log_events)): logger.error( - "Stream %s: skipping event %d, message exceeds max size: %d > %d", + "Stream %s: events are not in chronological order, something wrong with runner", stream, - timestamp, - message_size, - self.MESSAGE_MAX_SIZE, ) - continue - if total_size + message_size > self.BATCH_MAX_SIZE: - excessive_event = event - break - batch.append(cw_event) - total_size += message_size - event_count += 1 - if event_count >= self.EVENT_MAX_COUNT_IN_BATCH: - break - if skipped_past_events > 0: - logger.error("Stream %s: skipping %d past event(s)", stream, skipped_past_events) - if skipped_future_events > 0: - logger.error("Stream %s: skipping %d future event(s)", stream, skipped_future_events) - return batch, excessive_event - - def _runner_log_event_to_cloudwatch_event( - self, runner_log_event: RunnerLogEvent - ) -> _CloudWatchLogEvent: - return { - "timestamp": runner_log_event.timestamp, - "message": runner_log_event.message.decode(errors="replace"), - } - - @contextmanager - def _wrap_boto_errors(self) -> Iterator[None]: - try: - yield - except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: - raise LogStorageError(f"CloudWatch Logs error: {type(e).__name__}: {e}") from e - - def _is_resource_not_found_exception(self, exc: "botocore.exceptions.ClientError") -> bool: - try: - return exc.response["Error"]["Code"] == "ResourceNotFoundException" - except KeyError: - return False - - def _check_group_exists(self, name: str) -> None: - try: - self._client.describe_log_streams(logGroupName=name, limit=1) - except botocore.exceptions.ClientError as e: - if self._is_resource_not_found_exception(e): - raise LogStorageError(f"LogGroup '{name}' does not exist") - raise - - def _ensure_stream_exists(self, name: str, *, force: bool = False) -> None: - if not force and name in self._streams: - return - response = self._client.describe_log_streams( - logGroupName=self._group, logStreamNamePrefix=name - ) - for stream in response["logStreams"]: - if stream["logStreamName"] == name: - self._streams.add(name) + for batch in self._get_batch_iter(stream, sorted_log_events): + self._client.put_log_events( + logGroupName=self._group, + logStreamName=stream, + logEvents=batch, + ) + + def _get_batch_iter( + self, stream: str, log_events: List[RunnerLogEvent] + ) -> Iterator[List[_CloudWatchLogEvent]]: + shared_event_iter = iter(log_events) + event_iter = shared_event_iter + while True: + batch, excessive_event = self._get_next_batch(stream, event_iter) + if not batch: + return + yield batch + if excessive_event is not None: + event_iter = itertools.chain([excessive_event], shared_event_iter) + else: + event_iter = shared_event_iter + + def _get_next_batch( + self, stream: str, event_iter: Iterator[RunnerLogEvent] + ) -> Tuple[List[_CloudWatchLogEvent], Optional[RunnerLogEvent]]: + now_timestamp = int(datetime.now(timezone.utc).timestamp() * 1000) + batch: List[_CloudWatchLogEvent] = [] + total_size = 0 + event_count = 0 + first_timestamp: Optional[int] = None + skipped_past_events = 0 + skipped_future_events = 0 + # event that doesn't fit in the current batch + excessive_event: Optional[RunnerLogEvent] = None + for event in event_iter: + # Normally there should not be empty messages. + if not event.message: + continue + timestamp = event.timestamp + if first_timestamp is None: + first_timestamp = timestamp + elif timestamp - first_timestamp > self.BATCH_MAX_SPAN: + excessive_event = event + break + if now_timestamp - timestamp > self.PAST_EVENT_MAX_DELTA: + skipped_past_events += 1 + continue + if timestamp - now_timestamp > self.FUTURE_EVENT_MAX_DELTA: + skipped_future_events += 1 + continue + cw_event = self._runner_log_event_to_cloudwatch_event(event) + message_size = len(event.message) + self.MESSAGE_OVERHEAD_SIZE + if message_size > self.MESSAGE_MAX_SIZE: + # we should never hit this limit, as we use `io.Copy` to copy from pty to logs, + # which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go, + # `execJob` -> `io.Copy(logger, ptmx)` + logger.error( + "Stream %s: skipping event %d, message exceeds max size: %d > %d", + stream, + timestamp, + message_size, + self.MESSAGE_MAX_SIZE, + ) + continue + if total_size + message_size > self.BATCH_MAX_SIZE: + excessive_event = event + break + batch.append(cw_event) + total_size += message_size + event_count += 1 + if event_count >= self.EVENT_MAX_COUNT_IN_BATCH: + break + if skipped_past_events > 0: + logger.error("Stream %s: skipping %d past event(s)", stream, skipped_past_events) + if skipped_future_events > 0: + logger.error( + "Stream %s: skipping %d future event(s)", stream, skipped_future_events + ) + return batch, excessive_event + + def _runner_log_event_to_cloudwatch_event( + self, runner_log_event: RunnerLogEvent + ) -> _CloudWatchLogEvent: + return { + "timestamp": runner_log_event.timestamp, + "message": runner_log_event.message.decode(errors="replace"), + } + + @contextmanager + def _wrap_boto_errors(self) -> Iterator[None]: + try: + yield + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + raise LogStorageError(f"CloudWatch Logs error: {type(e).__name__}: {e}") from e + + def _is_resource_not_found_exception(self, exc: "botocore.exceptions.ClientError") -> bool: + try: + return exc.response["Error"]["Code"] == "ResourceNotFoundException" + except KeyError: + return False + + def _check_group_exists(self, name: str) -> None: + try: + self._client.describe_log_streams(logGroupName=name, limit=1) + except botocore.exceptions.ClientError as e: + if self._is_resource_not_found_exception(e): + raise LogStorageError(f"LogGroup '{name}' does not exist") + raise + + def _ensure_stream_exists(self, name: str, *, force: bool = False) -> None: + if not force and name in self._streams: return - self._client.create_log_stream(logGroupName=self._group, logStreamName=name) - self._streams.add(name) - - def _get_stream_name( - self, - project_name: str, - run_name: str, - job_submission_id: UUID, - producer: LogProducer, - ) -> str: - return f"{project_name}/{run_name}/{job_submission_id}/{producer.value}" + response = self._client.describe_log_streams( + logGroupName=self._group, logStreamNamePrefix=name + ) + for stream in response["logStreams"]: + if stream["logStreamName"] == name: + self._streams.add(name) + return + self._client.create_log_stream(logGroupName=self._group, logStreamName=name) + self._streams.add(name) + + def _get_stream_name( + self, + project_name: str, + run_name: str, + job_submission_id: UUID, + producer: LogProducer, + ) -> str: + return f"{project_name}/{run_name}/{job_submission_id}/{producer.value}" diff --git a/src/dstack/_internal/server/services/logs/gcp.py b/src/dstack/_internal/server/services/logs/gcp.py index 7faa727dc1..f3b03a3e41 100644 --- a/src/dstack/_internal/server/services/logs/gcp.py +++ b/src/dstack/_internal/server/services/logs/gcp.py @@ -20,6 +20,9 @@ from dstack._internal.utils.common import batched from dstack._internal.utils.logging import get_logger +logger = get_logger(__name__) + + GCP_LOGGING_AVAILABLE = True try: import google.api_core.exceptions @@ -28,152 +31,151 @@ from google.cloud.logging_v2.types import ListLogEntriesRequest except ImportError: GCP_LOGGING_AVAILABLE = False - - -logger = get_logger(__name__) - - -class GCPLogStorage(LogStorage): - # Max expected message size from runner is 32KB. - # Max expected LogEntry size is 32KB + metadata < 50KB < 256KB limit. - # With MAX_BATCH_SIZE = 100, max write request size < 5MB < 10 MB limit. - # See: https://cloud.google.com/logging/quotas. - MAX_RUNNER_MESSAGE_SIZE = 32 * 1024 - MAX_BATCH_SIZE = 100 - - # Use the same log name for all run logs so that it's easy to manage all dstack-related logs. - LOG_NAME = "dstack-run-logs" - # Logs from different jobs belong to different "streams". - # GCP Logging has no built-in concepts of streams, so we implement them with labels. - # It should be fast to filter by labels since labels are indexed by default - # (https://cloud.google.com/logging/docs/analyze/custom-index). - - def __init__(self, project_id: str): - self.project_id = project_id - try: - self.client = logging_v2.Client(project=project_id) - self.logger = self.client.logger(name=self.LOG_NAME) - self.logger.list_entries(max_results=1) - # Python client doesn't seem to support dry_run, - # so emit an empty log to check permissions. - self.logger.log_empty() - except google.auth.exceptions.DefaultCredentialsError: - raise LogStorageError("Default credentials not found") - except google.api_core.exceptions.NotFound: - raise LogStorageError(f"Project {project_id} not found") - except google.api_core.exceptions.PermissionDenied: - raise LogStorageError("Insufficient permissions") - - def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: - # TODO: GCP may return logs in random order when events have the same timestamp. - producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB - stream_name = self._get_stream_name( - project_name=project.name, - run_name=request.run_name, - job_submission_id=request.job_submission_id, - producer=producer, - ) - log_filters = [f'labels.stream = "{stream_name}"'] - if request.start_time: - log_filters.append(f'timestamp > "{request.start_time.isoformat()}"') - if request.end_time: - log_filters.append(f'timestamp < "{request.end_time.isoformat()}"') - log_filter = " AND ".join(log_filters) - - order_by = logging_v2.DESCENDING if request.descending else logging_v2.ASCENDING - try: - # Use low-level API to get access to next_page_token - request_obj = ListLogEntriesRequest( - resource_names=[f"projects/{self.client.project}"], - filter=log_filter, - order_by=order_by, - page_size=request.limit, - page_token=request.next_token, - ) - response = self.client._logging_api._gapic_api.list_log_entries(request=request_obj) - - logs = [ - LogEvent( - timestamp=entry.timestamp, - message=entry.json_payload.get("message"), - log_source=LogEventSource.STDOUT, - ) - for entry in response.entries - ] - next_token = response.next_page_token or None - except google.api_core.exceptions.ResourceExhausted as e: - logger.warning("GCP Logging exception: %s", repr(e)) - # GCP Logging has severely low quota of 60 reads/min for entries.list - raise ServerClientError( - "GCP Logging read request limit exceeded." - " It's recommended to increase default entries.list request quota from 60 per minute." - ) - return JobSubmissionLogs( - logs=logs, - external_url=self._get_stream_extrnal_url(stream_name), - next_token=next_token if len(logs) > 0 else None, - ) - - def write_logs( - self, - project: ProjectModel, - run_name: str, - job_submission_id: UUID, - runner_logs: List[RunnerLogEvent], - job_logs: List[RunnerLogEvent], - ): - producers_with_logs = [(LogProducer.RUNNER, runner_logs), (LogProducer.JOB, job_logs)] - for producer, producer_logs in producers_with_logs: +else: + + class GCPLogStorage(LogStorage): + # Max expected message size from runner is 32KB. + # Max expected LogEntry size is 32KB + metadata < 50KB < 256KB limit. + # With MAX_BATCH_SIZE = 100, max write request size < 5MB < 10 MB limit. + # See: https://cloud.google.com/logging/quotas. + MAX_RUNNER_MESSAGE_SIZE = 32 * 1024 + MAX_BATCH_SIZE = 100 + + # Use the same log name for all run logs so that it's easy to manage all dstack-related logs. + LOG_NAME = "dstack-run-logs" + # Logs from different jobs belong to different "streams". + # GCP Logging has no built-in concepts of streams, so we implement them with labels. + # It should be fast to filter by labels since labels are indexed by default + # (https://cloud.google.com/logging/docs/analyze/custom-index). + + def __init__(self, project_id: str): + self.project_id = project_id + try: + self.client = logging_v2.Client(project=project_id) + self.logger = self.client.logger(name=self.LOG_NAME) + self.logger.list_entries(max_results=1) + # Python client doesn't seem to support dry_run, + # so emit an empty log to check permissions. + self.logger.log_empty() + except google.auth.exceptions.DefaultCredentialsError: + raise LogStorageError("Default credentials not found") + except google.api_core.exceptions.NotFound: + raise LogStorageError(f"Project {project_id} not found") + except google.api_core.exceptions.PermissionDenied: + raise LogStorageError("Insufficient permissions") + + def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: + # TODO: GCP may return logs in random order when events have the same timestamp. + producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB stream_name = self._get_stream_name( project_name=project.name, - run_name=run_name, - job_submission_id=job_submission_id, + run_name=request.run_name, + job_submission_id=request.job_submission_id, producer=producer, ) - self._write_logs_to_stream( - stream_name=stream_name, - logs=producer_logs, + log_filters = [f'labels.stream = "{stream_name}"'] + if request.start_time: + log_filters.append(f'timestamp > "{request.start_time.isoformat()}"') + if request.end_time: + log_filters.append(f'timestamp < "{request.end_time.isoformat()}"') + log_filter = " AND ".join(log_filters) + + order_by = logging_v2.DESCENDING if request.descending else logging_v2.ASCENDING + try: + # Use low-level API to get access to next_page_token + request_obj = ListLogEntriesRequest( + resource_names=[f"projects/{self.client.project}"], + filter=log_filter, + order_by=order_by, + page_size=request.limit, + page_token=request.next_token, + ) + response = self.client._logging_api._gapic_api.list_log_entries( + request=request_obj + ) + + logs = [ + LogEvent( + timestamp=entry.timestamp, + message=entry.json_payload.get("message"), + log_source=LogEventSource.STDOUT, + ) + for entry in response.entries + ] + next_token = response.next_page_token or None + except google.api_core.exceptions.ResourceExhausted as e: + logger.warning("GCP Logging exception: %s", repr(e)) + # GCP Logging has severely low quota of 60 reads/min for entries.list + raise ServerClientError( + "GCP Logging read request limit exceeded." + " It's recommended to increase default entries.list request quota from 60 per minute." + ) + return JobSubmissionLogs( + logs=logs, + external_url=self._get_stream_extrnal_url(stream_name), + next_token=next_token if len(logs) > 0 else None, ) - def close(self): - self.client.close() - - def _write_logs_to_stream(self, stream_name: str, logs: List[RunnerLogEvent]): - with self.logger.batch() as batcher: - for batch in batched(logs, self.MAX_BATCH_SIZE): - for log in batch: - message = log.message.decode(errors="replace") - timestamp = unix_time_ms_to_datetime(log.timestamp) - if len(log.message) > self.MAX_RUNNER_MESSAGE_SIZE: - logger.error( - "Stream %s: skipping event at %s, message exceeds max size: %d > %d", - stream_name, - timestamp.isoformat(), - len(log.message), - self.MAX_RUNNER_MESSAGE_SIZE, + def write_logs( + self, + project: ProjectModel, + run_name: str, + job_submission_id: UUID, + runner_logs: List[RunnerLogEvent], + job_logs: List[RunnerLogEvent], + ): + producers_with_logs = [(LogProducer.RUNNER, runner_logs), (LogProducer.JOB, job_logs)] + for producer, producer_logs in producers_with_logs: + stream_name = self._get_stream_name( + project_name=project.name, + run_name=run_name, + job_submission_id=job_submission_id, + producer=producer, + ) + self._write_logs_to_stream( + stream_name=stream_name, + logs=producer_logs, + ) + + def close(self): + self.client.close() + + def _write_logs_to_stream(self, stream_name: str, logs: List[RunnerLogEvent]): + with self.logger.batch() as batcher: + for batch in batched(logs, self.MAX_BATCH_SIZE): + for log in batch: + message = log.message.decode(errors="replace") + timestamp = unix_time_ms_to_datetime(log.timestamp) + if len(log.message) > self.MAX_RUNNER_MESSAGE_SIZE: + logger.error( + "Stream %s: skipping event at %s, message exceeds max size: %d > %d", + stream_name, + timestamp.isoformat(), + len(log.message), + self.MAX_RUNNER_MESSAGE_SIZE, + ) + continue + batcher.log_struct( + { + "message": message, + }, + labels={ + "stream": stream_name, + }, + timestamp=timestamp, ) - continue - batcher.log_struct( - { - "message": message, - }, - labels={ - "stream": stream_name, - }, - timestamp=timestamp, - ) - batcher.commit() + batcher.commit() - def _get_stream_name( - self, project_name: str, run_name: str, job_submission_id: UUID, producer: LogProducer - ) -> str: - return f"{project_name}-{run_name}-{job_submission_id}-{producer.value}" + def _get_stream_name( + self, project_name: str, run_name: str, job_submission_id: UUID, producer: LogProducer + ) -> str: + return f"{project_name}-{run_name}-{job_submission_id}-{producer.value}" - def _get_stream_extrnal_url(self, stream_name: str) -> str: - log_name_resource_name = self._get_log_name_resource_name() - query = f'logName="{log_name_resource_name}" AND labels.stream="{stream_name}"' - quoted_query = urllib.parse.quote(query, safe="") - return f"https://console.cloud.google.com/logs/query;query={quoted_query}?project={self.project_id}" + def _get_stream_extrnal_url(self, stream_name: str) -> str: + log_name_resource_name = self._get_log_name_resource_name() + query = f'logName="{log_name_resource_name}" AND labels.stream="{stream_name}"' + quoted_query = urllib.parse.quote(query, safe="") + return f"https://console.cloud.google.com/logs/query;query={quoted_query}?project={self.project_id}" - def _get_log_name_resource_name(self) -> str: - return f"projects/{self.project_id}/logs/{self.LOG_NAME}" + def _get_log_name_resource_name(self) -> str: + return f"projects/{self.project_id}/logs/{self.LOG_NAME}" From f1e02708dea09d99dcaba678da54c6b07a4fbf66 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 16:07:34 +0500 Subject: [PATCH 28/54] Fix filelog typing --- .../_internal/server/services/logs/filelog.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/dstack/_internal/server/services/logs/filelog.py b/src/dstack/_internal/server/services/logs/filelog.py index 823222a409..e4289805c6 100644 --- a/src/dstack/_internal/server/services/logs/filelog.py +++ b/src/dstack/_internal/server/services/logs/filelog.py @@ -48,7 +48,7 @@ def _poll_logs_ascending( ) -> JobSubmissionLogs: start_line = 0 if request.next_token: - start_line = self._next_token(request) + start_line = self._parse_next_token(request.next_token) logs = [] next_token = None @@ -97,7 +97,9 @@ def _poll_logs_ascending( def _poll_logs_descending( self, log_file_path: Path, request: PollLogsRequest ) -> JobSubmissionLogs: - start_offset = self._next_token(request) + start_offset = None + if request.next_token is not None: + start_offset = self._parse_next_token(request.next_token) candidate_logs = [] @@ -123,12 +125,12 @@ def _poll_logs_descending( except FileNotFoundError: return JobSubmissionLogs(logs=[], next_token=None) - logs = [log for log, offset in candidate_logs[: request.limit]] + logs = [log for log, _ in candidate_logs[: request.limit]] next_token = None if len(candidate_logs) > request.limit: # We fetched one more than the limit, so there are more pages. # The next token should point to the start of the last log we are returning. - _last_log_event, last_log_offset = candidate_logs[request.limit - 1] + _, last_log_offset = candidate_logs[request.limit - 1] next_token = str(last_log_offset) return JobSubmissionLogs(logs=logs, next_token=next_token) @@ -245,8 +247,7 @@ def _runner_log_event_to_log_event(self, runner_log_event: RunnerLogEvent) -> Lo message=runner_log_event.message.decode(errors="replace"), ) - def _next_token(self, request: PollLogsRequest) -> Optional[int]: - next_token = request.next_token + def _parse_next_token(self, next_token: str) -> int: if next_token is None: return None try: From bcbe2c766ff5e6005c2eec969201bde256d2ceb8 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 16:17:23 +0500 Subject: [PATCH 29/54] Use async_sessionmaker --- src/dstack/_internal/server/db.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/server/db.py b/src/dstack/_internal/server/db.py index 4e747a8e78..084630add1 100644 --- a/src/dstack/_internal/server/db.py +++ b/src/dstack/_internal/server/db.py @@ -4,8 +4,12 @@ from alembic import command, config from sqlalchemy import AsyncAdaptedQueuePool, event from sqlalchemy.engine.interfaces import DBAPIConnection -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) from sqlalchemy.pool import ConnectionPoolEntry from dstack._internal.server import settings @@ -26,8 +30,8 @@ def __init__(self, url: str, engine: Optional[AsyncEngine] = None): pool_size=settings.DB_POOL_SIZE, max_overflow=settings.DB_MAX_OVERFLOW, ) - self.session_maker = sessionmaker( - bind=self.engine, + self.session_maker = async_sessionmaker( + bind=self.engine, # type: ignore[assignment] expire_on_commit=False, class_=AsyncSession, ) From 8f291f2e344765e4c5e33331c7785a30e40b2af4 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 19 Aug 2025 16:23:33 +0500 Subject: [PATCH 30/54] Assert proxy_jump.ssh_key --- src/dstack/_internal/server/services/fleets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index e02cac6589..4f2b64cc5a 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -504,6 +504,7 @@ async def create_fleet_ssh_instance_model( raise ServerClientError("ssh key or user not specified") if proxy_jump is not None: + assert proxy_jump.ssh_key is not None ssh_proxy = SSHConnectionParams( hostname=proxy_jump.hostname, port=proxy_jump.port or 22, From 84171fe838bc98d94c1f4d4d65b0e488cc6b4fd6 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 10:00:52 +0500 Subject: [PATCH 31/54] Ignore type errors from deps --- src/dstack/_internal/server/services/docker.py | 4 ++-- src/dstack/_internal/server/services/logs/gcp.py | 2 +- src/dstack/_internal/server/utils/provisioning.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dstack/_internal/server/services/docker.py b/src/dstack/_internal/server/services/docker.py index 40291179df..7181edc7d3 100644 --- a/src/dstack/_internal/server/services/docker.py +++ b/src/dstack/_internal/server/services/docker.py @@ -77,7 +77,7 @@ def get_image_config(image_name: str, registry_auth: Optional[RegistryAuth]) -> registry_client = PatchedDXF( host=image.registry or DEFAULT_REGISTRY, repo=image.repo, - auth=DXFAuthAdapter(registry_auth), + auth=DXFAuthAdapter(registry_auth), # type: ignore[assignment] timeout=REGISTRY_REQUEST_TIMEOUT, ) @@ -88,7 +88,7 @@ def get_image_config(image_name: str, registry_auth: Optional[RegistryAuth]) -> ) manifest = ImageManifest.__response__.parse_raw(manifest_resp) config_stream = registry_client.pull_blob(manifest.config.digest) - config_resp = join_byte_stream_checked(config_stream, MAX_CONFIG_OBJECT_SIZE) + config_resp = join_byte_stream_checked(config_stream, MAX_CONFIG_OBJECT_SIZE) # type: ignore[arg-type] if config_resp is None: raise DockerRegistryError( f"Image config object exceeds the size limit of {MAX_CONFIG_OBJECT_SIZE} bytes" diff --git a/src/dstack/_internal/server/services/logs/gcp.py b/src/dstack/_internal/server/services/logs/gcp.py index f3b03a3e41..c1b1a75cf1 100644 --- a/src/dstack/_internal/server/services/logs/gcp.py +++ b/src/dstack/_internal/server/services/logs/gcp.py @@ -90,7 +90,7 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi page_size=request.limit, page_token=request.next_token, ) - response = self.client._logging_api._gapic_api.list_log_entries( + response = self.client._logging_api._gapic_api.list_log_entries( # type: ignore[attr-defined] request=request_obj ) diff --git a/src/dstack/_internal/server/utils/provisioning.py b/src/dstack/_internal/server/utils/provisioning.py index 94a5347343..b77efe7db4 100644 --- a/src/dstack/_internal/server/utils/provisioning.py +++ b/src/dstack/_internal/server/utils/provisioning.py @@ -312,10 +312,10 @@ def get_paramiko_connection( with proxy_ctx as proxy_client, paramiko.SSHClient() as client: proxy_channel: Optional[paramiko.Channel] = None if proxy_client is not None: + transport = proxy_client.get_transport() + assert transport is not None try: - proxy_channel = proxy_client.get_transport().open_channel( - "direct-tcpip", (host, port), ("", 0) - ) + proxy_channel = transport.open_channel("direct-tcpip", (host, port), ("", 0)) except (paramiko.SSHException, OSError) as e: raise ProvisioningError(f"Proxy channel failed: {e}") from e client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) From cca1b01c4e576f485e1614a2d15fe27a42a6ee7b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 10:05:38 +0500 Subject: [PATCH 32/54] Forbid entrypoint for dev-environment Fixes #3002 --- src/dstack/_internal/core/models/configurations.py | 6 ++++++ .../_internal/server/services/jobs/configurators/base.py | 1 + 2 files changed, 7 insertions(+) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index a721dfbfae..4b44255d87 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -535,6 +535,12 @@ def schema_extra(schema: Dict[str, Any]): ProfileParams.Config.schema_extra(schema) BaseRunConfigurationWithPorts.Config.schema_extra(schema) + @validator("entrypoint") + def validate_entrypoint(cls, v: Optional[str]) -> Optional[str]: + if v is not None: + raise ValueError("entrypoint is not supported for dev-environment") + return v + class TaskConfigurationParams(CoreModel): nodes: Annotated[int, Field(description="Number of nodes", ge=1)] = 1 diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index c5f90b54ec..ee9bb05f6e 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -179,6 +179,7 @@ def _shell(self) -> str: async def _commands(self) -> List[str]: if self.run_spec.configuration.entrypoint is not None: # docker-like format + assert self.run_spec.configuration.type != "dev-environment" entrypoint = shlex.split(self.run_spec.configuration.entrypoint) commands = self.run_spec.configuration.commands elif shell_commands := self._shell_commands(): From 731aad3076277181fb2886fde3760496c9b822b3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 10:15:21 +0500 Subject: [PATCH 33/54] Fix vscode and cursor __init__ annotations --- src/dstack/_internal/cli/services/configurators/run.py | 2 +- .../_internal/server/services/jobs/configurators/dev.py | 1 + .../server/services/jobs/configurators/extensions/cursor.py | 4 ++-- .../server/services/jobs/configurators/extensions/vscode.py | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 6133999ef6..71045baa79 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -552,7 +552,7 @@ def _unique_ports_constraint(ports: List[int]): used_ports.add(i) -def _detect_vscode_version(exe: str = "code") -> Optional[str]: +def _detect_vscode_version(exe: str = "code1") -> Optional[str]: try: run = subprocess.run([exe, "--version"], capture_output=True) except FileNotFoundError: diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py index 20aad1f232..bd49a0e898 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/dev.py +++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py @@ -19,6 +19,7 @@ class DevEnvironmentJobConfigurator(JobConfigurator): def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]): assert run_spec.configuration.type == "dev-environment" + assert run_spec.run_name is not None if run_spec.configuration.ide == "vscode": __class = VSCodeDesktop diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py index d0c819d8da..4b9a91dcfd 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py +++ b/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from dstack._internal.core.models.configurations import DEFAULT_REPO_DIR @@ -7,7 +7,7 @@ class CursorDesktop: def __init__( self, run_name: str, - version: str, + version: Optional[str], extensions: List[str], ): self.run_name = run_name diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py index f1a2534de0..ac8358f72b 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py +++ b/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from dstack._internal.core.models.configurations import DEFAULT_REPO_DIR @@ -7,7 +7,7 @@ class VSCodeDesktop: def __init__( self, run_name: str, - version: str, + version: Optional[str], extensions: List[str], ): self.run_name = run_name From 05cfb9142213918895dc11b8b95ae0c2effce05f Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 10:40:37 +0500 Subject: [PATCH 34/54] Pass probe_spec.body as content --- src/dstack/_internal/server/background/tasks/process_probes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/tasks/process_probes.py b/src/dstack/_internal/server/background/tasks/process_probes.py index 5ed9375d13..bc1dc09431 100644 --- a/src/dstack/_internal/server/background/tasks/process_probes.py +++ b/src/dstack/_internal/server/background/tasks/process_probes.py @@ -120,7 +120,7 @@ async def _execute_probe(probe: ProbeModel, probe_spec: ProbeSpec) -> bool: method=probe_spec.method, url="http://dstack" + probe_spec.url, headers=[(h.name, h.value) for h in probe_spec.headers], - data=probe_spec.body, + content=probe_spec.body, timeout=probe_spec.timeout, follow_redirects=False, ) From e27129b8901f53ddfce9ef5356f150a728f08514 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 10:45:49 +0500 Subject: [PATCH 35/54] Assert gateway configuration.name --- src/dstack/_internal/server/services/gateways/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 5a7b50d021..f47b192999 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -93,6 +93,8 @@ async def create_gateway_compute( backend_id: Optional[uuid.UUID] = None, ) -> GatewayComputeModel: assert isinstance(backend_compute, ComputeWithGatewaySupport) + assert configuration.name is not None + private_bytes, public_bytes = generate_rsa_key_pair_bytes() gateway_ssh_private_key = private_bytes.decode() gateway_ssh_public_key = public_bytes.decode() From 62e34bfd440afdc1632eac1b10b2c8dc55848e88 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 11:58:51 +0500 Subject: [PATCH 36/54] Fix unbound next_token --- src/dstack/_internal/server/services/logs/aws.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/logs/aws.py b/src/dstack/_internal/server/services/logs/aws.py index 6f03a7fdca..4e56f0865d 100644 --- a/src/dstack/_internal/server/services/logs/aws.py +++ b/src/dstack/_internal/server/services/logs/aws.py @@ -117,8 +117,9 @@ def _get_log_events_with_retry( ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]: current_request = request previous_next_token = request.next_token + next_token = None - for attempt in range(self.MAX_RETRIES): + for _ in range(self.MAX_RETRIES): cw_events, next_token = self._get_log_events(stream, current_request) if cw_events: From eee29cdfab4a79114e96a5485fcb3cd0bc1aa871 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 12:12:08 +0500 Subject: [PATCH 37/54] Cast path to str --- src/dstack/_internal/server/services/gateways/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/gateways/connection.py b/src/dstack/_internal/server/services/gateways/connection.py index 6b107c34a9..b8df322a1d 100644 --- a/src/dstack/_internal/server/services/gateways/connection.py +++ b/src/dstack/_internal/server/services/gateways/connection.py @@ -67,7 +67,7 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int): # reverse_forwarded_sockets are added later in .open() ) self.tunnel_id = uuid.uuid4() - self._client = GatewayClient(uds=self.gateway_socket_path) + self._client = GatewayClient(uds=str(self.gateway_socket_path)) @staticmethod def _init_symlink_dir(connection_dir: Path) -> Tuple[TemporaryDirectory, Path]: From 54bb2f079f5a3fbc66b184dfc940520442c88c5b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 12:16:13 +0500 Subject: [PATCH 38/54] Fix unbound success var --- .../_internal/server/background/tasks/process_running_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 2b6c173c54..122624972b 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -346,7 +346,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): server_ssh_private_keys, job_provisioning_data, ) - elif initial_status == JobStatus.RUNNING: + else: logger.debug("%s: process running job, age=%s", fmt(job_model), job_submission.age) success = await common_utils.run_async( _process_running, From 4e9cbef47d731d656a60bf49bc303e3ee593d14e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 12:23:23 +0500 Subject: [PATCH 39/54] Fix typing --- .../_internal/server/background/tasks/process_running_jobs.py | 1 + src/dstack/_internal/server/services/proxy/repo.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 122624972b..19cb089b11 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -637,6 +637,7 @@ def _process_pulling_with_shim( is successful """ shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + job_runtime_data = None if shim_client.is_api_v2_supported(): # raises error if shim is down, causes retry task = shim_client.get_task(job_model.id) diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 3a8204377a..ae7ea19f8d 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -142,7 +142,7 @@ async def list_models(self, project_name: str) -> List[ChatModel]: model_options_obj = service_spec.options.get("openai", {}).get("model") if model_spec is None or model_options_obj is None: continue - model_options = pydantic.parse_obj_as(AnyModel, model_options_obj) + model_options = pydantic.parse_obj_as(AnyModel, model_options_obj) # type: ignore[arg-type] model = ChatModel( project_name=project_name, name=model_spec.name, From 4ec44eaab5b7891ec14e06f7be6a7ea2bbd58ca5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 13:38:34 +0500 Subject: [PATCH 40/54] Add pyright config --- pyproject.toml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a14ac0cbbf..8671d0601c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,14 @@ pattern = '\s*|]*>\s*|\s*|]*>\s*|\ replacement = '' ignore-case = true +[tool.pyright] +include = [ + "src/dstack/_internal/server", +] +ignore = [ + "src/dstack/_internal/server/migrations/versions", +] + [dependency-groups] dev = [ "httpx>=0.28.1", From 2d93912e49ed6a82abb0fa1229bbf0bdba9637a5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 13:42:55 +0500 Subject: [PATCH 41/54] Run pyright in CI --- .github/workflows/build.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8a97c67c7d..891316a627 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,6 +34,9 @@ jobs: python-version: 3.11 - run: uv tool install pre-commit - run: pre-commit run -a --show-diff-on-failure + - uses: jakebailey/pyright-action@v2 + with: + pylance-version: latest-release frontend-build: runs-on: ubuntu-latest From d90c9fe5e27a6e74825a747382dfe22719837172 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 13:48:29 +0500 Subject: [PATCH 42/54] Run pyright as part of tests --- .github/workflows/build.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 891316a627..1e4a3ff648 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,9 +34,6 @@ jobs: python-version: 3.11 - run: uv tool install pre-commit - run: pre-commit run -a --show-diff-on-failure - - uses: jakebailey/pyright-action@v2 - with: - pylance-version: latest-release frontend-build: runs-on: ubuntu-latest @@ -85,6 +82,10 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: uv sync --all-extras + - name: Run pyright + uses: jakebailey/pyright-action@v2 + with: + pylance-version: latest-release - name: Download frontend build uses: actions/download-artifact@v4 with: From 0e15cf11e2160e61703d4daa72dcb9c0be8d4740 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 14:05:05 +0500 Subject: [PATCH 43/54] Ignore type for entry_points --- src/dstack/_internal/server/services/plugins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/plugins.py b/src/dstack/_internal/server/services/plugins.py index 8acd101f9c..933ed43052 100644 --- a/src/dstack/_internal/server/services/plugins.py +++ b/src/dstack/_internal/server/services/plugins.py @@ -60,7 +60,7 @@ def load_plugins(enabled_plugins: list[str]): _PLUGINS.clear() entrypoints: dict[str, PluginEntrypoint] = {} plugins_to_load = enabled_plugins.copy() - for entrypoint in entry_points(group="dstack.plugins"): + for entrypoint in entry_points(group="dstack.plugins"): # type: ignore[call-arg] if entrypoint.name not in enabled_plugins: logger.info( ("Found not enabled plugin %s. Plugin will not be loaded."), From a02f6ff349c2acc8d9683e0bb5a676ccc4445638 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 14:06:47 +0500 Subject: [PATCH 44/54] Fix _detect_vscode_version --- src/dstack/_internal/cli/services/configurators/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 71045baa79..6133999ef6 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -552,7 +552,7 @@ def _unique_ports_constraint(ports: List[int]): used_ports.add(i) -def _detect_vscode_version(exe: str = "code1") -> Optional[str]: +def _detect_vscode_version(exe: str = "code") -> Optional[str]: try: run = subprocess.run([exe, "--version"], capture_output=True) except FileNotFoundError: From 4e51e4eda7ce3dd57fc8dd8a9035b10f8dc7ea84 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 14:13:31 +0500 Subject: [PATCH 45/54] Remove run_name assert --- src/dstack/_internal/server/services/jobs/configurators/dev.py | 1 - .../server/services/jobs/configurators/extensions/cursor.py | 2 +- .../server/services/jobs/configurators/extensions/vscode.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py index bd49a0e898..20aad1f232 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/dev.py +++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py @@ -19,7 +19,6 @@ class DevEnvironmentJobConfigurator(JobConfigurator): def __init__(self, run_spec: RunSpec, secrets: Dict[str, str]): assert run_spec.configuration.type == "dev-environment" - assert run_spec.run_name is not None if run_spec.configuration.ide == "vscode": __class = VSCodeDesktop diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py index 4b9a91dcfd..9c5e68d96e 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py +++ b/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py @@ -6,7 +6,7 @@ class CursorDesktop: def __init__( self, - run_name: str, + run_name: Optional[str], version: Optional[str], extensions: List[str], ): diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py index ac8358f72b..a10b254d02 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py +++ b/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py @@ -6,7 +6,7 @@ class VSCodeDesktop: def __init__( self, - run_name: str, + run_name: Optional[str], version: Optional[str], extensions: List[str], ): From 8614e7ec8879b671c7cfaf7e72e09c716fd9cda2 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 14:19:07 +0500 Subject: [PATCH 46/54] Fix add_extra_schema_types for one $ref --- src/dstack/_internal/utils/json_schema.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/utils/json_schema.py b/src/dstack/_internal/utils/json_schema.py index 73ee643179..19bcd0bc62 100644 --- a/src/dstack/_internal/utils/json_schema.py +++ b/src/dstack/_internal/utils/json_schema.py @@ -3,7 +3,9 @@ def add_extra_schema_types(schema_property: dict, extra_types: list[dict]): refs = [schema_property.pop("allOf")[0]] elif "anyOf" in schema_property: refs = schema_property.pop("anyOf") - else: + elif "type" in schema_property: refs = [{"type": schema_property.pop("type")}] + else: + refs = [{"$ref": schema_property.pop("$ref")}] refs.extend(extra_types) schema_property["anyOf"] = refs From 3f4d7d8b8b40a2bdf6516ffa2bb7cd930a455150 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 17:11:19 +0500 Subject: [PATCH 47/54] Replace ConfigurationWith extensions with mixins --- pyproject.toml | 1 + src/dstack/_internal/cli/commands/offer.py | 46 ++++++++++++-- .../cli/services/configurators/__init__.py | 6 +- .../cli/services/configurators/base.py | 10 ++-- .../cli/services/configurators/fleet.py | 2 +- .../cli/services/configurators/gateway.py | 2 +- .../cli/services/configurators/run.py | 60 ++++++++++++------- .../cli/services/configurators/volume.py | 2 +- .../_internal/core/models/configurations.py | 27 +++++---- 9 files changed, 110 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8671d0601c..16ecc0e31e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ ignore-case = true [tool.pyright] include = [ "src/dstack/_internal/server", + "src/dstack/_internal/cli/services/configurators", ] ignore = [ "src/dstack/_internal/server/migrations/versions", diff --git a/src/dstack/_internal/cli/commands/offer.py b/src/dstack/_internal/cli/commands/offer.py index cb920ec275..21ddb71b62 100644 --- a/src/dstack/_internal/cli/commands/offer.py +++ b/src/dstack/_internal/cli/commands/offer.py @@ -4,9 +4,11 @@ from pathlib import Path from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec from dstack._internal.cli.services.configurators.run import ( BaseRunConfigurator, ) +from dstack._internal.cli.services.profile import register_profile_args from dstack._internal.cli.utils.common import console from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.models.configurations import ( @@ -27,11 +29,45 @@ class OfferConfigurator(BaseRunConfigurator): TYPE = ApplyConfigurationType.TASK @classmethod - def register_args( - cls, - parser: argparse.ArgumentParser, - ): - super().register_args(parser, default_max_offers=50) + def register_args(cls, parser: argparse.ArgumentParser): + configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options") + configuration_group.add_argument( + "-n", + "--name", + dest="run_name", + help="The name of the run. If not specified, a random name is assigned", + ) + configuration_group.add_argument( + "--max-offers", + help="Number of offers to show in the run plan", + type=int, + default=50, + ) + cls.register_env_args(configuration_group) + configuration_group.add_argument( + "--cpu", + type=cpu_spec, + help="Request CPU for the run. " + "The format is [code]ARCH[/]:[code]COUNT[/] (all parts are optional)", + dest="cpu_spec", + metavar="SPEC", + ) + configuration_group.add_argument( + "--gpu", + type=gpu_spec, + help="Request GPU for the run. " + "The format is [code]NAME[/]:[code]COUNT[/]:[code]MEMORY[/] (all parts are optional)", + dest="gpu_spec", + metavar="SPEC", + ) + configuration_group.add_argument( + "--disk", + type=disk_spec, + help="Request the size range of disk for the run. Example [code]--disk 100GB..[/].", + metavar="RANGE", + dest="disk_spec", + ) + register_profile_args(parser) # TODO: Support aggregated offers diff --git a/src/dstack/_internal/cli/services/configurators/__init__.py b/src/dstack/_internal/cli/services/configurators/__init__.py index cba23ee31a..94976e0505 100644 --- a/src/dstack/_internal/cli/services/configurators/__init__.py +++ b/src/dstack/_internal/cli/services/configurators/__init__.py @@ -1,6 +1,6 @@ import sys from pathlib import Path -from typing import Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple, Type import yaml @@ -24,7 +24,7 @@ APPLY_STDIN_NAME = "-" -apply_configurators_mapping: Dict[ApplyConfigurationType, Type[BaseApplyConfigurator]] = { +apply_configurators_mapping: Dict[ApplyConfigurationType, Type[BaseApplyConfigurator[Any]]] = { cls.TYPE: cls for cls in [ DevEnvironmentConfigurator, @@ -47,7 +47,7 @@ } -def get_apply_configurator_class(configurator_type: str) -> Type[BaseApplyConfigurator]: +def get_apply_configurator_class(configurator_type: str) -> Type[BaseApplyConfigurator[Any]]: return apply_configurators_mapping[ApplyConfigurationType(configurator_type)] diff --git a/src/dstack/_internal/cli/services/configurators/base.py b/src/dstack/_internal/cli/services/configurators/base.py index 39e34693eb..440a31d6c2 100644 --- a/src/dstack/_internal/cli/services/configurators/base.py +++ b/src/dstack/_internal/cli/services/configurators/base.py @@ -1,7 +1,7 @@ import argparse import os from abc import ABC, abstractmethod -from typing import List, Optional, Union, cast +from typing import Generic, List, Optional, TypeVar, Union, cast from dstack._internal.cli.services.args import env_var from dstack._internal.core.errors import ConfigurationError @@ -15,8 +15,10 @@ ArgsParser = Union[argparse._ArgumentGroup, argparse.ArgumentParser] +ApplyConfigurationT = TypeVar("ApplyConfigurationT", bound=AnyApplyConfiguration) -class BaseApplyConfigurator(ABC): + +class BaseApplyConfigurator(ABC, Generic[ApplyConfigurationT]): TYPE: ApplyConfigurationType def __init__(self, api_client: Client): @@ -25,7 +27,7 @@ def __init__(self, api_client: Client): @abstractmethod def apply_configuration( self, - conf: AnyApplyConfiguration, + conf: ApplyConfigurationT, configuration_path: str, command_args: argparse.Namespace, configurator_args: argparse.Namespace, @@ -48,7 +50,7 @@ def apply_configuration( @abstractmethod def delete_configuration( self, - conf: AnyApplyConfiguration, + conf: ApplyConfigurationT, configuration_path: str, command_args: argparse.Namespace, ): diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index cb9d7a2b87..6718f4f0f2 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -46,7 +46,7 @@ logger = get_logger(__name__) -class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator): +class FleetConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator[FleetConfiguration]): TYPE: ApplyConfigurationType = ApplyConfigurationType.FLEET def apply_configuration( diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py index 8651c79ce8..8a22277b17 100644 --- a/src/dstack/_internal/cli/services/configurators/gateway.py +++ b/src/dstack/_internal/cli/services/configurators/gateway.py @@ -27,7 +27,7 @@ from dstack.api._public import Client -class GatewayConfigurator(BaseApplyConfigurator): +class GatewayConfigurator(BaseApplyConfigurator[GatewayConfiguration]): TYPE: ApplyConfigurationType = ApplyConfigurationType.GATEWAY def apply_configuration( diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 6133999ef6..87f1c3a7df 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -3,7 +3,7 @@ import sys import time from pathlib import Path -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, TypeVar import gpuhunt from pydantic import parse_obj_as @@ -33,8 +33,7 @@ from dstack._internal.core.models.configurations import ( AnyRunConfiguration, ApplyConfigurationType, - BaseRunConfiguration, - BaseRunConfigurationWithPorts, + ConfigurationWithPortsParams, DevEnvironmentConfiguration, PortMapping, RunConfigurationType, @@ -63,13 +62,18 @@ logger = get_logger(__name__) +RunConfigurationT = TypeVar("RunConfigurationT", bound=AnyRunConfiguration) -class BaseRunConfigurator(ApplyEnvVarsConfiguratorMixin, BaseApplyConfigurator): + +class BaseRunConfigurator( + ApplyEnvVarsConfiguratorMixin, + BaseApplyConfigurator[RunConfigurationT], +): TYPE: ApplyConfigurationType def apply_configuration( self, - conf: BaseRunConfiguration, + conf: RunConfigurationT, configuration_path: str, command_args: argparse.Namespace, configurator_args: argparse.Namespace, @@ -89,7 +93,7 @@ def apply_configuration( "There are three options:\n" " - Run `dstack init` to initialize the current directory as a repo\n" " - Specify `--repo`\n" - " - Specify `--no-repo` to not use any repo and supress this warning" + " - Specify `--no-repo` to not use any repo and suppress this warning" " (this will be the default in the future versions)" ) if not command_args.yes and not confirm_ask("Continue without the repo?"): @@ -270,7 +274,7 @@ def apply_configuration( def delete_configuration( self, - conf: AnyRunConfiguration, + conf: RunConfigurationT, configuration_path: str, command_args: argparse.Namespace, ): @@ -296,7 +300,7 @@ def delete_configuration( console.print(f"Run [code]{conf.name}[/] deleted") @classmethod - def register_args(cls, parser: argparse.ArgumentParser, default_max_offers: int = 3): + def register_args(cls, parser: argparse.ArgumentParser): configuration_group = parser.add_argument_group(f"{cls.TYPE.value} Options") configuration_group.add_argument( "-n", @@ -308,7 +312,7 @@ def register_args(cls, parser: argparse.ArgumentParser, default_max_offers: int "--max-offers", help="Number of offers to show in the run plan", type=int, - default=default_max_offers, + default=3, ) cls.register_env_args(configuration_group) configuration_group.add_argument( @@ -336,7 +340,7 @@ def register_args(cls, parser: argparse.ArgumentParser, default_max_offers: int ) register_profile_args(parser) - def apply_args(self, conf: BaseRunConfiguration, args: argparse.Namespace, unknown: List[str]): + def apply_args(self, conf: RunConfigurationT, args: argparse.Namespace, unknown: List[str]): apply_profile_args(args, conf) if args.run_name: conf.name = args.run_name @@ -360,7 +364,7 @@ def interpolate_run_args(self, value: List[str], unknown): except InterpolatorError as e: raise ConfigurationError(e.args[0]) - def interpolate_env(self, conf: BaseRunConfiguration): + def interpolate_env(self, conf: RunConfigurationT): env_dict = conf.env.as_dict() interpolator = VariablesInterpolator({"env": env_dict}, skip=["secrets"]) try: @@ -380,7 +384,7 @@ def interpolate_env(self, conf: BaseRunConfiguration): except InterpolatorError as e: raise ConfigurationError(e.args[0]) - def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None: + def validate_gpu_vendor_and_image(self, conf: RunConfigurationT) -> None: """ Infers and sets `resources.gpu.vendor` if not set, requires `image` if the vendor is AMD. """ @@ -441,7 +445,7 @@ def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None: "`image` is required if `resources.gpu.vendor` is `tenstorrent`" ) - def validate_cpu_arch_and_image(self, conf: BaseRunConfiguration) -> None: + def validate_cpu_arch_and_image(self, conf: RunConfigurationT) -> None: """ Infers `resources.cpu.arch` if not set, requires `image` if the architecture is ARM. """ @@ -465,10 +469,9 @@ def validate_cpu_arch_and_image(self, conf: BaseRunConfiguration) -> None: raise ConfigurationError("`image` is required if `resources.cpu.arch` is `arm`") -class RunWithPortsConfigurator(BaseRunConfigurator): +class RunWithPortsConfiguratorMixin: @classmethod - def register_args(cls, parser: argparse.ArgumentParser): - super().register_args(parser) + def register_ports_args(cls, parser: argparse.ArgumentParser): parser.add_argument( "-p", "--port", @@ -485,29 +488,42 @@ def register_args(cls, parser: argparse.ArgumentParser): metavar="HOST", ) - def apply_args( - self, conf: BaseRunConfigurationWithPorts, args: argparse.Namespace, unknown: List[str] + def apply_ports_args( + self, + conf: ConfigurationWithPortsParams, + args: argparse.Namespace, ): - super().apply_args(conf, args, unknown) if args.ports: conf.ports = list(_merge_ports(conf.ports, args.ports).values()) -class TaskConfigurator(RunWithPortsConfigurator): +class TaskConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigurator): TYPE = ApplyConfigurationType.TASK + @classmethod + def register_args(cls, parser: argparse.ArgumentParser): + super().register_args(parser) + cls.register_ports_args(parser) + def apply_args(self, conf: TaskConfiguration, args: argparse.Namespace, unknown: List[str]): super().apply_args(conf, args, unknown) + self.apply_ports_args(conf, args) self.interpolate_run_args(conf.commands, unknown) -class DevEnvironmentConfigurator(RunWithPortsConfigurator): +class DevEnvironmentConfigurator(RunWithPortsConfiguratorMixin, BaseRunConfigurator): TYPE = ApplyConfigurationType.DEV_ENVIRONMENT + @classmethod + def register_args(cls, parser: argparse.ArgumentParser): + super().register_args(parser) + cls.register_ports_args(parser) + def apply_args( self, conf: DevEnvironmentConfiguration, args: argparse.Namespace, unknown: List[str] ): super().apply_args(conf, args, unknown) + self.apply_ports_args(conf, args) if conf.ide == "vscode" and conf.version is None: conf.version = _detect_vscode_version() if conf.version is None: @@ -677,6 +693,8 @@ def render_run_spec_diff(old_spec: RunSpec, new_spec: RunSpec) -> Optional[str]: if type(old_spec.profile) is not type(new_spec.profile): item = NestedListItem("Profile") else: + assert old_spec.profile is not None + assert new_spec.profile is not None item = NestedListItem( "Profile properties:", children=[ diff --git a/src/dstack/_internal/cli/services/configurators/volume.py b/src/dstack/_internal/cli/services/configurators/volume.py index 2a085477ed..72b21e5bb4 100644 --- a/src/dstack/_internal/cli/services/configurators/volume.py +++ b/src/dstack/_internal/cli/services/configurators/volume.py @@ -26,7 +26,7 @@ from dstack.api._public import Client -class VolumeConfigurator(BaseApplyConfigurator): +class VolumeConfigurator(BaseApplyConfigurator[VolumeConfiguration]): TYPE: ApplyConfigurationType = ApplyConfigurationType.VOLUME def apply_configuration( diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 4b44255d87..4b92d6f82b 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -466,7 +466,7 @@ def validate_shell(cls, v) -> Optional[str]: raise ValueError("The value must be `sh`, `bash`, or an absolute path") -class BaseRunConfigurationWithPorts(BaseRunConfiguration): +class ConfigurationWithPortsParams(CoreModel): ports: Annotated[ List[Union[ValidPort, constr(regex=r"^(?:[0-9]+|\*):[0-9]+$"), PortMapping]], Field(description="Port numbers/mapping to expose"), @@ -481,7 +481,7 @@ def convert_ports(cls, v) -> PortMapping: return v -class BaseRunConfigurationWithCommands(BaseRunConfiguration): +class ConfigurationWithCommandsParams(CoreModel): commands: Annotated[CommandsList, Field(description="The shell commands to run")] = [] @root_validator @@ -525,15 +525,18 @@ def parse_inactivity_duration( class DevEnvironmentConfiguration( - ProfileParams, BaseRunConfigurationWithPorts, DevEnvironmentConfigurationParams + ProfileParams, + BaseRunConfiguration, + ConfigurationWithPortsParams, + DevEnvironmentConfigurationParams, ): type: Literal["dev-environment"] = "dev-environment" - class Config(ProfileParams.Config, BaseRunConfigurationWithPorts.Config): + class Config(ProfileParams.Config, BaseRunConfiguration.Config): @staticmethod def schema_extra(schema: Dict[str, Any]): ProfileParams.Config.schema_extra(schema) - BaseRunConfigurationWithPorts.Config.schema_extra(schema) + BaseRunConfiguration.Config.schema_extra(schema) @validator("entrypoint") def validate_entrypoint(cls, v: Optional[str]) -> Optional[str]: @@ -548,8 +551,9 @@ class TaskConfigurationParams(CoreModel): class TaskConfiguration( ProfileParams, - BaseRunConfigurationWithCommands, - BaseRunConfigurationWithPorts, + BaseRunConfiguration, + ConfigurationWithCommandsParams, + ConfigurationWithPortsParams, TaskConfigurationParams, ): type: Literal["task"] = "task" @@ -697,19 +701,22 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: class ServiceConfiguration( - ProfileParams, BaseRunConfigurationWithCommands, ServiceConfigurationParams + ProfileParams, + BaseRunConfiguration, + ConfigurationWithCommandsParams, + ServiceConfigurationParams, ): type: Literal["service"] = "service" class Config( ProfileParams.Config, - BaseRunConfigurationWithCommands.Config, + BaseRunConfiguration.Config, ServiceConfigurationParams.Config, ): @staticmethod def schema_extra(schema: Dict[str, Any]): ProfileParams.Config.schema_extra(schema) - BaseRunConfigurationWithCommands.Config.schema_extra(schema) + BaseRunConfiguration.Config.schema_extra(schema) ServiceConfigurationParams.Config.schema_extra(schema) From 5d251a2d0d81ee07c75249891176dfd893b19560 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 17:15:44 +0500 Subject: [PATCH 48/54] Fix unbound spec_json --- src/dstack/plugins/builtin/rest_plugin/_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dstack/plugins/builtin/rest_plugin/_plugin.py b/src/dstack/plugins/builtin/rest_plugin/_plugin.py index 1a094147ec..210dd50e19 100644 --- a/src/dstack/plugins/builtin/rest_plugin/_plugin.py +++ b/src/dstack/plugins/builtin/rest_plugin/_plugin.py @@ -86,6 +86,7 @@ def _on_apply( spec: ApplySpec, excludes: Optional[Dict] = None, ) -> ApplySpec: + spec_json = None try: spec_request = request_cls(user=user, project=project, spec=spec) spec_json = self._call_plugin_service(spec_request, endpoint, excludes) From 0ae300c11ee7edbd5390563711a9c615f396cca8 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 17:17:09 +0500 Subject: [PATCH 49/54] Remove huggingface api --- src/dstack/api/huggingface/__init__.py | 73 -------------------------- 1 file changed, 73 deletions(-) delete mode 100644 src/dstack/api/huggingface/__init__.py diff --git a/src/dstack/api/huggingface/__init__.py b/src/dstack/api/huggingface/__init__.py deleted file mode 100644 index 83e5491172..0000000000 --- a/src/dstack/api/huggingface/__init__.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Dict, Optional - -from dstack.api._public.huggingface.finetuning.sft import FineTuningTask - - -class SFTFineTuningTask(FineTuningTask): - def __init__( - self, - model_name: str, - dataset_name: str, - env: Dict[str, str], - new_model_name: Optional[str] = None, - report_to: Optional[str] = None, - per_device_train_batch_size: int = 4, - per_device_eval_batch_size: int = 4, - gradient_accumulation_steps: int = 1, - learning_rate: float = 2e-4, - max_grad_norm: float = 0.3, - weight_decay: float = 0.001, - lora_alpha: int = 16, - lora_dropout: float = 0.1, - lora_r: int = 64, - max_seq_length: Optional[int] = None, - use_4bit: bool = True, - use_nested_quant: bool = True, - bnb_4bit_compute_dtype: str = "float16", - bnb_4bit_quant_type: str = "nf4", - num_train_epochs: float = 1, - fp16: bool = False, - bf16: bool = False, - packing: bool = False, - gradient_checkpointing: bool = True, - optim: str = "paged_adamw_32bit", - lr_scheduler_type: str = "constant", - max_steps: int = -1, - warmup_ratio: float = 0.03, - group_by_length: bool = True, - save_steps: int = 0, - logging_steps: int = 25, - ): - super().__init__( - model_name, - dataset_name, - new_model_name, - env, - report_to, - per_device_train_batch_size, - per_device_eval_batch_size, - gradient_accumulation_steps, - learning_rate, - max_grad_norm, - weight_decay, - lora_alpha, - lora_dropout, - lora_r, - max_seq_length, - use_4bit, - use_nested_quant, - bnb_4bit_compute_dtype, - bnb_4bit_quant_type, - num_train_epochs, - fp16, - bf16, - packing, - gradient_checkpointing, - optim, - lr_scheduler_type, - max_steps, - warmup_ratio, - group_by_length, - save_steps, - logging_steps, - ) From 9fa0937c845eae8656be92745598068245de9493 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 17:19:11 +0500 Subject: [PATCH 50/54] Type check plugins --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 16ecc0e31e..5141346ebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ ignore-case = true [tool.pyright] include = [ + "src/dstack/plugins", "src/dstack/_internal/server", "src/dstack/_internal/cli/services/configurators", ] From 95ea41cdc214e91166989d57a9c7ad89b48902ee Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 17:24:24 +0500 Subject: [PATCH 51/54] Fix BaseApplyConfigurator generics --- .../_internal/cli/services/configurators/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/cli/services/configurators/__init__.py b/src/dstack/_internal/cli/services/configurators/__init__.py index 94976e0505..91768bdcd3 100644 --- a/src/dstack/_internal/cli/services/configurators/__init__.py +++ b/src/dstack/_internal/cli/services/configurators/__init__.py @@ -1,6 +1,6 @@ import sys from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Type +from typing import Dict, Optional, Tuple, Type import yaml @@ -24,7 +24,9 @@ APPLY_STDIN_NAME = "-" -apply_configurators_mapping: Dict[ApplyConfigurationType, Type[BaseApplyConfigurator[Any]]] = { +apply_configurators_mapping: Dict[ + ApplyConfigurationType, Type[BaseApplyConfigurator[AnyApplyConfiguration]] +] = { cls.TYPE: cls for cls in [ DevEnvironmentConfigurator, @@ -47,7 +49,9 @@ } -def get_apply_configurator_class(configurator_type: str) -> Type[BaseApplyConfigurator[Any]]: +def get_apply_configurator_class( + configurator_type: str, +) -> Type[BaseApplyConfigurator[AnyApplyConfiguration]]: return apply_configurators_mapping[ApplyConfigurationType(configurator_type)] From b0230f4396a6cac9a88ae4f44806f63fa8e6862c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 20 Aug 2025 17:30:44 +0500 Subject: [PATCH 52/54] Type check core/services --- pyproject.toml | 1 + src/dstack/_internal/core/services/repos.py | 6 +++--- src/dstack/_internal/core/services/ssh/ports.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5141346ebb..6d184cbfe4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ ignore-case = true include = [ "src/dstack/plugins", "src/dstack/_internal/server", + "src/dstack/_internal/core/services", "src/dstack/_internal/cli/services/configurators", ] ignore = [ diff --git a/src/dstack/_internal/core/services/repos.py b/src/dstack/_internal/core/services/repos.py index bd6026d11b..61ff9b3abb 100644 --- a/src/dstack/_internal/core/services/repos.py +++ b/src/dstack/_internal/core/services/repos.py @@ -84,7 +84,7 @@ def get_local_repo_credentials( def check_remote_repo_credentials_https(url: GitRepoURL, oauth_token: str) -> RemoteRepoCreds: try: - git.cmd.Git().ls_remote(url.as_https(oauth_token), env=dict(GIT_TERMINAL_PROMPT="0")) + git.cmd.Git().ls_remote(url.as_https(oauth_token), env=dict(GIT_TERMINAL_PROMPT="0")) # type: ignore[attr-defined] except GitCommandError: masked = len(oauth_token[:-4]) * "*" + oauth_token[-4:] raise InvalidRepoCredentialsError( @@ -111,7 +111,7 @@ def check_remote_repo_credentials_ssh(url: GitRepoURL, identity_file: PathLike) private_key = f.read() try: - git.cmd.Git().ls_remote( + git.cmd.Git().ls_remote( # type: ignore[attr-defined] url.as_ssh(), env=dict(GIT_SSH_COMMAND=make_ssh_command_for_git(identity_file)) ) except GitCommandError: @@ -131,7 +131,7 @@ def get_default_branch(remote_url: str) -> Optional[str]: Get the default branch of a remote Git repository. """ try: - output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD") + output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD") # type: ignore[attr-defined] for line in output.splitlines(): if line.startswith("ref:"): return line.split()[1].split("/")[-1] diff --git a/src/dstack/_internal/core/services/ssh/ports.py b/src/dstack/_internal/core/services/ssh/ports.py index 1462958fe7..f0716e6158 100644 --- a/src/dstack/_internal/core/services/ssh/ports.py +++ b/src/dstack/_internal/core/services/ssh/ports.py @@ -74,7 +74,7 @@ def _listen(port: int) -> Optional[socket.socket]: try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if IS_WINDOWS: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) # type: ignore[attr-defined] sock.bind(("", port)) return sock except socket.error as e: From 1db1121ad0058af592b94ebb0a13600975fdf4bd Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 21 Aug 2025 11:14:56 +0500 Subject: [PATCH 53/54] Fix services.gpus typing --- src/dstack/_internal/server/routers/gpus.py | 7 +- src/dstack/_internal/server/services/gpus.py | 95 ++++++++++---------- 2 files changed, 49 insertions(+), 53 deletions(-) diff --git a/src/dstack/_internal/server/routers/gpus.py b/src/dstack/_internal/server/routers/gpus.py index 521ace1594..45f0e8bf1f 100644 --- a/src/dstack/_internal/server/routers/gpus.py +++ b/src/dstack/_internal/server/routers/gpus.py @@ -1,9 +1,7 @@ from typing import Tuple from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.gpus import ListGpusRequest, ListGpusResponse from dstack._internal.server.security.permissions import ProjectMember @@ -20,10 +18,7 @@ @project_router.post("/list", response_model=ListGpusResponse, response_model_exclude_none=True) async def list_gpus( body: ListGpusRequest, - session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> ListGpusResponse: _, project = user_project - return await list_gpus_grouped( - session=session, project=project, run_spec=body.run_spec, group_by=body.group_by - ) + return await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by) diff --git a/src/dstack/_internal/server/services/gpus.py b/src/dstack/_internal/server/services/gpus.py index 0ec347be00..c2ddcd2fd8 100644 --- a/src/dstack/_internal/server/services/gpus.py +++ b/src/dstack/_internal/server/services/gpus.py @@ -1,8 +1,8 @@ from typing import Dict, List, Literal, Optional, Tuple -from sqlalchemy.ext.asyncio import AsyncSession - from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceOfferWithAvailability from dstack._internal.core.models.profiles import SpotPolicy from dstack._internal.core.models.resources import Range @@ -15,10 +15,43 @@ ListGpusResponse, ) from dstack._internal.server.services.offers import get_offers_by_requirements +from dstack._internal.utils.common import get_or_error + + +async def list_gpus_grouped( + project: ProjectModel, + run_spec: RunSpec, + group_by: Optional[List[Literal["backend", "region", "count"]]] = None, +) -> ListGpusResponse: + """Retrieves available GPU specifications based on a run spec, with optional grouping.""" + offers = await _get_gpu_offers(project=project, run_spec=run_spec) + backend_gpus = _process_offers_into_backend_gpus(offers) + group_by_set = set(group_by) if group_by else set() + if "region" in group_by_set and "backend" not in group_by_set: + raise ServerClientError("Cannot group by 'region' without also grouping by 'backend'") + + # Determine grouping strategy based on combination + has_backend = "backend" in group_by_set + has_region = "region" in group_by_set + has_count = "count" in group_by_set + if has_backend and has_region and has_count: + gpus = _get_gpus_grouped_by_backend_region_and_count(backend_gpus) + elif has_backend and has_count: + gpus = _get_gpus_grouped_by_backend_and_count(backend_gpus) + elif has_backend and has_region: + gpus = _get_gpus_grouped_by_backend_and_region(backend_gpus) + elif has_backend: + gpus = _get_gpus_grouped_by_backend(backend_gpus) + elif has_count: + gpus = _get_gpus_grouped_by_count(backend_gpus) + else: + gpus = _get_gpus_with_no_grouping(backend_gpus) + + return ListGpusResponse(gpus=gpus) async def _get_gpu_offers( - session: AsyncSession, project: ProjectModel, run_spec: RunSpec + project: ProjectModel, run_spec: RunSpec ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: """Fetches all available instance offers that match the run spec's GPU requirements.""" profile = run_spec.merged_profile @@ -28,7 +61,6 @@ async def _get_gpu_offers( spot=get_policy_map(profile.spot_policy, default=SpotPolicy.AUTO), reservation=profile.reservation, ) - return await get_offers_by_requirements( project=project, profile=profile, @@ -45,7 +77,7 @@ def _process_offers_into_backend_gpus( offers: List[Tuple[Backend, InstanceOfferWithAvailability]], ) -> List[BackendGpus]: """Transforms raw offers into a structured list of BackendGpus, aggregating GPU info.""" - backend_data: Dict[str, Dict] = {} + backend_data: Dict[BackendType, Dict] = {} for backend, offer in offers: backend_type = backend.TYPE @@ -111,7 +143,7 @@ def _process_offers_into_backend_gpus( return backend_gpus_list -def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: str): +def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: BackendType): """Updates an existing GpuGroup with new data from another GPU offer.""" spot_type: Literal["spot", "on-demand"] = "spot" if gpu.spot else "on-demand" @@ -122,6 +154,12 @@ def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: str): if row.backends and backend_type not in row.backends: row.backends.append(backend_type) + # FIXME: Consider using non-optional range + assert row.count.min is not None + assert row.count.max is not None + assert row.price.min is not None + assert row.price.max is not None + row.count.min = min(row.count.min, gpu.count) row.count.max = max(row.count.max, gpu.count) per_gpu_price = gpu.price / gpu.count @@ -194,7 +232,7 @@ def _get_gpus_grouped_by_backend(backend_gpus: List[BackendGpus]) -> List[GpuGro not any(av.is_available() for av in g.availability), g.price.min, g.price.max, - g.backend.value, + get_or_error(g.backend).value, g.name, g.memory_mib, ), @@ -229,7 +267,7 @@ def _get_gpus_grouped_by_backend_and_region(backend_gpus: List[BackendGpus]) -> not any(av.is_available() for av in g.availability), g.price.min, g.price.max, - g.backend.value, + get_or_error(g.backend).value, g.region, g.name, g.memory_mib, @@ -299,7 +337,7 @@ def _get_gpus_grouped_by_backend_and_count(backend_gpus: List[BackendGpus]) -> L not any(av.is_available() for av in g.availability), g.price.min, g.price.max, - g.backend.value, + get_or_error(g.backend).value, g.count.min, g.name, g.memory_mib, @@ -344,47 +382,10 @@ def _get_gpus_grouped_by_backend_region_and_count( not any(av.is_available() for av in g.availability), g.price.min, g.price.max, - g.backend.value, + get_or_error(g.backend).value, g.region, g.count.min, g.name, g.memory_mib, ), ) - - -async def list_gpus_grouped( - session: AsyncSession, - project: ProjectModel, - run_spec: RunSpec, - group_by: Optional[List[Literal["backend", "region", "count"]]] = None, -) -> ListGpusResponse: - """Retrieves available GPU specifications based on a run spec, with optional grouping.""" - offers = await _get_gpu_offers(session, project, run_spec) - backend_gpus = _process_offers_into_backend_gpus(offers) - - group_by_set = set(group_by) if group_by else set() - - if "region" in group_by_set and "backend" not in group_by_set: - from dstack._internal.core.errors import ServerClientError - - raise ServerClientError("Cannot group by 'region' without also grouping by 'backend'") - - # Determine grouping strategy based on combination - has_backend = "backend" in group_by_set - has_region = "region" in group_by_set - has_count = "count" in group_by_set - if has_backend and has_region and has_count: - gpus = _get_gpus_grouped_by_backend_region_and_count(backend_gpus) - elif has_backend and has_count: - gpus = _get_gpus_grouped_by_backend_and_count(backend_gpus) - elif has_backend and has_region: - gpus = _get_gpus_grouped_by_backend_and_region(backend_gpus) - elif has_backend: - gpus = _get_gpus_grouped_by_backend(backend_gpus) - elif has_count: - gpus = _get_gpus_grouped_by_count(backend_gpus) - else: - gpus = _get_gpus_with_no_grouping(backend_gpus) - - return ListGpusResponse(gpus=gpus) From 8f51c0577d54fbdd2f2a2cdaf4121f995c92f91d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 21 Aug 2025 12:06:29 +0500 Subject: [PATCH 54/54] Document pyright in Contributing --- contributing/DEVELOPMENT.md | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/contributing/DEVELOPMENT.md b/contributing/DEVELOPMENT.md index 775a0691e4..4428f811b3 100644 --- a/contributing/DEVELOPMENT.md +++ b/contributing/DEVELOPMENT.md @@ -25,12 +25,26 @@ uv sync --all-extras Alternatively, if you want to manage virtual environments by yourself, you can install `dstack` into the activated virtual environment with `uv sync --all-extras --active`. -## 4. (Recommended) Install pre-commits: +## 4. (Recommended) Install pre-commit hooks: + +Code formatting and linting can be done automatically on each commit with `pre-commit` hooks: ```shell uv run pre-commit install ``` -## 5. Frontend +## 5. (Recommended) Use pyright: + +The CI runs `pyright` for type checking `dstack` Python code. +So we recommend you configure your IDE to use `pyright`/`pylance` with `standard` type checking mode. + +You can also install `pyright` and run it from the CLI: + +```shell +uv tool install pyright +pyright -p . +``` + +## 6. Frontend See [FRONTEND.md](FRONTEND.md) for the details on how to build and develop the frontend.