From 4c3411b61b93405f16a6ac432c87eaab4c70b570 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 21 Aug 2025 14:27:36 +0500 Subject: [PATCH] Make Configurator generic --- .../core/backends/aws/configurator.py | 18 +++++--- .../core/backends/azure/configurator.py | 18 +++++--- .../core/backends/base/configurator.py | 45 +++++++++---------- .../core/backends/cloudrift/configurator.py | 20 ++++++--- .../core/backends/cudo/configurator.py | 18 +++++--- .../core/backends/datacrunch/configurator.py | 20 ++++++--- .../core/backends/gcp/configurator.py | 18 +++++--- .../core/backends/hotaisle/configurator.py | 20 ++++++--- .../core/backends/kubernetes/configurator.py | 20 ++++++--- .../core/backends/lambdalabs/configurator.py | 18 +++++--- .../core/backends/nebius/configurator.py | 18 +++++--- .../core/backends/oci/configurator.py | 18 +++++--- .../core/backends/runpod/configurator.py | 18 +++++--- .../backends/template/configurator.py.jinja | 18 +++++--- .../core/backends/tensordock/configurator.py | 20 ++++++--- .../core/backends/vastai/configurator.py | 18 +++++--- .../core/backends/vultr/configurator.py | 15 +++++-- .../server/services/backends/__init__.py | 4 +- 18 files changed, 210 insertions(+), 134 deletions(-) diff --git a/src/dstack/_internal/core/backends/aws/configurator.py b/src/dstack/_internal/core/backends/aws/configurator.py index 894b7f573a..059c65098e 100644 --- a/src/dstack/_internal/core/backends/aws/configurator.py +++ b/src/dstack/_internal/core/backends/aws/configurator.py @@ -7,7 +7,6 @@ from dstack._internal.core.backends.aws import auth, compute, resources from dstack._internal.core.backends.aws.backend import AWSBackend from dstack._internal.core.backends.aws.models import ( - AnyAWSBackendConfig, AWSAccessKeyCreds, AWSBackendConfig, AWSBackendConfigWithCreds, @@ -52,7 +51,12 @@ MAIN_REGION = "us-east-1" -class AWSConfigurator(Configurator): +class AWSConfigurator( + Configurator[ + AWSBackendConfig, + AWSBackendConfigWithCreds, + ] +): TYPE = BackendType.AWS BACKEND_CLASS = AWSBackend @@ -87,12 +91,12 @@ def create_backend( auth=AWSCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyAWSBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> AWSBackendConfigWithCreds: + config = self._get_config(record) + return AWSBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> AWSBackendConfig: config = self._get_config(record) - if include_creds: - return AWSBackendConfigWithCreds.__response__.parse_obj(config) return AWSBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> AWSBackend: diff --git a/src/dstack/_internal/core/backends/azure/configurator.py b/src/dstack/_internal/core/backends/azure/configurator.py index b25244de8b..a44d897ea2 100644 --- a/src/dstack/_internal/core/backends/azure/configurator.py +++ b/src/dstack/_internal/core/backends/azure/configurator.py @@ -24,7 +24,6 @@ from dstack._internal.core.backends.azure import utils as azure_utils from dstack._internal.core.backends.azure.backend import AzureBackend from dstack._internal.core.backends.azure.models import ( - AnyAzureBackendConfig, AzureBackendConfig, AzureBackendConfigWithCreds, AzureClientCreds, @@ -71,7 +70,12 @@ MAIN_LOCATION = "eastus" -class AzureConfigurator(Configurator): +class AzureConfigurator( + Configurator[ + AzureBackendConfig, + AzureBackendConfigWithCreds, + ] +): TYPE = BackendType.AZURE BACKEND_CLASS = AzureBackend @@ -130,12 +134,12 @@ def create_backend( auth=AzureCreds.parse_obj(config.creds).__root__.json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyAzureBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> AzureBackendConfigWithCreds: + config = self._get_config(record) + return AzureBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> AzureBackendConfig: config = self._get_config(record) - if include_creds: - return AzureBackendConfigWithCreds.__response__.parse_obj(config) return AzureBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> AzureBackend: diff --git a/src/dstack/_internal/core/backends/base/configurator.py b/src/dstack/_internal/core/backends/base/configurator.py index f31e978a31..246d3d6118 100644 --- a/src/dstack/_internal/core/backends/base/configurator.py +++ b/src/dstack/_internal/core/backends/base/configurator.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, ClassVar, List, Literal, Optional, overload +from typing import Any, ClassVar, Generic, List, Optional, TypeVar from uuid import UUID from dstack._internal.core.backends.base.backend import Backend from dstack._internal.core.backends.models import ( - AnyBackendConfig, AnyBackendConfigWithCreds, AnyBackendConfigWithoutCreds, ) @@ -16,6 +15,11 @@ # We'll introduce our own base limit that can be customized per backend if required. TAGS_MAX_NUM = 25 +BackendConfigWithoutCredsT = TypeVar( + "BackendConfigWithoutCredsT", bound=AnyBackendConfigWithoutCreds +) +BackendConfigWithCredsT = TypeVar("BackendConfigWithCredsT", bound=AnyBackendConfigWithCreds) + class BackendRecord(CoreModel): """ @@ -40,7 +44,7 @@ class StoredBackendRecord(BackendRecord): backend_id: UUID -class Configurator(ABC): +class Configurator(ABC, Generic[BackendConfigWithoutCredsT, BackendConfigWithCredsT]): """ `Configurator` is responsible for configuring backends and initializing `Backend` instances from backend configs. @@ -53,7 +57,7 @@ class Configurator(ABC): BACKEND_CLASS: ClassVar[type[Backend]] @abstractmethod - def validate_config(self, config: AnyBackendConfigWithCreds, default_creds_enabled: bool): + def validate_config(self, config: BackendConfigWithCredsT, default_creds_enabled: bool): """ Validates backend config including backend creds and other parameters. Raises `ServerClientError` or its subclass if config is invalid. @@ -62,9 +66,7 @@ def validate_config(self, config: AnyBackendConfigWithCreds, default_creds_enabl pass @abstractmethod - def create_backend( - self, project_name: str, config: AnyBackendConfigWithCreds - ) -> BackendRecord: + def create_backend(self, project_name: str, config: BackendConfigWithCredsT) -> BackendRecord: """ Sets up backend given backend config and returns text-encoded config and creds to be stored in the DB. @@ -78,26 +80,23 @@ def create_backend( """ pass - @overload - def get_backend_config( - self, record: StoredBackendRecord, include_creds: Literal[False] - ) -> AnyBackendConfigWithoutCreds: - pass - - @overload - def get_backend_config( - self, record: StoredBackendRecord, include_creds: Literal[True] - ) -> AnyBackendConfigWithCreds: + @abstractmethod + def get_backend_config_with_creds( + self, record: StoredBackendRecord + ) -> BackendConfigWithCredsT: + """ + Constructs `BackendConfig` with credentials included. + Used internally and when project admins need to see backend's creds. + """ pass @abstractmethod - def get_backend_config( - self, record: StoredBackendRecord, include_creds: bool - ) -> AnyBackendConfig: + def get_backend_config_without_creds( + self, record: StoredBackendRecord + ) -> BackendConfigWithoutCredsT: """ - Constructs `BackendConfig` to be returned in API responses. - Project admins may need to see backend's creds. In this case `include_creds` will be `True`. - Otherwise, no sensitive information should be included. + Constructs `BackendConfig` without sensitive information. + Used for API responses where creds should not be exposed. """ pass diff --git a/src/dstack/_internal/core/backends/cloudrift/configurator.py b/src/dstack/_internal/core/backends/cloudrift/configurator.py index 62410b0ecd..b6097d1654 100644 --- a/src/dstack/_internal/core/backends/cloudrift/configurator.py +++ b/src/dstack/_internal/core/backends/cloudrift/configurator.py @@ -8,7 +8,6 @@ from dstack._internal.core.backends.cloudrift.api_client import RiftClient from dstack._internal.core.backends.cloudrift.backend import CloudRiftBackend from dstack._internal.core.backends.cloudrift.models import ( - AnyCloudRiftBackendConfig, AnyCloudRiftCreds, CloudRiftBackendConfig, CloudRiftBackendConfigWithCreds, @@ -21,7 +20,12 @@ ) -class CloudRiftConfigurator(Configurator): +class CloudRiftConfigurator( + Configurator[ + CloudRiftBackendConfig, + CloudRiftBackendConfigWithCreds, + ] +): TYPE = BackendType.CLOUDRIFT BACKEND_CLASS = CloudRiftBackend @@ -40,12 +44,14 @@ def create_backend( auth=CloudRiftCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyCloudRiftBackendConfig: + def get_backend_config_with_creds( + self, record: BackendRecord + ) -> CloudRiftBackendConfigWithCreds: + config = self._get_config(record) + return CloudRiftBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> CloudRiftBackendConfig: config = self._get_config(record) - if include_creds: - return CloudRiftBackendConfigWithCreds.__response__.parse_obj(config) return CloudRiftBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> CloudRiftBackend: diff --git a/src/dstack/_internal/core/backends/cudo/configurator.py b/src/dstack/_internal/core/backends/cudo/configurator.py index 2f435a7826..a4d5b7d94d 100644 --- a/src/dstack/_internal/core/backends/cudo/configurator.py +++ b/src/dstack/_internal/core/backends/cudo/configurator.py @@ -8,7 +8,6 @@ from dstack._internal.core.backends.cudo import api_client from dstack._internal.core.backends.cudo.backend import CudoBackend from dstack._internal.core.backends.cudo.models import ( - AnyCudoBackendConfig, CudoBackendConfig, CudoBackendConfigWithCreds, CudoConfig, @@ -18,7 +17,12 @@ from dstack._internal.core.models.backends.base import BackendType -class CudoConfigurator(Configurator): +class CudoConfigurator( + Configurator[ + CudoBackendConfig, + CudoBackendConfigWithCreds, + ] +): TYPE = BackendType.CUDO BACKEND_CLASS = CudoBackend @@ -35,12 +39,12 @@ def create_backend( auth=CudoCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyCudoBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> CudoBackendConfigWithCreds: + config = self._get_config(record) + return CudoBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> CudoBackendConfig: config = self._get_config(record) - if include_creds: - return CudoBackendConfigWithCreds.__response__.parse_obj(config) return CudoBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> CudoBackend: diff --git a/src/dstack/_internal/core/backends/datacrunch/configurator.py b/src/dstack/_internal/core/backends/datacrunch/configurator.py index fde42a6412..f31d5a69a1 100644 --- a/src/dstack/_internal/core/backends/datacrunch/configurator.py +++ b/src/dstack/_internal/core/backends/datacrunch/configurator.py @@ -10,7 +10,6 @@ ) from dstack._internal.core.backends.datacrunch.backend import DataCrunchBackend from dstack._internal.core.backends.datacrunch.models import ( - AnyDataCrunchBackendConfig, DataCrunchBackendConfig, DataCrunchBackendConfigWithCreds, DataCrunchConfig, @@ -22,7 +21,12 @@ ) -class DataCrunchConfigurator(Configurator): +class DataCrunchConfigurator( + Configurator[ + DataCrunchBackendConfig, + DataCrunchBackendConfigWithCreds, + ] +): TYPE = BackendType.DATACRUNCH BACKEND_CLASS = DataCrunchBackend @@ -41,12 +45,14 @@ def create_backend( auth=DataCrunchCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyDataCrunchBackendConfig: + def get_backend_config_with_creds( + self, record: BackendRecord + ) -> DataCrunchBackendConfigWithCreds: + config = self._get_config(record) + return DataCrunchBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> DataCrunchBackendConfig: config = self._get_config(record) - if include_creds: - return DataCrunchBackendConfigWithCreds.__response__.parse_obj(config) return DataCrunchBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> DataCrunchBackend: diff --git a/src/dstack/_internal/core/backends/gcp/configurator.py b/src/dstack/_internal/core/backends/gcp/configurator.py index f034a869dc..f59decd122 100644 --- a/src/dstack/_internal/core/backends/gcp/configurator.py +++ b/src/dstack/_internal/core/backends/gcp/configurator.py @@ -11,7 +11,6 @@ from dstack._internal.core.backends.gcp import auth, resources from dstack._internal.core.backends.gcp.backend import GCPBackend from dstack._internal.core.backends.gcp.models import ( - AnyGCPBackendConfig, GCPBackendConfig, GCPBackendConfigWithCreds, GCPConfig, @@ -109,7 +108,12 @@ MAIN_REGION = "us-east1" -class GCPConfigurator(Configurator): +class GCPConfigurator( + Configurator[ + GCPBackendConfig, + GCPBackendConfigWithCreds, + ] +): TYPE = BackendType.GCP BACKEND_CLASS = GCPBackend @@ -147,12 +151,12 @@ def create_backend( auth=GCPCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyGCPBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> GCPBackendConfigWithCreds: + config = self._get_config(record) + return GCPBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> GCPBackendConfig: config = self._get_config(record) - if include_creds: - return GCPBackendConfigWithCreds.__response__.parse_obj(config) return GCPBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> GCPBackend: diff --git a/src/dstack/_internal/core/backends/hotaisle/configurator.py b/src/dstack/_internal/core/backends/hotaisle/configurator.py index c7a6a6006e..8f7a6f537f 100644 --- a/src/dstack/_internal/core/backends/hotaisle/configurator.py +++ b/src/dstack/_internal/core/backends/hotaisle/configurator.py @@ -7,7 +7,6 @@ from dstack._internal.core.backends.hotaisle.api_client import HotAisleAPIClient from dstack._internal.core.backends.hotaisle.backend import HotAisleBackend from dstack._internal.core.backends.hotaisle.models import ( - AnyHotAisleBackendConfig, AnyHotAisleCreds, HotAisleBackendConfig, HotAisleBackendConfigWithCreds, @@ -20,7 +19,12 @@ ) -class HotAisleConfigurator(Configurator): +class HotAisleConfigurator( + Configurator[ + HotAisleBackendConfig, + HotAisleBackendConfigWithCreds, + ] +): TYPE = BackendType.HOTAISLE BACKEND_CLASS = HotAisleBackend @@ -37,12 +41,14 @@ def create_backend( auth=HotAisleCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyHotAisleBackendConfig: + def get_backend_config_with_creds( + self, record: BackendRecord + ) -> HotAisleBackendConfigWithCreds: + config = self._get_config(record) + return HotAisleBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> HotAisleBackendConfig: config = self._get_config(record) - if include_creds: - return HotAisleBackendConfigWithCreds.__response__.parse_obj(config) return HotAisleBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> HotAisleBackend: diff --git a/src/dstack/_internal/core/backends/kubernetes/configurator.py b/src/dstack/_internal/core/backends/kubernetes/configurator.py index 5dc99c22dc..93c9965362 100644 --- a/src/dstack/_internal/core/backends/kubernetes/configurator.py +++ b/src/dstack/_internal/core/backends/kubernetes/configurator.py @@ -6,7 +6,6 @@ from dstack._internal.core.backends.kubernetes import utils as kubernetes_utils from dstack._internal.core.backends.kubernetes.backend import KubernetesBackend from dstack._internal.core.backends.kubernetes.models import ( - AnyKubernetesBackendConfig, KubernetesBackendConfig, KubernetesBackendConfigWithCreds, KubernetesConfig, @@ -18,7 +17,12 @@ logger = get_logger(__name__) -class KubernetesConfigurator(Configurator): +class KubernetesConfigurator( + Configurator[ + KubernetesBackendConfig, + KubernetesBackendConfigWithCreds, + ] +): TYPE = BackendType.KUBERNETES BACKEND_CLASS = KubernetesBackend @@ -40,12 +44,14 @@ def create_backend( auth="", ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyKubernetesBackendConfig: + def get_backend_config_with_creds( + self, record: BackendRecord + ) -> KubernetesBackendConfigWithCreds: + config = self._get_config(record) + return KubernetesBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> KubernetesBackendConfig: config = self._get_config(record) - if include_creds: - return KubernetesBackendConfigWithCreds.__response__.parse_obj(config) return KubernetesBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> KubernetesBackend: diff --git a/src/dstack/_internal/core/backends/lambdalabs/configurator.py b/src/dstack/_internal/core/backends/lambdalabs/configurator.py index c4eb73041a..7c99cb2139 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/configurator.py +++ b/src/dstack/_internal/core/backends/lambdalabs/configurator.py @@ -8,7 +8,6 @@ from dstack._internal.core.backends.lambdalabs import api_client from dstack._internal.core.backends.lambdalabs.backend import LambdaBackend from dstack._internal.core.backends.lambdalabs.models import ( - AnyLambdaBackendConfig, LambdaBackendConfig, LambdaBackendConfigWithCreds, LambdaConfig, @@ -20,7 +19,12 @@ ) -class LambdaConfigurator(Configurator): +class LambdaConfigurator( + Configurator[ + LambdaBackendConfig, + LambdaBackendConfigWithCreds, + ] +): TYPE = BackendType.LAMBDA BACKEND_CLASS = LambdaBackend @@ -37,12 +41,12 @@ def create_backend( auth=LambdaCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyLambdaBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> LambdaBackendConfigWithCreds: + config = self._get_config(record) + return LambdaBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> LambdaBackendConfig: config = self._get_config(record) - if include_creds: - return LambdaBackendConfigWithCreds.__response__.parse_obj(config) return LambdaBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> LambdaBackend: diff --git a/src/dstack/_internal/core/backends/nebius/configurator.py b/src/dstack/_internal/core/backends/nebius/configurator.py index 39f86bbcd1..331e153445 100644 --- a/src/dstack/_internal/core/backends/nebius/configurator.py +++ b/src/dstack/_internal/core/backends/nebius/configurator.py @@ -11,7 +11,6 @@ from dstack._internal.core.backends.nebius.backend import NebiusBackend from dstack._internal.core.backends.nebius.fabrics import get_all_infiniband_fabrics from dstack._internal.core.backends.nebius.models import ( - AnyNebiusBackendConfig, NebiusBackendConfig, NebiusBackendConfigWithCreds, NebiusConfig, @@ -22,7 +21,12 @@ from dstack._internal.core.models.backends.base import BackendType -class NebiusConfigurator(Configurator): +class NebiusConfigurator( + Configurator[ + NebiusBackendConfig, + NebiusBackendConfigWithCreds, + ] +): TYPE = BackendType.NEBIUS BACKEND_CLASS = NebiusBackend @@ -60,12 +64,12 @@ def create_backend( auth=NebiusCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyNebiusBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> NebiusBackendConfigWithCreds: + config = self._get_config(record) + return NebiusBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> NebiusBackendConfig: config = self._get_config(record) - if include_creds: - return NebiusBackendConfigWithCreds.__response__.parse_obj(config) return NebiusBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> NebiusBackend: diff --git a/src/dstack/_internal/core/backends/oci/configurator.py b/src/dstack/_internal/core/backends/oci/configurator.py index 004d2b45cb..4558e8bf96 100644 --- a/src/dstack/_internal/core/backends/oci/configurator.py +++ b/src/dstack/_internal/core/backends/oci/configurator.py @@ -10,7 +10,6 @@ from dstack._internal.core.backends.oci.backend import OCIBackend from dstack._internal.core.backends.oci.exceptions import any_oci_exception from dstack._internal.core.backends.oci.models import ( - AnyOCIBackendConfig, OCIBackendConfig, OCIBackendConfigWithCreds, OCIConfig, @@ -42,7 +41,12 @@ ) -class OCIConfigurator(Configurator): +class OCIConfigurator( + Configurator[ + OCIBackendConfig, + OCIBackendConfigWithCreds, + ] +): TYPE = BackendType.OCI BACKEND_CLASS = OCIBackend @@ -83,12 +87,12 @@ def create_backend( auth=OCICreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyOCIBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> OCIBackendConfigWithCreds: + config = self._get_config(record) + return OCIBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> OCIBackendConfig: config = self._get_config(record) - if include_creds: - return OCIBackendConfigWithCreds.__response__.parse_obj(config) return OCIBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> OCIBackend: diff --git a/src/dstack/_internal/core/backends/runpod/configurator.py b/src/dstack/_internal/core/backends/runpod/configurator.py index e37bb4d34a..df023f7179 100644 --- a/src/dstack/_internal/core/backends/runpod/configurator.py +++ b/src/dstack/_internal/core/backends/runpod/configurator.py @@ -8,7 +8,6 @@ from dstack._internal.core.backends.runpod import api_client from dstack._internal.core.backends.runpod.backend import RunpodBackend from dstack._internal.core.backends.runpod.models import ( - AnyRunpodBackendConfig, RunpodBackendConfig, RunpodBackendConfigWithCreds, RunpodConfig, @@ -18,7 +17,12 @@ from dstack._internal.core.models.backends.base import BackendType -class RunpodConfigurator(Configurator): +class RunpodConfigurator( + Configurator[ + RunpodBackendConfig, + RunpodBackendConfigWithCreds, + ] +): TYPE = BackendType.RUNPOD BACKEND_CLASS = RunpodBackend @@ -35,12 +39,12 @@ def create_backend( auth=RunpodCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyRunpodBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> RunpodBackendConfigWithCreds: + config = self._get_config(record) + return RunpodBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> RunpodBackendConfig: config = self._get_config(record) - if include_creds: - return RunpodBackendConfigWithCreds.__response__.parse_obj(config) return RunpodBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> RunpodBackend: diff --git a/src/dstack/_internal/core/backends/template/configurator.py.jinja b/src/dstack/_internal/core/backends/template/configurator.py.jinja index 6330ed7008..47ea303903 100644 --- a/src/dstack/_internal/core/backends/template/configurator.py.jinja +++ b/src/dstack/_internal/core/backends/template/configurator.py.jinja @@ -7,7 +7,6 @@ from dstack._internal.core.backends.base.configurator import ( ) from dstack._internal.core.backends.{{ backend_name|lower }}.backend import {{ backend_name }}Backend from dstack._internal.core.backends.{{ backend_name|lower }}.models import ( - Any{{ backend_name }}BackendConfig, Any{{ backend_name }}Creds, {{ backend_name }}BackendConfig, {{ backend_name }}BackendConfigWithCreds, @@ -20,7 +19,12 @@ from dstack._internal.core.models.backends.base import ( ) -class {{ backend_name }}Configurator(Configurator): +class {{ backend_name }}Configurator( + Configurator[ + {{ backend_name }}BackendConfig, + {{ backend_name }}BackendConfigWithCreds, + ] +): TYPE = BackendType.{{ backend_name|upper }} BACKEND_CLASS = {{ backend_name }}Backend @@ -40,12 +44,12 @@ class {{ backend_name }}Configurator(Configurator): auth={{ backend_name }}Creds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> Any{{ backend_name }}BackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> {{ backend_name }}BackendConfigWithCreds: + config = self._get_config(record) + return {{ backend_name }}BackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> {{ backend_name }}BackendConfig: config = self._get_config(record) - if include_creds: - return {{ backend_name }}BackendConfigWithCreds.__response__.parse_obj(config) return {{ backend_name }}BackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> {{ backend_name }}Backend: diff --git a/src/dstack/_internal/core/backends/tensordock/configurator.py b/src/dstack/_internal/core/backends/tensordock/configurator.py index fcba7c8c75..0582b63431 100644 --- a/src/dstack/_internal/core/backends/tensordock/configurator.py +++ b/src/dstack/_internal/core/backends/tensordock/configurator.py @@ -8,7 +8,6 @@ from dstack._internal.core.backends.tensordock import api_client from dstack._internal.core.backends.tensordock.backend import TensorDockBackend from dstack._internal.core.backends.tensordock.models import ( - AnyTensorDockBackendConfig, TensorDockBackendConfig, TensorDockBackendConfigWithCreds, TensorDockConfig, @@ -23,7 +22,12 @@ REGIONS = [] -class TensorDockConfigurator(Configurator): +class TensorDockConfigurator( + Configurator[ + TensorDockBackendConfig, + TensorDockBackendConfigWithCreds, + ] +): TYPE = BackendType.TENSORDOCK BACKEND_CLASS = TensorDockBackend @@ -44,12 +48,14 @@ def create_backend( auth=TensorDockCreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyTensorDockBackendConfig: + def get_backend_config_with_creds( + self, record: BackendRecord + ) -> TensorDockBackendConfigWithCreds: + config = self._get_config(record) + return TensorDockBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> TensorDockBackendConfig: config = self._get_config(record) - if include_creds: - return TensorDockBackendConfigWithCreds.__response__.parse_obj(config) return TensorDockBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> TensorDockBackend: diff --git a/src/dstack/_internal/core/backends/vastai/configurator.py b/src/dstack/_internal/core/backends/vastai/configurator.py index 6cbb181d31..997842f1a8 100644 --- a/src/dstack/_internal/core/backends/vastai/configurator.py +++ b/src/dstack/_internal/core/backends/vastai/configurator.py @@ -8,7 +8,6 @@ from dstack._internal.core.backends.vastai import api_client from dstack._internal.core.backends.vastai.backend import VastAIBackend from dstack._internal.core.backends.vastai.models import ( - AnyVastAIBackendConfig, VastAIBackendConfig, VastAIBackendConfigWithCreds, VastAIConfig, @@ -23,7 +22,12 @@ REGIONS = [] -class VastAIConfigurator(Configurator): +class VastAIConfigurator( + Configurator[ + VastAIBackendConfig, + VastAIBackendConfigWithCreds, + ] +): TYPE = BackendType.VASTAI BACKEND_CLASS = VastAIBackend @@ -42,12 +46,12 @@ def create_backend( auth=VastAICreds.parse_obj(config.creds).json(), ) - def get_backend_config( - self, record: BackendRecord, include_creds: bool - ) -> AnyVastAIBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> VastAIBackendConfigWithCreds: + config = self._get_config(record) + return VastAIBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> VastAIBackendConfig: config = self._get_config(record) - if include_creds: - return VastAIBackendConfigWithCreds.__response__.parse_obj(config) return VastAIBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> VastAIBackend: diff --git a/src/dstack/_internal/core/backends/vultr/configurator.py b/src/dstack/_internal/core/backends/vultr/configurator.py index b135e23500..39f98a03e8 100644 --- a/src/dstack/_internal/core/backends/vultr/configurator.py +++ b/src/dstack/_internal/core/backends/vultr/configurator.py @@ -23,7 +23,12 @@ REGIONS = [] -class VultrConfigurator(Configurator): +class VultrConfigurator( + Configurator[ + VultrBackendConfig, + VultrBackendConfigWithCreds, + ] +): TYPE = BackendType.VULTR BACKEND_CLASS = VultrBackend @@ -42,10 +47,12 @@ def create_backend( auth=VultrCreds.parse_obj(config.creds).json(), ) - def get_backend_config(self, record: BackendRecord, include_creds: bool) -> VultrBackendConfig: + def get_backend_config_with_creds(self, record: BackendRecord) -> VultrBackendConfigWithCreds: + config = self._get_config(record) + return VultrBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> VultrBackendConfig: config = self._get_config(record) - if include_creds: - return VultrBackendConfigWithCreds.__response__.parse_obj(config) return VultrBackendConfig.__response__.parse_obj(config) def get_backend(self, record: BackendRecord) -> VultrBackend: diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 9711a503bf..7613d75550 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -135,7 +135,7 @@ def get_backend_config_with_creds_from_backend_model( backend_model: BackendModel, ) -> AnyBackendConfigWithCreds: backend_record = get_stored_backend_record(backend_model) - backend_config = configurator.get_backend_config(backend_record, include_creds=True) + backend_config = configurator.get_backend_config_with_creds(backend_record) return backend_config @@ -144,7 +144,7 @@ def get_backend_config_without_creds_from_backend_model( backend_model: BackendModel, ) -> AnyBackendConfigWithoutCreds: backend_record = get_stored_backend_record(backend_model) - backend_config = configurator.get_backend_config(backend_record, include_creds=False) + backend_config = configurator.get_backend_config_without_creds(backend_record) return backend_config