From bd1adc8e62a2e44394d0e68c048bbf911295e5bd Mon Sep 17 00:00:00 2001 From: Dmitry Trifonov Date: Tue, 27 May 2025 10:26:21 -0700 Subject: [PATCH 1/8] WIP --- pyproject.toml | 2 +- .../core/backends/cloudrift/__init__.py | 0 .../core/backends/cloudrift/backend.py | 16 +++ .../core/backends/cloudrift/compute.py | 125 ++++++++++++++++++ .../core/backends/cloudrift/configurator.py | 70 ++++++++++ .../core/backends/cloudrift/models.py | 40 ++++++ .../_internal/core/models/backends/base.py | 2 + 7 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 src/dstack/_internal/core/backends/cloudrift/__init__.py create mode 100644 src/dstack/_internal/core/backends/cloudrift/backend.py create mode 100644 src/dstack/_internal/core/backends/cloudrift/compute.py create mode 100644 src/dstack/_internal/core/backends/cloudrift/configurator.py create mode 100644 src/dstack/_internal/core/backends/cloudrift/models.py diff --git a/pyproject.toml b/pyproject.toml index 8ec49b0ea3..46bea5ee59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "python-multipart>=0.0.16", "filelock", "psutil", - "gpuhunt>=0.1.3,<0.2.0", + "gpuhunt>=0.1.4,<0.2.0", "argcomplete>=3.5.0", ] diff --git a/src/dstack/_internal/core/backends/cloudrift/__init__.py b/src/dstack/_internal/core/backends/cloudrift/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/_internal/core/backends/cloudrift/backend.py b/src/dstack/_internal/core/backends/cloudrift/backend.py new file mode 100644 index 0000000000..cca0c620c6 --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/backend.py @@ -0,0 +1,16 @@ +from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.cloudrift.compute import CloudRiftCompute +from dstack._internal.core.backends.cloudrift.models import CloudRiftConfig +from dstack._internal.core.models.backends.base import BackendType + + +class CloudRiftBackend(Backend): + TYPE = BackendType.CLOUDRIFT + COMPUTE_CLASS = CloudRiftCompute + + def __init__(self, config: CloudRiftConfig): + self.config = config + self._compute = CloudRiftCompute(self.config) + + def compute(self) -> CloudRiftCompute: + return self._compute diff --git a/src/dstack/_internal/core/backends/cloudrift/compute.py b/src/dstack/_internal/core/backends/cloudrift/compute.py new file mode 100644 index 0000000000..c6b32dcf66 --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/compute.py @@ -0,0 +1,125 @@ +from typing import Dict, List, Optional, Union + +import requests + +from dstack._internal.core.backends.base.backend import Compute +from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, +) +from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.cloudrift.models import CloudRiftConfig +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceConfiguration, + InstanceOffer, + InstanceOfferWithAvailability, +) +from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run +from dstack._internal.core.models.volumes import Volume +from dstack._internal.utils.logging import get_logger +from src.dstack._internal.core.backends.cloudrift.models import CloudRiftAPIKeyCreds + +logger = get_logger(__name__) + + +CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai" +CLOUDRIFT_API_VERSION = "2025-03-21" + + +class CloudRiftCompute( + ComputeWithCreateInstanceSupport, + Compute, +): + def __init__(self, config: CloudRiftConfig): + super().__init__() + self.config = config + + def get_offers( + self, requirements: Optional[Requirements] = None + ) -> List[InstanceOfferWithAvailability]: + offers = get_catalog_offers( + backend=BackendType.CLOUDRIFT, + locations=self.config.regions or None, + requirements=requirements, + ) + offers_with_availabilities = self._get_offers_with_availability(offers) + return offers_with_availabilities + + def _get_offers_with_availability( + self, offers: List[InstanceOffer] + ) -> List[InstanceOfferWithAvailability]: + instance_types_with_availabilities: List[Dict] = _get_instance_types() + + region_availabilities = {} + for instance_type in instance_types_with_availabilities: + for variant in instance_type["variants"]: + for dc, count in variant["available_nodes_per_dc"].items(): + if count > 0: + key = (variant["name"], dc) + region_availabilities[key] = InstanceAvailability.AVAILABLE + + availability_offers = [] + for offer in offers: + key = (offer.instance.name, offer.region) + availability = region_availabilities.get(key, InstanceAvailability.NOT_AVAILABLE) + availability_offers.append( + InstanceOfferWithAvailability(**offer.dict(), availability=availability) + ) + + return availability_offers + + def create_instance( + self, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + ) -> JobProvisioningData: + # TODO: Implement if backend supports creating instances (VM-based). + # Delete if backend can only run jobs (container-based). + raise NotImplementedError() + + def run_job( + self, + run: Run, + job: Job, + instance_offer: InstanceOfferWithAvailability, + project_ssh_public_key: str, + project_ssh_private_key: str, + volumes: List[Volume], + ) -> JobProvisioningData: + # TODO: Implement if create_instance() is not implemented. Delete otherwise. + raise NotImplementedError() + + def terminate_instance( + self, instance_id: str, region: str, backend_data: Optional[str] = None + ): + raise NotImplementedError() + + +def _get_instance_types(): + request_data = {"selector": {"ByServiceAndLocation": {"services": ["vm"]}}} + response_data = _make_request("instance-types/list", request_data) + return response_data["instance_types"] + + +def _make_request(endpoint: str, request_data: dict) -> Union[dict, str, None]: + response = requests.request( + "POST", + f"{CLOUDRIFT_SERVER_ADDRESS}/api/v1/{endpoint}", + json={"version": CLOUDRIFT_API_VERSION, "data": request_data}, + timeout=5.0, + ) + if not response.ok: + response.raise_for_status() + try: + response_json = response.json() + if isinstance(response_json, str): + return response_json + return response_json["data"] + except requests.exceptions.JSONDecodeError: + return None + + +if __name__ == "__main__": + compute = CloudRiftCompute(CloudRiftConfig(creds=CloudRiftAPIKeyCreds(api_key="asdasdasd"))) + print(compute.get_offers()) diff --git a/src/dstack/_internal/core/backends/cloudrift/configurator.py b/src/dstack/_internal/core/backends/cloudrift/configurator.py new file mode 100644 index 0000000000..4c1d7baf55 --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/configurator.py @@ -0,0 +1,70 @@ +import json + +from dstack._internal.core.backends.base.configurator import ( + BackendRecord, + Configurator, + raise_invalid_credentials_error, +) +from dstack._internal.core.backends.cloudrift.backend import CloudRiftBackend +from dstack._internal.core.backends.cloudrift.models import ( + AnyCloudRiftBackendConfig, + AnyCloudRiftCreds, + CloudRiftBackendConfig, + CloudRiftBackendConfigWithCreds, + CloudRiftConfig, + CloudRiftCreds, + CloudRiftStoredConfig, +) +from dstack._internal.core.models.backends.base import ( + BackendType, +) + +# TODO: Add all supported regions and default regions +REGIONS = [] + + +class CloudRiftConfigurator(Configurator): + TYPE = BackendType.CLOUDRIFT + BACKEND_CLASS = CloudRiftBackend + + def validate_config( + self, config: CloudRiftBackendConfigWithCreds, default_creds_enabled: bool + ): + self._validate_creds(config.creds) + # TODO: Validate additional config parameters if any + + def create_backend( + self, project_name: str, config: CloudRiftBackendConfigWithCreds + ) -> BackendRecord: + if config.regions is None: + config.regions = REGIONS + return BackendRecord( + config=CloudRiftStoredConfig( + **CloudRiftBackendConfig.__response__.parse_obj(config).dict() + ).json(), + auth=CloudRiftCreds.parse_obj(config.creds).json(), + ) + + def get_backend_config( + self, record: BackendRecord, include_creds: bool + ) -> AnyCloudRiftBackendConfig: + config = self._get_config(record) + if include_creds: + return CloudRiftBackendConfigWithCreds.__response__.parse_obj(config) + return CloudRiftBackendConfig.__response__.parse_obj(config) + + def get_backend(self, record: BackendRecord) -> CloudRiftBackend: + config = self._get_config(record) + return CloudRiftBackend(config=config) + + def _get_config(self, record: BackendRecord) -> CloudRiftConfig: + return CloudRiftConfig.__response__( + **json.loads(record.config), + creds=CloudRiftCreds.parse_raw(record.auth), + ) + + def _validate_creds(self, creds: AnyCloudRiftCreds): + # TODO: Implement API key or other creds validation + # if valid: + # return + raise_invalid_credentials_error(fields=[["creds", "api_key"]]) diff --git a/src/dstack/_internal/core/backends/cloudrift/models.py b/src/dstack/_internal/core/backends/cloudrift/models.py new file mode 100644 index 0000000000..62a6726f9a --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/models.py @@ -0,0 +1,40 @@ +from typing import Annotated, List, Literal, Optional, Union + +from pydantic import Field + +from dstack._internal.core.models.common import CoreModel + + +class CloudRiftAPIKeyCreds(CoreModel): + type: Annotated[Literal["api_key"], Field(description="The type of credentials")] = "api_key" + api_key: Annotated[str, Field(description="The API key")] + + +AnyCloudRiftCreds = CloudRiftAPIKeyCreds +CloudRiftCreds = AnyCloudRiftCreds + + +class CloudRiftBackendConfig(CoreModel): + type: Annotated[ + Literal["cloudrift"], + Field(description="The type of backend"), + ] = "cloudrift" + regions: Annotated[ + Optional[List[str]], + Field(description="The list of CloudRift regions. Omit to use all regions"), + ] = None + + +class CloudRiftBackendConfigWithCreds(CloudRiftBackendConfig): + creds: Annotated[AnyCloudRiftCreds, Field(description="The credentials")] + + +AnyCloudRiftBackendConfig = Union[CloudRiftBackendConfig, CloudRiftBackendConfigWithCreds] + + +class CloudRiftStoredConfig(CloudRiftBackendConfig): + pass + + +class CloudRiftConfig(CloudRiftStoredConfig): + creds: AnyCloudRiftCreds diff --git a/src/dstack/_internal/core/models/backends/base.py b/src/dstack/_internal/core/models/backends/base.py index 47df1163a1..78aafb142c 100644 --- a/src/dstack/_internal/core/models/backends/base.py +++ b/src/dstack/_internal/core/models/backends/base.py @@ -6,6 +6,7 @@ class BackendType(str, enum.Enum): Attributes: AWS (BackendType): Amazon Web Services AZURE (BackendType): Microsoft Azure + CLOUDRIFT (BackendType): CloudRift CUDO (BackendType): Cudo DSTACK (BackendType): dstack Sky GCP (BackendType): Google Cloud Platform @@ -22,6 +23,7 @@ class BackendType(str, enum.Enum): AWS = "aws" AZURE = "azure" + CLOUDRIFT = "cloudrift" CUDO = "cudo" DATACRUNCH = "datacrunch" DSTACK = "dstack" From ccf8df54f2c45537f3b6a54ac7c63cb79642975e Mon Sep 17 00:00:00 2001 From: Slawek Date: Sun, 1 Jun 2025 02:21:41 -0700 Subject: [PATCH 2/8] added instance renting logic --- .../core/backends/cloudrift/api_client.py | 225 ++++++++++++++++++ .../core/backends/cloudrift/compute.py | 108 +++++---- .../core/backends/cloudrift/configurator.py | 12 +- .../_internal/core/backends/configurators.py | 9 + src/dstack/_internal/core/backends/models.py | 7 + 5 files changed, 305 insertions(+), 56 deletions(-) create mode 100644 src/dstack/_internal/core/backends/cloudrift/api_client.py diff --git a/src/dstack/_internal/core/backends/cloudrift/api_client.py b/src/dstack/_internal/core/backends/cloudrift/api_client.py new file mode 100644 index 0000000000..58cae15b1e --- /dev/null +++ b/src/dstack/_internal/core/backends/cloudrift/api_client.py @@ -0,0 +1,225 @@ +import os +import re +from typing import Any, Dict, List, Mapping, Optional, Union + +import requests +from packaging import version +from requests import Response + +from dstack._internal.core.errors import BackendError, BackendInvalidCredentialsError +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai" +CLOUDRIFT_API_VERSION = "2025-03-21" + + +class RiftClient: + def __init__(self, api_key: Optional[str] = None): + self.server_address = CLOUDRIFT_SERVER_ADDRESS + self.public_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "api/v1") + self.internal_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "internal") + self.api_key = api_key + + def validate_api_key(self) -> bool: + """ + Validates the API key by making a request to the server. + Returns True if the API key is valid, False otherwise. + """ + try: + response = self._make_request("auth/me") + if isinstance(response, dict): + return response.get("email", False) + return False + except BackendInvalidCredentialsError: + return False + except Exception as e: + logger.error(f"Error validating API key: {e}") + return False + + def get_instance_types(self) -> List[Dict]: + request_data = {"selector": {"ByServiceAndLocation": {"services": ["vm"]}}} + response_data = self._make_request("instance-types/list", request_data) + if isinstance(response_data, dict): + return response_data.get("instance_types", []) + return [] + + def list_recipies(self) -> List[Dict]: + request_data = {} + response_data = self._make_request("recipes/list", request_data) + if isinstance(response_data, dict): + return response_data.get("groups", []) + return [] + + def get_vm_recipies(self) -> List[Dict]: + """ + Retrieves a list of VM recipes from the CloudRift API. + Returns a list of dictionaries containing recipe information. + """ + recipe_group = self.list_recipies() + vm_recipes = [] + for group in recipe_group: + tags = group.get("tags ", []) + has_vm = "vm" in tags + if group.get("name", "").lower() != "linux" and not has_vm: + continue + + recipes = group.get("recipes", []) + for recipe in recipes: + details = recipe.get("details", {}) + if details.get("VirtualMachine", False): + vm_recipes.append(recipe) + + return vm_recipes + + def get_vm_image_url(self) -> str | None: + recipes = self.get_vm_recipies() + ubuntu_images = [] + for recipe in recipes: + has_nvidia_driver = "nvidia-driver" in recipe.get("tags", []) + if not has_nvidia_driver: + continue + + recipe_name = recipe.get("name", "") + if "Ubuntu" not in recipe_name: + continue + + url = recipe["details"].get("VirtualMachine", {}).get("image_url", None) + version_match = re.search(r".* (\d+\.\d+)", recipe_name) + if url and version_match and version_match.group(1): + ubuntu_version = version.parse(version_match.group(1)) + ubuntu_images.append((ubuntu_version, url)) + + ubuntu_images.sort(key=lambda x: x[0]) # Sort by version + if ubuntu_images: + return ubuntu_images[-1][1] + + return None + + def deploy_instance(self, instance_type: str, region: str, ssh_keys: List[str]) -> List[str]: + image_url = self.get_vm_image_url() + if not image_url: + raise BackendError("No suitable VM image found.") + + request_data = { + "config": { + "VirtualMachine": { + # "cloudinit_url": "", + "image_url": image_url, + "ssh_key": {"PublicKeys": ssh_keys}, + } + }, + "selector": { + "ByInstanceTypeAndLocation": { + "datacenters": [region], + "instance_type": instance_type, + } + }, + "with_public_ip": True, + } + logger.debug("Deploying instance with request data: %s", request_data) + + response_data = self._make_request("instances/rent", request_data) + if isinstance(response_data, dict): + return response_data.get("instance_ids", []) + return [] + + def list_instances(self, instance_ids: Optional[List[str]] = None) -> List[Dict]: + request_data = { + "selector": { + "ByStatus": ["Initializing", "Active", "Deactivating"], + } + } + logger.debug("Listing instances with request data: %s", request_data) + response_data = self._make_request("instances/list", request_data) + if isinstance(response_data, dict): + return response_data.get("instances", []) + + return [] + + def get_instance_by_id(self, instance_id: str) -> Optional[Dict]: + request_data = {"selector": {"ById": [instance_id]}} + logger.debug("Getting instance with request data: %s", request_data) + response_data = self._make_request("instances/list", request_data) + if isinstance(response_data, dict): + instances = response_data.get("instances", []) + if isinstance(instances, list) and len(instances) > 0: + return instances[0] + + return None + + def is_instance_ready(self, instance_id: str) -> bool: + """ + Checks if the instance with the given ID is ready. + Returns True if the instance is ready, False otherwise. + """ + instance_info = self.get_instance_by_id(instance_id) + if instance_info: + instance_type = instance_info.get("node_mode", "") + if instance_type == "VirtualMachine": + vms = instance_info.get("virtual_machines", []) + if len(vms) > 0: + vm_ready = vms[0].get("ready", False) + return vm_ready + else: + return instance_info.get("status", "") == "Active" + return False + + def terminate_instance(self, instance_id: str) -> bool: + request_data = {"selector": {"ById": [instance_id]}} + logger.debug("Terminating instance with request data: %s", request_data) + response_data = self._make_request("instances/terminate", request_data) + if isinstance(response_data, dict): + info = response_data.get("terminated", []) + return len(info) > 0 + + return False + + def _make_request( + self, + endpoint: str, + data: Optional[Mapping[str, Any]] = None, + method: str = "POST", + **kwargs, + ) -> Union[Mapping[str, Any], str, Response]: + headers = {} + if self.api_key is not None: + headers["X-API-Key"] = self.api_key + + version = CLOUDRIFT_API_VERSION + full_url = f"{self.public_api_root}/{endpoint}" + + try: + response = requests.request( + method, + full_url, + headers=headers, + json={"version": version, "data": data}, + timeout=120, + **kwargs, + ) + + if not response.ok: + response.raise_for_status() + try: + response_json = response.json() + if isinstance(response_json, str): + return response_json + if version is not None and version < response_json["version"]: + logger.warning( + "The API version %s is lower than the server version %s. ", + version, + response_json["version"], + ) + return response_json["data"] + except requests.exceptions.JSONDecodeError: + return response + except requests.HTTPError as e: + if e.response is not None and e.response.status_code in ( + requests.codes.forbidden, + requests.codes.unauthorized, + ): + raise BackendInvalidCredentialsError(e.response.text) + raise diff --git a/src/dstack/_internal/core/backends/cloudrift/compute.py b/src/dstack/_internal/core/backends/cloudrift/compute.py index c6b32dcf66..be10067308 100644 --- a/src/dstack/_internal/core/backends/cloudrift/compute.py +++ b/src/dstack/_internal/core/backends/cloudrift/compute.py @@ -1,13 +1,13 @@ -from typing import Dict, List, Optional, Union - -import requests +from typing import Dict, List, Optional from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( ComputeWithCreateInstanceSupport, ) from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.cloudrift.api_client import RiftClient from dstack._internal.core.backends.cloudrift.models import CloudRiftConfig +from dstack._internal.core.errors import ComputeError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -15,18 +15,12 @@ InstanceOffer, InstanceOfferWithAvailability, ) -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume +from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger -from src.dstack._internal.core.backends.cloudrift.models import CloudRiftAPIKeyCreds logger = get_logger(__name__) -CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai" -CLOUDRIFT_API_VERSION = "2025-03-21" - - class CloudRiftCompute( ComputeWithCreateInstanceSupport, Compute, @@ -34,6 +28,7 @@ class CloudRiftCompute( def __init__(self, config: CloudRiftConfig): super().__init__() self.config = config + self.client = RiftClient(self.config.creds.api_key) def get_offers( self, requirements: Optional[Requirements] = None @@ -43,13 +38,14 @@ def get_offers( locations=self.config.regions or None, requirements=requirements, ) + offers_with_availabilities = self._get_offers_with_availability(offers) return offers_with_availabilities def _get_offers_with_availability( self, offers: List[InstanceOffer] ) -> List[InstanceOfferWithAvailability]: - instance_types_with_availabilities: List[Dict] = _get_instance_types() + instance_types_with_availabilities: List[Dict] = self.client.get_instance_types() region_availabilities = {} for instance_type in instance_types_with_availabilities: @@ -66,6 +62,9 @@ def _get_offers_with_availability( availability_offers.append( InstanceOfferWithAvailability(**offer.dict(), availability=availability) ) + logger.debug( + f"Offer {offer.instance.name} in region {offer.region} has availability: {availability}" + ) return availability_offers @@ -74,52 +73,57 @@ def create_instance( instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration, ) -> JobProvisioningData: - # TODO: Implement if backend supports creating instances (VM-based). - # Delete if backend can only run jobs (container-based). - raise NotImplementedError() + # TODO: add commands to cloud-init + # commands = get_shim_commands(authorized_keys=instance_config.get_public_keys()) + # logger.debug( + # f"Creating instance for offer {instance_offer.instance.name} in region {instance_offer.region} with commands: {commands}" + # ) + + instance_ids = self.client.deploy_instance( + instance_type=instance_offer.instance.name, + region=instance_offer.region, + ssh_keys=instance_config.get_public_keys(), + ) + + if len(instance_ids) == 0: + raise ComputeError( + f"Failed to create instance for offer {instance_offer.instance.name} in region {instance_offer.region}." + ) + + return JobProvisioningData( + backend=instance_offer.backend, + instance_type=instance_offer.instance, + instance_id=instance_ids[0], + hostname=None, + internal_ip=None, + region=instance_offer.region, + price=instance_offer.price, + username="riftuser", + ssh_port=22, + dockerized=True, + ssh_proxy=None, + backend_data=None, + ) - def run_job( + def update_provisioning_data( self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, + provisioning_data: JobProvisioningData, project_ssh_public_key: str, project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - # TODO: Implement if create_instance() is not implemented. Delete otherwise. - raise NotImplementedError() + ): + instance_info = self.client.get_instance_by_id(provisioning_data.instance_id) + if instance_info: + vms = instance_info.get("virtual_machines", []) + if len(vms) > 0: + vm_ready = vms[0].get("ready", False) + if vm_ready: + provisioning_data.hostname = instance_info.get("host_address", None) + + pass def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): - raise NotImplementedError() - - -def _get_instance_types(): - request_data = {"selector": {"ByServiceAndLocation": {"services": ["vm"]}}} - response_data = _make_request("instance-types/list", request_data) - return response_data["instance_types"] - - -def _make_request(endpoint: str, request_data: dict) -> Union[dict, str, None]: - response = requests.request( - "POST", - f"{CLOUDRIFT_SERVER_ADDRESS}/api/v1/{endpoint}", - json={"version": CLOUDRIFT_API_VERSION, "data": request_data}, - timeout=5.0, - ) - if not response.ok: - response.raise_for_status() - try: - response_json = response.json() - if isinstance(response_json, str): - return response_json - return response_json["data"] - except requests.exceptions.JSONDecodeError: - return None - - -if __name__ == "__main__": - compute = CloudRiftCompute(CloudRiftConfig(creds=CloudRiftAPIKeyCreds(api_key="asdasdasd"))) - print(compute.get_offers()) + terminated = self.client.terminate_instance(instance_id=instance_id) + if not terminated: + raise ComputeError(f"Failed to terminate instance {instance_id} in region {region}.") diff --git a/src/dstack/_internal/core/backends/cloudrift/configurator.py b/src/dstack/_internal/core/backends/cloudrift/configurator.py index 4c1d7baf55..0b36d25138 100644 --- a/src/dstack/_internal/core/backends/cloudrift/configurator.py +++ b/src/dstack/_internal/core/backends/cloudrift/configurator.py @@ -5,6 +5,7 @@ Configurator, raise_invalid_credentials_error, ) +from dstack._internal.core.backends.cloudrift.api_client import RiftClient from dstack._internal.core.backends.cloudrift.backend import CloudRiftBackend from dstack._internal.core.backends.cloudrift.models import ( AnyCloudRiftBackendConfig, @@ -22,6 +23,8 @@ # TODO: Add all supported regions and default regions REGIONS = [] +CLOUDRIFT_API_URL = "https://api.cloudrift.ai" + class CloudRiftConfigurator(Configurator): TYPE = BackendType.CLOUDRIFT @@ -64,7 +67,8 @@ def _get_config(self, record: BackendRecord) -> CloudRiftConfig: ) def _validate_creds(self, creds: AnyCloudRiftCreds): - # TODO: Implement API key or other creds validation - # if valid: - # return - raise_invalid_credentials_error(fields=[["creds", "api_key"]]) + if not isinstance(creds, CloudRiftCreds): + raise_invalid_credentials_error(fields=[["creds"]]) + client = RiftClient(creds.api_key) + if not client.validate_api_key(): + raise_invalid_credentials_error(fields=[["creds", "api_key"]]) diff --git a/src/dstack/_internal/core/backends/configurators.py b/src/dstack/_internal/core/backends/configurators.py index 43a215d3ab..571d010529 100644 --- a/src/dstack/_internal/core/backends/configurators.py +++ b/src/dstack/_internal/core/backends/configurators.py @@ -20,6 +20,15 @@ except ImportError: pass +try: + from dstack._internal.core.backends.cloudrift.configurator import ( + CloudRiftConfigurator, + ) + + _CONFIGURATOR_CLASSES.append(CloudRiftConfigurator) +except ImportError: + pass + try: from dstack._internal.core.backends.cudo.configurator import ( CudoConfigurator, diff --git a/src/dstack/_internal/core/backends/models.py b/src/dstack/_internal/core/backends/models.py index 18567592a5..0b5779db78 100644 --- a/src/dstack/_internal/core/backends/models.py +++ b/src/dstack/_internal/core/backends/models.py @@ -8,6 +8,10 @@ AzureBackendConfig, AzureBackendConfigWithCreds, ) +from dstack._internal.core.backends.cloudrift.models import ( + CloudRiftBackendConfig, + CloudRiftBackendConfigWithCreds, +) from dstack._internal.core.backends.cudo.models import ( CudoBackendConfig, CudoBackendConfigWithCreds, @@ -65,6 +69,7 @@ AnyBackendConfigWithoutCreds = Union[ AWSBackendConfig, AzureBackendConfig, + CloudRiftBackendConfig, CudoBackendConfig, DataCrunchBackendConfig, GCPBackendConfig, @@ -86,6 +91,7 @@ AnyBackendConfigWithCreds = Union[ AWSBackendConfigWithCreds, AzureBackendConfigWithCreds, + CloudRiftBackendConfigWithCreds, CudoBackendConfigWithCreds, DataCrunchBackendConfigWithCreds, GCPBackendConfigWithCreds, @@ -106,6 +112,7 @@ AnyBackendFileConfigWithCreds = Union[ AWSBackendConfigWithCreds, AzureBackendConfigWithCreds, + CloudRiftBackendConfigWithCreds, CudoBackendConfigWithCreds, DataCrunchBackendConfigWithCreds, GCPBackendFileConfigWithCreds, From d5dbaacfc4ed6624c64169b57cffb2fe31d26e27 Mon Sep 17 00:00:00 2001 From: Slawek Date: Thu, 5 Jun 2025 00:43:04 -0700 Subject: [PATCH 3/8] pass custom commands --- .../_internal/core/backends/cloudrift/api_client.py | 5 ++++- .../_internal/core/backends/cloudrift/compute.py | 12 +++++++----- .../core/backends/cloudrift/configurator.py | 2 -- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/dstack/_internal/core/backends/cloudrift/api_client.py b/src/dstack/_internal/core/backends/cloudrift/api_client.py index 58cae15b1e..edb3f60d39 100644 --- a/src/dstack/_internal/core/backends/cloudrift/api_client.py +++ b/src/dstack/_internal/core/backends/cloudrift/api_client.py @@ -98,7 +98,9 @@ def get_vm_image_url(self) -> str | None: return None - def deploy_instance(self, instance_type: str, region: str, ssh_keys: List[str]) -> List[str]: + def deploy_instance( + self, instance_type: str, region: str, ssh_keys: List[str], cmd: str + ) -> List[str]: image_url = self.get_vm_image_url() if not image_url: raise BackendError("No suitable VM image found.") @@ -107,6 +109,7 @@ def deploy_instance(self, instance_type: str, region: str, ssh_keys: List[str]) "config": { "VirtualMachine": { # "cloudinit_url": "", + "cloudinit_commands": cmd, "image_url": image_url, "ssh_key": {"PublicKeys": ssh_keys}, } diff --git a/src/dstack/_internal/core/backends/cloudrift/compute.py b/src/dstack/_internal/core/backends/cloudrift/compute.py index be10067308..38bef645dc 100644 --- a/src/dstack/_internal/core/backends/cloudrift/compute.py +++ b/src/dstack/_internal/core/backends/cloudrift/compute.py @@ -3,6 +3,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( ComputeWithCreateInstanceSupport, + get_shim_commands, ) from dstack._internal.core.backends.base.offers import get_catalog_offers from dstack._internal.core.backends.cloudrift.api_client import RiftClient @@ -73,16 +74,17 @@ def create_instance( instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration, ) -> JobProvisioningData: - # TODO: add commands to cloud-init - # commands = get_shim_commands(authorized_keys=instance_config.get_public_keys()) - # logger.debug( - # f"Creating instance for offer {instance_offer.instance.name} in region {instance_offer.region} with commands: {commands}" - # ) + commands = get_shim_commands(authorized_keys=instance_config.get_public_keys()) + startup_script = " ".join([" && ".join(commands)]) + logger.debug( + f"Creating instance for offer {instance_offer.instance.name} in region {instance_offer.region} with commands: {startup_script}" + ) instance_ids = self.client.deploy_instance( instance_type=instance_offer.instance.name, region=instance_offer.region, ssh_keys=instance_config.get_public_keys(), + cmd=startup_script, ) if len(instance_ids) == 0: diff --git a/src/dstack/_internal/core/backends/cloudrift/configurator.py b/src/dstack/_internal/core/backends/cloudrift/configurator.py index 0b36d25138..124fea2be9 100644 --- a/src/dstack/_internal/core/backends/cloudrift/configurator.py +++ b/src/dstack/_internal/core/backends/cloudrift/configurator.py @@ -23,8 +23,6 @@ # TODO: Add all supported regions and default regions REGIONS = [] -CLOUDRIFT_API_URL = "https://api.cloudrift.ai" - class CloudRiftConfigurator(Configurator): TYPE = BackendType.CLOUDRIFT From f1507f05799daf9eaee55109464faad4f358e832 Mon Sep 17 00:00:00 2001 From: Slawek Date: Fri, 6 Jun 2025 13:05:10 -0700 Subject: [PATCH 4/8] doc --- docs/docs/concepts/backends.md | 22 ++++++++++++++++++++++ docs/docs/reference/server/config.yml.md | 17 +++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/docs/docs/concepts/backends.md b/docs/docs/concepts/backends.md index e0f52e6e67..3a70c6cfea 100644 --- a/docs/docs/concepts/backends.md +++ b/docs/docs/concepts/backends.md @@ -913,6 +913,28 @@ projects: +### CloudRift + +Log into your [CloudRift :material-arrow-top-right-thin:{ .external }](https://www.cloudrift.ai/console/) console, click `API Keys` in the sidebar and click the button to create a new API key. + +Ensure you've created a project with CloudRift, + +Then proceed to configuring the backend. + +
+ +```yaml +projects: + - name: main + backends: + - type: cloudrift + creds: + type: api_key + api_key: rift_2prgY1d0laOrf2BblTwx2B2d1zcf1zIp4tZYpj5j88qmNgz38pxNlpX3vAo +``` + +
+ ## On-prem servers ### SSH fleets diff --git a/docs/docs/reference/server/config.yml.md b/docs/docs/reference/server/config.yml.md index a021c138a0..fbe378d8cd 100644 --- a/docs/docs/reference/server/config.yml.md +++ b/docs/docs/reference/server/config.yml.md @@ -315,6 +315,23 @@ to configure [backends](../../concepts/backends.md) and other [sever-level setti type: required: true +##### `projects[n].backends[type=cloudrift]` { #cloudrift data-toc-label="cloudrift" } + +#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftBackendConfigWithCreds + overrides: + show_root_heading: false + type: + required: true + item_id_prefix: cloudrift- + +###### `projects[n].backends[type=cloudrift].creds` { #cloudrift-creds data-toc-label="creds" } + +#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftAPIKeyCreds + overrides: + show_root_heading: false + type: + required: true + ### `encryption` { #encryption data-toc-label="encryption" } #SCHEMA# dstack._internal.server.services.config.EncryptionConfig From c91a183f6575a2a7981e14dbcb24757e627a5b4a Mon Sep 17 00:00:00 2001 From: Slawek Date: Fri, 6 Jun 2025 13:11:38 -0700 Subject: [PATCH 5/8] updated version --- src/dstack/_internal/core/backends/cloudrift/api_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/core/backends/cloudrift/api_client.py b/src/dstack/_internal/core/backends/cloudrift/api_client.py index edb3f60d39..8b3a45f3c0 100644 --- a/src/dstack/_internal/core/backends/cloudrift/api_client.py +++ b/src/dstack/_internal/core/backends/cloudrift/api_client.py @@ -13,7 +13,7 @@ CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai" -CLOUDRIFT_API_VERSION = "2025-03-21" +CLOUDRIFT_API_VERSION = "2025-05-29" class RiftClient: From 45f9f6e991011a3a09ebbb0e8665f5ad341ab6a6 Mon Sep 17 00:00:00 2001 From: Slawek Date: Fri, 6 Jun 2025 15:42:56 -0700 Subject: [PATCH 6/8] tests --- .../core/backends/cloudrift/__init__.py | 0 .../backends/cloudrift/test_configurator.py | 34 +++++++++++++++++++ .../_internal/server/routers/test_backends.py | 1 + 3 files changed, 35 insertions(+) create mode 100644 src/tests/_internal/core/backends/cloudrift/__init__.py create mode 100644 src/tests/_internal/core/backends/cloudrift/test_configurator.py diff --git a/src/tests/_internal/core/backends/cloudrift/__init__.py b/src/tests/_internal/core/backends/cloudrift/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/core/backends/cloudrift/test_configurator.py b/src/tests/_internal/core/backends/cloudrift/test_configurator.py new file mode 100644 index 0000000000..f12499d890 --- /dev/null +++ b/src/tests/_internal/core/backends/cloudrift/test_configurator.py @@ -0,0 +1,34 @@ +from unittest.mock import patch + +import pytest + +from dstack._internal.core.backends.cloudrift.configurator import ( + CloudRiftConfigurator, +) +from dstack._internal.core.backends.cloudrift.models import ( + CloudRiftBackendConfigWithCreds, + CloudRiftCreds, +) +from dstack._internal.core.errors import BackendInvalidCredentialsError + + +class TestDataCrunchConfigurator: + def test_validate_config_valid(self): + config = CloudRiftBackendConfigWithCreds(creds=CloudRiftCreds(api_key="valid")) + with patch( + "dstack._internal.core.backends.cloudrift.api_client.RiftClient.validate_api_key" + ) as validate_mock: + validate_mock.return_value = True + CloudRiftConfigurator().validate_config(config, default_creds_enabled=True) + + def test_validate_config_invalid(self): + config = CloudRiftBackendConfigWithCreds(creds=CloudRiftCreds(api_key="invalid")) + with ( + patch( + "dstack._internal.core.backends.cloudrift.api_client.RiftClient.validate_api_key" + ) as validate_mock, + pytest.raises(BackendInvalidCredentialsError) as exc_info, + ): + validate_mock.return_value = False + CloudRiftConfigurator().validate_config(config, default_creds_enabled=True) + assert exc_info.value.fields == [["creds", "api_key"]] diff --git a/src/tests/_internal/server/routers/test_backends.py b/src/tests/_internal/server/routers/test_backends.py index 569ab88ee2..6afe36c0c6 100644 --- a/src/tests/_internal/server/routers/test_backends.py +++ b/src/tests/_internal/server/routers/test_backends.py @@ -79,6 +79,7 @@ async def test_returns_backend_types(self, client: AsyncClient): assert response.json() == [ "aws", "azure", + "cloudrift", "cudo", "datacrunch", "gcp", From 33eabc229852c67b07f998396e9b9983cd018080 Mon Sep 17 00:00:00 2001 From: Slawek Date: Mon, 23 Jun 2025 14:26:53 -0700 Subject: [PATCH 7/8] test fix --- src/dstack/_internal/core/backends/cloudrift/api_client.py | 2 +- src/tests/_internal/server/routers/test_backends.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/core/backends/cloudrift/api_client.py b/src/dstack/_internal/core/backends/cloudrift/api_client.py index 8b3a45f3c0..00de2adb26 100644 --- a/src/dstack/_internal/core/backends/cloudrift/api_client.py +++ b/src/dstack/_internal/core/backends/cloudrift/api_client.py @@ -74,7 +74,7 @@ def get_vm_recipies(self) -> List[Dict]: return vm_recipes - def get_vm_image_url(self) -> str | None: + def get_vm_image_url(self) -> Optional[str]: recipes = self.get_vm_recipies() ubuntu_images = [] for recipe in recipes: diff --git a/src/tests/_internal/server/routers/test_backends.py b/src/tests/_internal/server/routers/test_backends.py index 569ab88ee2..6afe36c0c6 100644 --- a/src/tests/_internal/server/routers/test_backends.py +++ b/src/tests/_internal/server/routers/test_backends.py @@ -79,6 +79,7 @@ async def test_returns_backend_types(self, client: AsyncClient): assert response.json() == [ "aws", "azure", + "cloudrift", "cudo", "datacrunch", "gcp", From b4eeb7e4452bab54b63bb3fa86c92dc05b6ec724 Mon Sep 17 00:00:00 2001 From: Slawek Date: Mon, 23 Jun 2025 17:30:33 -0700 Subject: [PATCH 8/8] PR feedback --- docs/docs/concepts/backends.md | 4 +-- .../core/backends/cloudrift/api_client.py | 34 ++++--------------- .../core/backends/cloudrift/compute.py | 29 ++++++++++------ .../core/backends/cloudrift/configurator.py | 6 ---- 4 files changed, 27 insertions(+), 46 deletions(-) diff --git a/docs/docs/concepts/backends.md b/docs/docs/concepts/backends.md index 3a70c6cfea..21208d74b9 100644 --- a/docs/docs/concepts/backends.md +++ b/docs/docs/concepts/backends.md @@ -915,9 +915,9 @@ projects: ### CloudRift -Log into your [CloudRift :material-arrow-top-right-thin:{ .external }](https://www.cloudrift.ai/console/) console, click `API Keys` in the sidebar and click the button to create a new API key. +Log into your [CloudRift :material-arrow-top-right-thin:{ .external }](https://console.cloudrift.ai/) console, click `API Keys` in the sidebar and click the button to create a new API key. -Ensure you've created a project with CloudRift, +Ensure you've created a project with CloudRift. Then proceed to configuring the backend. diff --git a/src/dstack/_internal/core/backends/cloudrift/api_client.py b/src/dstack/_internal/core/backends/cloudrift/api_client.py index 00de2adb26..51bbfafd69 100644 --- a/src/dstack/_internal/core/backends/cloudrift/api_client.py +++ b/src/dstack/_internal/core/backends/cloudrift/api_client.py @@ -18,9 +18,7 @@ class RiftClient: def __init__(self, api_key: Optional[str] = None): - self.server_address = CLOUDRIFT_SERVER_ADDRESS self.public_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "api/v1") - self.internal_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "internal") self.api_key = api_key def validate_api_key(self) -> bool: @@ -31,7 +29,7 @@ def validate_api_key(self) -> bool: try: response = self._make_request("auth/me") if isinstance(response, dict): - return response.get("email", False) + return "email" in response return False except BackendInvalidCredentialsError: return False @@ -46,7 +44,7 @@ def get_instance_types(self) -> List[Dict]: return response_data.get("instance_types", []) return [] - def list_recipies(self) -> List[Dict]: + def list_recipes(self) -> List[Dict]: request_data = {} response_data = self._make_request("recipes/list", request_data) if isinstance(response_data, dict): @@ -58,12 +56,12 @@ def get_vm_recipies(self) -> List[Dict]: Retrieves a list of VM recipes from the CloudRift API. Returns a list of dictionaries containing recipe information. """ - recipe_group = self.list_recipies() + recipe_group = self.list_recipes() vm_recipes = [] for group in recipe_group: - tags = group.get("tags ", []) - has_vm = "vm" in tags - if group.get("name", "").lower() != "linux" and not has_vm: + tags = group.get("tags", []) + has_vm = "vm" in map(str.lower, tags) + if group.get("name", "").lower() != "linux" or not has_vm: continue recipes = group.get("recipes", []) @@ -108,7 +106,6 @@ def deploy_instance( request_data = { "config": { "VirtualMachine": { - # "cloudinit_url": "", "cloudinit_commands": cmd, "image_url": image_url, "ssh_key": {"PublicKeys": ssh_keys}, @@ -153,23 +150,6 @@ def get_instance_by_id(self, instance_id: str) -> Optional[Dict]: return None - def is_instance_ready(self, instance_id: str) -> bool: - """ - Checks if the instance with the given ID is ready. - Returns True if the instance is ready, False otherwise. - """ - instance_info = self.get_instance_by_id(instance_id) - if instance_info: - instance_type = instance_info.get("node_mode", "") - if instance_type == "VirtualMachine": - vms = instance_info.get("virtual_machines", []) - if len(vms) > 0: - vm_ready = vms[0].get("ready", False) - return vm_ready - else: - return instance_info.get("status", "") == "Active" - return False - def terminate_instance(self, instance_id: str) -> bool: request_data = {"selector": {"ById": [instance_id]}} logger.debug("Terminating instance with request data: %s", request_data) @@ -200,7 +180,7 @@ def _make_request( full_url, headers=headers, json={"version": version, "data": data}, - timeout=120, + timeout=15, **kwargs, ) diff --git a/src/dstack/_internal/core/backends/cloudrift/compute.py b/src/dstack/_internal/core/backends/cloudrift/compute.py index 38bef645dc..03d9fd74c6 100644 --- a/src/dstack/_internal/core/backends/cloudrift/compute.py +++ b/src/dstack/_internal/core/backends/cloudrift/compute.py @@ -16,6 +16,7 @@ InstanceOffer, InstanceOfferWithAvailability, ) +from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger @@ -63,9 +64,6 @@ def _get_offers_with_availability( availability_offers.append( InstanceOfferWithAvailability(**offer.dict(), availability=availability) ) - logger.debug( - f"Offer {offer.instance.name} in region {offer.region} has availability: {availability}" - ) return availability_offers @@ -73,6 +71,7 @@ def create_instance( self, instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration, + placement_group: Optional[PlacementGroup], ) -> JobProvisioningData: commands = get_shim_commands(authorized_keys=instance_config.get_public_keys()) startup_script = " ".join([" && ".join(commands)]) @@ -114,14 +113,22 @@ def update_provisioning_data( project_ssh_private_key: str, ): instance_info = self.client.get_instance_by_id(provisioning_data.instance_id) - if instance_info: - vms = instance_info.get("virtual_machines", []) - if len(vms) > 0: - vm_ready = vms[0].get("ready", False) - if vm_ready: - provisioning_data.hostname = instance_info.get("host_address", None) - - pass + + if not instance_info: + return + + instance_mode = instance_info.get("node_mode", "") + + if not instance_mode or instance_mode != "VirtualMachine": + return + + vms = instance_info.get("virtual_machines", []) + if len(vms) == 0: + return + + vm_ready = vms[0].get("ready", False) + if vm_ready: + provisioning_data.hostname = instance_info.get("host_address", None) def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None diff --git a/src/dstack/_internal/core/backends/cloudrift/configurator.py b/src/dstack/_internal/core/backends/cloudrift/configurator.py index 124fea2be9..62410b0ecd 100644 --- a/src/dstack/_internal/core/backends/cloudrift/configurator.py +++ b/src/dstack/_internal/core/backends/cloudrift/configurator.py @@ -20,9 +20,6 @@ BackendType, ) -# TODO: Add all supported regions and default regions -REGIONS = [] - class CloudRiftConfigurator(Configurator): TYPE = BackendType.CLOUDRIFT @@ -32,13 +29,10 @@ def validate_config( self, config: CloudRiftBackendConfigWithCreds, default_creds_enabled: bool ): self._validate_creds(config.creds) - # TODO: Validate additional config parameters if any def create_backend( self, project_name: str, config: CloudRiftBackendConfigWithCreds ) -> BackendRecord: - if config.regions is None: - config.regions = REGIONS return BackendRecord( config=CloudRiftStoredConfig( **CloudRiftBackendConfig.__response__.parse_obj(config).dict()