Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/dstack/_internal/core/backends/aws/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dstack._internal.core.backends.aws import auth, compute, resources
from dstack._internal.core.backends.aws.backend import AWSBackend
from dstack._internal.core.backends.aws.models import (
AnyAWSBackendConfig,
AWSAccessKeyCreds,
AWSBackendConfig,
AWSBackendConfigWithCreds,
Expand Down Expand Up @@ -52,7 +51,12 @@
MAIN_REGION = "us-east-1"


class AWSConfigurator(Configurator):
class AWSConfigurator(
Configurator[
AWSBackendConfig,
AWSBackendConfigWithCreds,
]
):
TYPE = BackendType.AWS
BACKEND_CLASS = AWSBackend

Expand Down Expand Up @@ -87,12 +91,12 @@ def create_backend(
auth=AWSCreds.parse_obj(config.creds).json(),
)

def get_backend_config(
self, record: BackendRecord, include_creds: bool
) -> AnyAWSBackendConfig:
def get_backend_config_with_creds(self, record: BackendRecord) -> AWSBackendConfigWithCreds:
config = self._get_config(record)
return AWSBackendConfigWithCreds.__response__.parse_obj(config)

def get_backend_config_without_creds(self, record: BackendRecord) -> AWSBackendConfig:
config = self._get_config(record)
if include_creds:
return AWSBackendConfigWithCreds.__response__.parse_obj(config)
return AWSBackendConfig.__response__.parse_obj(config)

def get_backend(self, record: BackendRecord) -> AWSBackend:
Expand Down
18 changes: 11 additions & 7 deletions src/dstack/_internal/core/backends/azure/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from dstack._internal.core.backends.azure import utils as azure_utils
from dstack._internal.core.backends.azure.backend import AzureBackend
from dstack._internal.core.backends.azure.models import (
AnyAzureBackendConfig,
AzureBackendConfig,
AzureBackendConfigWithCreds,
AzureClientCreds,
Expand Down Expand Up @@ -71,7 +70,12 @@
MAIN_LOCATION = "eastus"


class AzureConfigurator(Configurator):
class AzureConfigurator(
Configurator[
AzureBackendConfig,
AzureBackendConfigWithCreds,
]
):
TYPE = BackendType.AZURE
BACKEND_CLASS = AzureBackend

Expand Down Expand Up @@ -130,12 +134,12 @@ def create_backend(
auth=AzureCreds.parse_obj(config.creds).__root__.json(),
)

def get_backend_config(
self, record: BackendRecord, include_creds: bool
) -> AnyAzureBackendConfig:
def get_backend_config_with_creds(self, record: BackendRecord) -> AzureBackendConfigWithCreds:
config = self._get_config(record)
return AzureBackendConfigWithCreds.__response__.parse_obj(config)

def get_backend_config_without_creds(self, record: BackendRecord) -> AzureBackendConfig:
config = self._get_config(record)
if include_creds:
return AzureBackendConfigWithCreds.__response__.parse_obj(config)
return AzureBackendConfig.__response__.parse_obj(config)

def get_backend(self, record: BackendRecord) -> AzureBackend:
Expand Down
45 changes: 22 additions & 23 deletions src/dstack/_internal/core/backends/base/configurator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, ClassVar, List, Literal, Optional, overload
from typing import Any, ClassVar, Generic, List, Optional, TypeVar
from uuid import UUID

from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.models import (
AnyBackendConfig,
AnyBackendConfigWithCreds,
AnyBackendConfigWithoutCreds,
)
Expand All @@ -16,6 +15,11 @@
# We'll introduce our own base limit that can be customized per backend if required.
TAGS_MAX_NUM = 25

BackendConfigWithoutCredsT = TypeVar(
"BackendConfigWithoutCredsT", bound=AnyBackendConfigWithoutCreds
)
BackendConfigWithCredsT = TypeVar("BackendConfigWithCredsT", bound=AnyBackendConfigWithCreds)


class BackendRecord(CoreModel):
"""
Expand All @@ -40,7 +44,7 @@ class StoredBackendRecord(BackendRecord):
backend_id: UUID


class Configurator(ABC):
class Configurator(ABC, Generic[BackendConfigWithoutCredsT, BackendConfigWithCredsT]):
"""
`Configurator` is responsible for configuring backends
and initializing `Backend` instances from backend configs.
Expand All @@ -53,7 +57,7 @@ class Configurator(ABC):
BACKEND_CLASS: ClassVar[type[Backend]]

@abstractmethod
def validate_config(self, config: AnyBackendConfigWithCreds, default_creds_enabled: bool):
def validate_config(self, config: BackendConfigWithCredsT, default_creds_enabled: bool):
"""
Validates backend config including backend creds and other parameters.
Raises `ServerClientError` or its subclass if config is invalid.
Expand All @@ -62,9 +66,7 @@ def validate_config(self, config: AnyBackendConfigWithCreds, default_creds_enabl
pass

@abstractmethod
def create_backend(
self, project_name: str, config: AnyBackendConfigWithCreds
) -> BackendRecord:
def create_backend(self, project_name: str, config: BackendConfigWithCredsT) -> BackendRecord:
"""
Sets up backend given backend config and returns
text-encoded config and creds to be stored in the DB.
Expand All @@ -78,26 +80,23 @@ def create_backend(
"""
pass

@overload
def get_backend_config(
self, record: StoredBackendRecord, include_creds: Literal[False]
) -> AnyBackendConfigWithoutCreds:
pass

@overload
def get_backend_config(
self, record: StoredBackendRecord, include_creds: Literal[True]
) -> AnyBackendConfigWithCreds:
@abstractmethod
def get_backend_config_with_creds(
self, record: StoredBackendRecord
) -> BackendConfigWithCredsT:
"""
Constructs `BackendConfig` with credentials included.
Used internally and when project admins need to see backend's creds.
"""
pass

@abstractmethod
def get_backend_config(
self, record: StoredBackendRecord, include_creds: bool
) -> AnyBackendConfig:
def get_backend_config_without_creds(
self, record: StoredBackendRecord
) -> BackendConfigWithoutCredsT:
"""
Constructs `BackendConfig` to be returned in API responses.
Project admins may need to see backend's creds. In this case `include_creds` will be `True`.
Otherwise, no sensitive information should be included.
Constructs `BackendConfig` without sensitive information.
Used for API responses where creds should not be exposed.
"""
pass

Expand Down
20 changes: 13 additions & 7 deletions src/dstack/_internal/core/backends/cloudrift/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dstack._internal.core.backends.cloudrift.api_client import RiftClient
from dstack._internal.core.backends.cloudrift.backend import CloudRiftBackend
from dstack._internal.core.backends.cloudrift.models import (
AnyCloudRiftBackendConfig,
AnyCloudRiftCreds,
CloudRiftBackendConfig,
CloudRiftBackendConfigWithCreds,
Expand All @@ -21,7 +20,12 @@
)


class CloudRiftConfigurator(Configurator):
class CloudRiftConfigurator(
Configurator[
CloudRiftBackendConfig,
CloudRiftBackendConfigWithCreds,
]
):
TYPE = BackendType.CLOUDRIFT
BACKEND_CLASS = CloudRiftBackend

Expand All @@ -40,12 +44,14 @@ def create_backend(
auth=CloudRiftCreds.parse_obj(config.creds).json(),
)

def get_backend_config(
self, record: BackendRecord, include_creds: bool
) -> AnyCloudRiftBackendConfig:
def get_backend_config_with_creds(
self, record: BackendRecord
) -> CloudRiftBackendConfigWithCreds:
config = self._get_config(record)
return CloudRiftBackendConfigWithCreds.__response__.parse_obj(config)

def get_backend_config_without_creds(self, record: BackendRecord) -> CloudRiftBackendConfig:
config = self._get_config(record)
if include_creds:
return CloudRiftBackendConfigWithCreds.__response__.parse_obj(config)
return CloudRiftBackendConfig.__response__.parse_obj(config)

def get_backend(self, record: BackendRecord) -> CloudRiftBackend:
Expand Down
18 changes: 11 additions & 7 deletions src/dstack/_internal/core/backends/cudo/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dstack._internal.core.backends.cudo import api_client
from dstack._internal.core.backends.cudo.backend import CudoBackend
from dstack._internal.core.backends.cudo.models import (
AnyCudoBackendConfig,
CudoBackendConfig,
CudoBackendConfigWithCreds,
CudoConfig,
Expand All @@ -18,7 +17,12 @@
from dstack._internal.core.models.backends.base import BackendType


class CudoConfigurator(Configurator):
class CudoConfigurator(
Configurator[
CudoBackendConfig,
CudoBackendConfigWithCreds,
]
):
TYPE = BackendType.CUDO
BACKEND_CLASS = CudoBackend

Expand All @@ -35,12 +39,12 @@ def create_backend(
auth=CudoCreds.parse_obj(config.creds).json(),
)

def get_backend_config(
self, record: BackendRecord, include_creds: bool
) -> AnyCudoBackendConfig:
def get_backend_config_with_creds(self, record: BackendRecord) -> CudoBackendConfigWithCreds:
config = self._get_config(record)
return CudoBackendConfigWithCreds.__response__.parse_obj(config)

def get_backend_config_without_creds(self, record: BackendRecord) -> CudoBackendConfig:
config = self._get_config(record)
if include_creds:
return CudoBackendConfigWithCreds.__response__.parse_obj(config)
return CudoBackendConfig.__response__.parse_obj(config)

def get_backend(self, record: BackendRecord) -> CudoBackend:
Expand Down
20 changes: 13 additions & 7 deletions src/dstack/_internal/core/backends/datacrunch/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from dstack._internal.core.backends.datacrunch.backend import DataCrunchBackend
from dstack._internal.core.backends.datacrunch.models import (
AnyDataCrunchBackendConfig,
DataCrunchBackendConfig,
DataCrunchBackendConfigWithCreds,
DataCrunchConfig,
Expand All @@ -22,7 +21,12 @@
)


class DataCrunchConfigurator(Configurator):
class DataCrunchConfigurator(
Configurator[
DataCrunchBackendConfig,
DataCrunchBackendConfigWithCreds,
]
):
TYPE = BackendType.DATACRUNCH
BACKEND_CLASS = DataCrunchBackend

Expand All @@ -41,12 +45,14 @@ def create_backend(
auth=DataCrunchCreds.parse_obj(config.creds).json(),
)

def get_backend_config(
self, record: BackendRecord, include_creds: bool
) -> AnyDataCrunchBackendConfig:
def get_backend_config_with_creds(
self, record: BackendRecord
) -> DataCrunchBackendConfigWithCreds:
config = self._get_config(record)
return DataCrunchBackendConfigWithCreds.__response__.parse_obj(config)

def get_backend_config_without_creds(self, record: BackendRecord) -> DataCrunchBackendConfig:
config = self._get_config(record)
if include_creds:
return DataCrunchBackendConfigWithCreds.__response__.parse_obj(config)
return DataCrunchBackendConfig.__response__.parse_obj(config)

def get_backend(self, record: BackendRecord) -> DataCrunchBackend:
Expand Down
18 changes: 11 additions & 7 deletions src/dstack/_internal/core/backends/gcp/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from dstack._internal.core.backends.gcp import auth, resources
from dstack._internal.core.backends.gcp.backend import GCPBackend
from dstack._internal.core.backends.gcp.models import (
AnyGCPBackendConfig,
GCPBackendConfig,
GCPBackendConfigWithCreds,
GCPConfig,
Expand Down Expand Up @@ -109,7 +108,12 @@
MAIN_REGION = "us-east1"


class GCPConfigurator(Configurator):
class GCPConfigurator(
Configurator[
GCPBackendConfig,
GCPBackendConfigWithCreds,
]
):
TYPE = BackendType.GCP
BACKEND_CLASS = GCPBackend

Expand Down Expand Up @@ -147,12 +151,12 @@ def create_backend(
auth=GCPCreds.parse_obj(config.creds).json(),
)

def get_backend_config(
self, record: BackendRecord, include_creds: bool
) -> AnyGCPBackendConfig:
def get_backend_config_with_creds(self, record: BackendRecord) -> GCPBackendConfigWithCreds:
config = self._get_config(record)
return GCPBackendConfigWithCreds.__response__.parse_obj(config)

def get_backend_config_without_creds(self, record: BackendRecord) -> GCPBackendConfig:
config = self._get_config(record)
if include_creds:
return GCPBackendConfigWithCreds.__response__.parse_obj(config)
return GCPBackendConfig.__response__.parse_obj(config)

def get_backend(self, record: BackendRecord) -> GCPBackend:
Expand Down
20 changes: 13 additions & 7 deletions src/dstack/_internal/core/backends/hotaisle/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dstack._internal.core.backends.hotaisle.api_client import HotAisleAPIClient
from dstack._internal.core.backends.hotaisle.backend import HotAisleBackend
from dstack._internal.core.backends.hotaisle.models import (
AnyHotAisleBackendConfig,
AnyHotAisleCreds,
HotAisleBackendConfig,
HotAisleBackendConfigWithCreds,
Expand All @@ -20,7 +19,12 @@
)


class HotAisleConfigurator(Configurator):
class HotAisleConfigurator(
Configurator[
HotAisleBackendConfig,
HotAisleBackendConfigWithCreds,
]
):
TYPE = BackendType.HOTAISLE
BACKEND_CLASS = HotAisleBackend

Expand All @@ -37,12 +41,14 @@ def create_backend(
auth=HotAisleCreds.parse_obj(config.creds).json(),
)

def get_backend_config(
self, record: BackendRecord, include_creds: bool
) -> AnyHotAisleBackendConfig:
def get_backend_config_with_creds(
self, record: BackendRecord
) -> HotAisleBackendConfigWithCreds:
config = self._get_config(record)
return HotAisleBackendConfigWithCreds.__response__.parse_obj(config)

def get_backend_config_without_creds(self, record: BackendRecord) -> HotAisleBackendConfig:
config = self._get_config(record)
if include_creds:
return HotAisleBackendConfigWithCreds.__response__.parse_obj(config)
return HotAisleBackendConfig.__response__.parse_obj(config)

def get_backend(self, record: BackendRecord) -> HotAisleBackend:
Expand Down
Loading
Loading