diff --git a/docs/docs/concepts/backends.md b/docs/docs/concepts/backends.md index e0f52e6e67..21208d74b9 100644 --- a/docs/docs/concepts/backends.md +++ b/docs/docs/concepts/backends.md @@ -913,6 +913,28 @@ projects: +### CloudRift + +Log into your [CloudRift :material-arrow-top-right-thin:{ .external }](https://console.cloudrift.ai/) console, click `API Keys` in the sidebar and click the button to create a new API key. + +Ensure you've created a project with CloudRift. + +Then proceed to configuring the backend. + +
+ +```yaml +projects: + - name: main + backends: + - type: cloudrift + creds: + type: api_key + api_key: rift_2prgY1d0laOrf2BblTwx2B2d1zcf1zIp4tZYpj5j88qmNgz38pxNlpX3vAo +``` + +
+ ## On-prem servers ### SSH fleets diff --git a/docs/docs/reference/server/config.yml.md b/docs/docs/reference/server/config.yml.md index a021c138a0..fbe378d8cd 100644 --- a/docs/docs/reference/server/config.yml.md +++ b/docs/docs/reference/server/config.yml.md @@ -315,6 +315,23 @@ to configure [backends](../../concepts/backends.md) and other [sever-level setti type: required: true +##### `projects[n].backends[type=cloudrift]` { #cloudrift data-toc-label="cloudrift" } + +#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftBackendConfigWithCreds + overrides: + show_root_heading: false + type: + required: true + item_id_prefix: cloudrift- + +###### `projects[n].backends[type=cloudrift].creds` { #cloudrift-creds data-toc-label="creds" } + +#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftAPIKeyCreds + overrides: + show_root_heading: false + type: + required: true + ### `encryption` { #encryption data-toc-label="encryption" } #SCHEMA# dstack._internal.server.services.config.EncryptionConfig diff --git a/src/dstack/_internal/core/backends/cloudrift/__init__.py b/src/dstack/_internal/core/backends/cloudrift/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/_internal/core/backends/cloudrift/api_client.py b/src/dstack/_internal/core/backends/cloudrift/api_client.py new file mode 100644 index 0000000000..51bbfafd69 --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/api_client.py @@ -0,0 +1,208 @@ +import os +import re +from typing import Any, Dict, List, Mapping, Optional, Union + +import requests +from packaging import version +from requests import Response + +from dstack._internal.core.errors import BackendError, BackendInvalidCredentialsError +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai" +CLOUDRIFT_API_VERSION = "2025-05-29" + + +class RiftClient: + def __init__(self, api_key: Optional[str] = None): + self.public_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "api/v1") + self.api_key = api_key + + def validate_api_key(self) -> bool: + """ + Validates the API key by making a request to the server. + Returns True if the API key is valid, False otherwise. + """ + try: + response = self._make_request("auth/me") + if isinstance(response, dict): + return "email" in response + return False + except BackendInvalidCredentialsError: + return False + except Exception as e: + logger.error(f"Error validating API key: {e}") + return False + + def get_instance_types(self) -> List[Dict]: + request_data = {"selector": {"ByServiceAndLocation": {"services": ["vm"]}}} + response_data = self._make_request("instance-types/list", request_data) + if isinstance(response_data, dict): + return response_data.get("instance_types", []) + return [] + + def list_recipes(self) -> List[Dict]: + request_data = {} + response_data = self._make_request("recipes/list", request_data) + if isinstance(response_data, dict): + return response_data.get("groups", []) + return [] + + def get_vm_recipies(self) -> List[Dict]: + """ + Retrieves a list of VM recipes from the CloudRift API. + Returns a list of dictionaries containing recipe information. + """ + recipe_group = self.list_recipes() + vm_recipes = [] + for group in recipe_group: + tags = group.get("tags", []) + has_vm = "vm" in map(str.lower, tags) + if group.get("name", "").lower() != "linux" or not has_vm: + continue + + recipes = group.get("recipes", []) + for recipe in recipes: + details = recipe.get("details", {}) + if details.get("VirtualMachine", False): + vm_recipes.append(recipe) + + return vm_recipes + + def get_vm_image_url(self) -> Optional[str]: + recipes = self.get_vm_recipies() + ubuntu_images = [] + for recipe in recipes: + has_nvidia_driver = "nvidia-driver" in recipe.get("tags", []) + if not has_nvidia_driver: + continue + + recipe_name = recipe.get("name", "") + if "Ubuntu" not in recipe_name: + continue + + url = recipe["details"].get("VirtualMachine", {}).get("image_url", None) + version_match = re.search(r".* (\d+\.\d+)", recipe_name) + if url and version_match and version_match.group(1): + ubuntu_version = version.parse(version_match.group(1)) + ubuntu_images.append((ubuntu_version, url)) + + ubuntu_images.sort(key=lambda x: x[0]) # Sort by version + if ubuntu_images: + return ubuntu_images[-1][1] + + return None + + def deploy_instance( + self, instance_type: str, region: str, ssh_keys: List[str], cmd: str + ) -> List[str]: + image_url = self.get_vm_image_url() + if not image_url: + raise BackendError("No suitable VM image found.") + + request_data = { + "config": { + "VirtualMachine": { + "cloudinit_commands": cmd, + "image_url": image_url, + "ssh_key": {"PublicKeys": ssh_keys}, + } + }, + "selector": { + "ByInstanceTypeAndLocation": { + "datacenters": [region], + "instance_type": instance_type, + } + }, + "with_public_ip": True, + } + logger.debug("Deploying instance with request data: %s", request_data) + + response_data = self._make_request("instances/rent", request_data) + if isinstance(response_data, dict): + return response_data.get("instance_ids", []) + return [] + + def list_instances(self, instance_ids: Optional[List[str]] = None) -> List[Dict]: + request_data = { + "selector": { + "ByStatus": ["Initializing", "Active", "Deactivating"], + } + } + logger.debug("Listing instances with request data: %s", request_data) + response_data = self._make_request("instances/list", request_data) + if isinstance(response_data, dict): + return response_data.get("instances", []) + + return [] + + def get_instance_by_id(self, instance_id: str) -> Optional[Dict]: + request_data = {"selector": {"ById": [instance_id]}} + logger.debug("Getting instance with request data: %s", request_data) + response_data = self._make_request("instances/list", request_data) + if isinstance(response_data, dict): + instances = response_data.get("instances", []) + if isinstance(instances, list) and len(instances) > 0: + return instances[0] + + return None + + def terminate_instance(self, instance_id: str) -> bool: + request_data = {"selector": {"ById": [instance_id]}} + logger.debug("Terminating instance with request data: %s", request_data) + response_data = self._make_request("instances/terminate", request_data) + if isinstance(response_data, dict): + info = response_data.get("terminated", []) + return len(info) > 0 + + return False + + def _make_request( + self, + endpoint: str, + data: Optional[Mapping[str, Any]] = None, + method: str = "POST", + **kwargs, + ) -> Union[Mapping[str, Any], str, Response]: + headers = {} + if self.api_key is not None: + headers["X-API-Key"] = self.api_key + + version = CLOUDRIFT_API_VERSION + full_url = f"{self.public_api_root}/{endpoint}" + + try: + response = requests.request( + method, + full_url, + headers=headers, + json={"version": version, "data": data}, + timeout=15, + **kwargs, + ) + + if not response.ok: + response.raise_for_status() + try: + response_json = response.json() + if isinstance(response_json, str): + return response_json + if version is not None and version < response_json["version"]: + logger.warning( + "The API version %s is lower than the server version %s. ", + version, + response_json["version"], + ) + return response_json["data"] + except requests.exceptions.JSONDecodeError: + return response + except requests.HTTPError as e: + if e.response is not None and e.response.status_code in ( + requests.codes.forbidden, + requests.codes.unauthorized, + ): + raise BackendInvalidCredentialsError(e.response.text) + raise diff --git a/src/dstack/_internal/core/backends/cloudrift/backend.py b/src/dstack/_internal/core/backends/cloudrift/backend.py new file mode 100644 index 0000000000..cca0c620c6 --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/backend.py @@ -0,0 +1,16 @@ +from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.cloudrift.compute import CloudRiftCompute +from dstack._internal.core.backends.cloudrift.models import CloudRiftConfig +from dstack._internal.core.models.backends.base import BackendType + + +class CloudRiftBackend(Backend): + TYPE = BackendType.CLOUDRIFT + COMPUTE_CLASS = CloudRiftCompute + + def __init__(self, config: CloudRiftConfig): + self.config = config + self._compute = CloudRiftCompute(self.config) + + def compute(self) -> CloudRiftCompute: + return self._compute diff --git a/src/dstack/_internal/core/backends/cloudrift/compute.py b/src/dstack/_internal/core/backends/cloudrift/compute.py new file mode 100644 index 0000000000..03d9fd74c6 --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/compute.py @@ -0,0 +1,138 @@ +from typing import Dict, List, Optional + +from dstack._internal.core.backends.base.backend import Compute +from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, + get_shim_commands, +) +from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.cloudrift.api_client import RiftClient +from dstack._internal.core.backends.cloudrift.models import CloudRiftConfig +from dstack._internal.core.errors import ComputeError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceConfiguration, + InstanceOffer, + InstanceOfferWithAvailability, +) +from dstack._internal.core.models.placement import PlacementGroup +from dstack._internal.core.models.runs import JobProvisioningData, Requirements +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class CloudRiftCompute( + ComputeWithCreateInstanceSupport, + Compute, +): + def __init__(self, config: CloudRiftConfig): + super().__init__() + self.config = config + self.client = RiftClient(self.config.creds.api_key) + + def get_offers( + self, requirements: Optional[Requirements] = None + ) -> List[InstanceOfferWithAvailability]: + offers = get_catalog_offers( + backend=BackendType.CLOUDRIFT, + locations=self.config.regions or None, + requirements=requirements, + ) + + offers_with_availabilities = self._get_offers_with_availability(offers) + return offers_with_availabilities + + def _get_offers_with_availability( + self, offers: List[InstanceOffer] + ) -> List[InstanceOfferWithAvailability]: + instance_types_with_availabilities: List[Dict] = self.client.get_instance_types() + + region_availabilities = {} + for instance_type in instance_types_with_availabilities: + for variant in instance_type["variants"]: + for dc, count in variant["available_nodes_per_dc"].items(): + if count > 0: + key = (variant["name"], dc) + region_availabilities[key] = InstanceAvailability.AVAILABLE + + availability_offers = [] + for offer in offers: + key = (offer.instance.name, offer.region) + availability = region_availabilities.get(key, InstanceAvailability.NOT_AVAILABLE) + availability_offers.append( + InstanceOfferWithAvailability(**offer.dict(), availability=availability) + ) + + return availability_offers + + def create_instance( + self, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + placement_group: Optional[PlacementGroup], + ) -> JobProvisioningData: + commands = get_shim_commands(authorized_keys=instance_config.get_public_keys()) + startup_script = " ".join([" && ".join(commands)]) + logger.debug( + f"Creating instance for offer {instance_offer.instance.name} in region {instance_offer.region} with commands: {startup_script}" + ) + + instance_ids = self.client.deploy_instance( + instance_type=instance_offer.instance.name, + region=instance_offer.region, + ssh_keys=instance_config.get_public_keys(), + cmd=startup_script, + ) + + if len(instance_ids) == 0: + raise ComputeError( + f"Failed to create instance for offer {instance_offer.instance.name} in region {instance_offer.region}." + ) + + return JobProvisioningData( + backend=instance_offer.backend, + instance_type=instance_offer.instance, + instance_id=instance_ids[0], + hostname=None, + internal_ip=None, + region=instance_offer.region, + price=instance_offer.price, + username="riftuser", + ssh_port=22, + dockerized=True, + ssh_proxy=None, + backend_data=None, + ) + + def update_provisioning_data( + self, + provisioning_data: JobProvisioningData, + project_ssh_public_key: str, + project_ssh_private_key: str, + ): + instance_info = self.client.get_instance_by_id(provisioning_data.instance_id) + + if not instance_info: + return + + instance_mode = instance_info.get("node_mode", "") + + if not instance_mode or instance_mode != "VirtualMachine": + return + + vms = instance_info.get("virtual_machines", []) + if len(vms) == 0: + return + + vm_ready = vms[0].get("ready", False) + if vm_ready: + provisioning_data.hostname = instance_info.get("host_address", None) + + def terminate_instance( + self, instance_id: str, region: str, backend_data: Optional[str] = None + ): + terminated = self.client.terminate_instance(instance_id=instance_id) + if not terminated: + raise ComputeError(f"Failed to terminate instance {instance_id} in region {region}.") diff --git a/src/dstack/_internal/core/backends/cloudrift/configurator.py b/src/dstack/_internal/core/backends/cloudrift/configurator.py new file mode 100644 index 0000000000..62410b0ecd --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/configurator.py @@ -0,0 +1,66 @@ +import json + +from dstack._internal.core.backends.base.configurator import ( + BackendRecord, + Configurator, + raise_invalid_credentials_error, +) +from dstack._internal.core.backends.cloudrift.api_client import RiftClient +from dstack._internal.core.backends.cloudrift.backend import CloudRiftBackend +from dstack._internal.core.backends.cloudrift.models import ( + AnyCloudRiftBackendConfig, + AnyCloudRiftCreds, + CloudRiftBackendConfig, + CloudRiftBackendConfigWithCreds, + CloudRiftConfig, + CloudRiftCreds, + CloudRiftStoredConfig, +) +from dstack._internal.core.models.backends.base import ( + BackendType, +) + + +class CloudRiftConfigurator(Configurator): + TYPE = BackendType.CLOUDRIFT + BACKEND_CLASS = CloudRiftBackend + + def validate_config( + self, config: CloudRiftBackendConfigWithCreds, default_creds_enabled: bool + ): + self._validate_creds(config.creds) + + def create_backend( + self, project_name: str, config: CloudRiftBackendConfigWithCreds + ) -> BackendRecord: + return BackendRecord( + config=CloudRiftStoredConfig( + **CloudRiftBackendConfig.__response__.parse_obj(config).dict() + ).json(), + auth=CloudRiftCreds.parse_obj(config.creds).json(), + ) + + def get_backend_config( + self, record: BackendRecord, include_creds: bool + ) -> AnyCloudRiftBackendConfig: + config = self._get_config(record) + if include_creds: + return CloudRiftBackendConfigWithCreds.__response__.parse_obj(config) + return CloudRiftBackendConfig.__response__.parse_obj(config) + + def get_backend(self, record: BackendRecord) -> CloudRiftBackend: + config = self._get_config(record) + return CloudRiftBackend(config=config) + + def _get_config(self, record: BackendRecord) -> CloudRiftConfig: + return CloudRiftConfig.__response__( + **json.loads(record.config), + creds=CloudRiftCreds.parse_raw(record.auth), + ) + + def _validate_creds(self, creds: AnyCloudRiftCreds): + if not isinstance(creds, CloudRiftCreds): + raise_invalid_credentials_error(fields=[["creds"]]) + client = RiftClient(creds.api_key) + if not client.validate_api_key(): + raise_invalid_credentials_error(fields=[["creds", "api_key"]]) diff --git a/src/dstack/_internal/core/backends/cloudrift/models.py b/src/dstack/_internal/core/backends/cloudrift/models.py new file mode 100644 index 0000000000..62a6726f9a --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/models.py @@ -0,0 +1,40 @@ +from typing import Annotated, List, Literal, Optional, Union + +from pydantic import Field + +from dstack._internal.core.models.common import CoreModel + + +class CloudRiftAPIKeyCreds(CoreModel): + type: Annotated[Literal["api_key"], Field(description="The type of credentials")] = "api_key" + api_key: Annotated[str, Field(description="The API key")] + + +AnyCloudRiftCreds = CloudRiftAPIKeyCreds +CloudRiftCreds = AnyCloudRiftCreds + + +class CloudRiftBackendConfig(CoreModel): + type: Annotated[ + Literal["cloudrift"], + Field(description="The type of backend"), + ] = "cloudrift" + regions: Annotated[ + Optional[List[str]], + Field(description="The list of CloudRift regions. Omit to use all regions"), + ] = None + + +class CloudRiftBackendConfigWithCreds(CloudRiftBackendConfig): + creds: Annotated[AnyCloudRiftCreds, Field(description="The credentials")] + + +AnyCloudRiftBackendConfig = Union[CloudRiftBackendConfig, CloudRiftBackendConfigWithCreds] + + +class CloudRiftStoredConfig(CloudRiftBackendConfig): + pass + + +class CloudRiftConfig(CloudRiftStoredConfig): + creds: AnyCloudRiftCreds diff --git a/src/dstack/_internal/core/backends/configurators.py b/src/dstack/_internal/core/backends/configurators.py index 43a215d3ab..571d010529 100644 --- a/src/dstack/_internal/core/backends/configurators.py +++ b/src/dstack/_internal/core/backends/configurators.py @@ -20,6 +20,15 @@ except ImportError: pass +try: + from dstack._internal.core.backends.cloudrift.configurator import ( + CloudRiftConfigurator, + ) + + _CONFIGURATOR_CLASSES.append(CloudRiftConfigurator) +except ImportError: + pass + try: from dstack._internal.core.backends.cudo.configurator import ( CudoConfigurator, diff --git a/src/dstack/_internal/core/backends/models.py b/src/dstack/_internal/core/backends/models.py index 18567592a5..0b5779db78 100644 --- a/src/dstack/_internal/core/backends/models.py +++ b/src/dstack/_internal/core/backends/models.py @@ -8,6 +8,10 @@ AzureBackendConfig, AzureBackendConfigWithCreds, ) +from dstack._internal.core.backends.cloudrift.models import ( + CloudRiftBackendConfig, + CloudRiftBackendConfigWithCreds, +) from dstack._internal.core.backends.cudo.models import ( CudoBackendConfig, CudoBackendConfigWithCreds, @@ -65,6 +69,7 @@ AnyBackendConfigWithoutCreds = Union[ AWSBackendConfig, AzureBackendConfig, + CloudRiftBackendConfig, CudoBackendConfig, DataCrunchBackendConfig, GCPBackendConfig, @@ -86,6 +91,7 @@ AnyBackendConfigWithCreds = Union[ AWSBackendConfigWithCreds, AzureBackendConfigWithCreds, + CloudRiftBackendConfigWithCreds, CudoBackendConfigWithCreds, DataCrunchBackendConfigWithCreds, GCPBackendConfigWithCreds, @@ -106,6 +112,7 @@ AnyBackendFileConfigWithCreds = Union[ AWSBackendConfigWithCreds, AzureBackendConfigWithCreds, + CloudRiftBackendConfigWithCreds, CudoBackendConfigWithCreds, DataCrunchBackendConfigWithCreds, GCPBackendFileConfigWithCreds, diff --git a/src/dstack/_internal/core/models/backends/base.py b/src/dstack/_internal/core/models/backends/base.py index 47df1163a1..78aafb142c 100644 --- a/src/dstack/_internal/core/models/backends/base.py +++ b/src/dstack/_internal/core/models/backends/base.py @@ -6,6 +6,7 @@ class BackendType(str, enum.Enum): Attributes: AWS (BackendType): Amazon Web Services AZURE (BackendType): Microsoft Azure + CLOUDRIFT (BackendType): CloudRift CUDO (BackendType): Cudo DSTACK (BackendType): dstack Sky GCP (BackendType): Google Cloud Platform @@ -22,6 +23,7 @@ class BackendType(str, enum.Enum): AWS = "aws" AZURE = "azure" + CLOUDRIFT = "cloudrift" CUDO = "cudo" DATACRUNCH = "datacrunch" DSTACK = "dstack" diff --git a/src/tests/_internal/core/backends/cloudrift/__init__.py b/src/tests/_internal/core/backends/cloudrift/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/core/backends/cloudrift/test_configurator.py b/src/tests/_internal/core/backends/cloudrift/test_configurator.py new file mode 100644 index 0000000000..f12499d890 --- /dev/null +++ b/src/tests/_internal/core/backends/cloudrift/test_configurator.py @@ -0,0 +1,34 @@ +from unittest.mock import patch + +import pytest + +from dstack._internal.core.backends.cloudrift.configurator import ( + CloudRiftConfigurator, +) +from dstack._internal.core.backends.cloudrift.models import ( + CloudRiftBackendConfigWithCreds, + CloudRiftCreds, +) +from dstack._internal.core.errors import BackendInvalidCredentialsError + + +class TestDataCrunchConfigurator: + def test_validate_config_valid(self): + config = CloudRiftBackendConfigWithCreds(creds=CloudRiftCreds(api_key="valid")) + with patch( + "dstack._internal.core.backends.cloudrift.api_client.RiftClient.validate_api_key" + ) as validate_mock: + validate_mock.return_value = True + CloudRiftConfigurator().validate_config(config, default_creds_enabled=True) + + def test_validate_config_invalid(self): + config = CloudRiftBackendConfigWithCreds(creds=CloudRiftCreds(api_key="invalid")) + with ( + patch( + "dstack._internal.core.backends.cloudrift.api_client.RiftClient.validate_api_key" + ) as validate_mock, + pytest.raises(BackendInvalidCredentialsError) as exc_info, + ): + validate_mock.return_value = False + CloudRiftConfigurator().validate_config(config, default_creds_enabled=True) + assert exc_info.value.fields == [["creds", "api_key"]] diff --git a/src/tests/_internal/server/routers/test_backends.py b/src/tests/_internal/server/routers/test_backends.py index 569ab88ee2..6afe36c0c6 100644 --- a/src/tests/_internal/server/routers/test_backends.py +++ b/src/tests/_internal/server/routers/test_backends.py @@ -79,6 +79,7 @@ async def test_returns_backend_types(self, client: AsyncClient): assert response.json() == [ "aws", "azure", + "cloudrift", "cudo", "datacrunch", "gcp",