Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/docs/concepts/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,28 @@ projects:

</div>

### CloudRift

Log into your [CloudRift :material-arrow-top-right-thin:{ .external }](https://console.cloudrift.ai/) console, click `API Keys` in the sidebar and click the button to create a new API key.

Ensure you've created a project with CloudRift.

Then proceed to configuring the backend.

<div editor-title="~/.dstack/server/config.yml">

```yaml
projects:
- name: main
backends:
- type: cloudrift
creds:
type: api_key
api_key: rift_2prgY1d0laOrf2BblTwx2B2d1zcf1zIp4tZYpj5j88qmNgz38pxNlpX3vAo
```

</div>

## On-prem servers

### SSH fleets
Expand Down
17 changes: 17 additions & 0 deletions docs/docs/reference/server/config.yml.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,23 @@ to configure [backends](../../concepts/backends.md) and other [sever-level setti
type:
required: true

##### `projects[n].backends[type=cloudrift]` { #cloudrift data-toc-label="cloudrift" }

#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftBackendConfigWithCreds
overrides:
show_root_heading: false
type:
required: true
item_id_prefix: cloudrift-

###### `projects[n].backends[type=cloudrift].creds` { #cloudrift-creds data-toc-label="creds" }

#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftAPIKeyCreds
overrides:
show_root_heading: false
type:
required: true

### `encryption` { #encryption data-toc-label="encryption" }

#SCHEMA# dstack._internal.server.services.config.EncryptionConfig
Expand Down
Empty file.
208 changes: 208 additions & 0 deletions src/dstack/_internal/core/backends/cloudrift/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import os
import re
from typing import Any, Dict, List, Mapping, Optional, Union

import requests
from packaging import version
from requests import Response

from dstack._internal.core.errors import BackendError, BackendInvalidCredentialsError
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai"
CLOUDRIFT_API_VERSION = "2025-05-29"


class RiftClient:
def __init__(self, api_key: Optional[str] = None):
self.public_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "api/v1")
self.api_key = api_key

def validate_api_key(self) -> bool:
"""
Validates the API key by making a request to the server.
Returns True if the API key is valid, False otherwise.
"""
try:
response = self._make_request("auth/me")
if isinstance(response, dict):
return "email" in response
return False
except BackendInvalidCredentialsError:
return False
except Exception as e:
logger.error(f"Error validating API key: {e}")
return False

def get_instance_types(self) -> List[Dict]:
request_data = {"selector": {"ByServiceAndLocation": {"services": ["vm"]}}}
response_data = self._make_request("instance-types/list", request_data)
if isinstance(response_data, dict):
return response_data.get("instance_types", [])
return []

def list_recipes(self) -> List[Dict]:
request_data = {}
response_data = self._make_request("recipes/list", request_data)
if isinstance(response_data, dict):
return response_data.get("groups", [])
return []

def get_vm_recipies(self) -> List[Dict]:
"""
Retrieves a list of VM recipes from the CloudRift API.
Returns a list of dictionaries containing recipe information.
"""
recipe_group = self.list_recipes()
vm_recipes = []
for group in recipe_group:
tags = group.get("tags", [])
has_vm = "vm" in map(str.lower, tags)
if group.get("name", "").lower() != "linux" or not has_vm:
continue

recipes = group.get("recipes", [])
for recipe in recipes:
details = recipe.get("details", {})
if details.get("VirtualMachine", False):
vm_recipes.append(recipe)

return vm_recipes

def get_vm_image_url(self) -> Optional[str]:
recipes = self.get_vm_recipies()
ubuntu_images = []
for recipe in recipes:
has_nvidia_driver = "nvidia-driver" in recipe.get("tags", [])
if not has_nvidia_driver:
continue

recipe_name = recipe.get("name", "")
if "Ubuntu" not in recipe_name:
continue

url = recipe["details"].get("VirtualMachine", {}).get("image_url", None)
version_match = re.search(r".* (\d+\.\d+)", recipe_name)
if url and version_match and version_match.group(1):
ubuntu_version = version.parse(version_match.group(1))
ubuntu_images.append((ubuntu_version, url))

ubuntu_images.sort(key=lambda x: x[0]) # Sort by version
if ubuntu_images:
return ubuntu_images[-1][1]

return None

def deploy_instance(
self, instance_type: str, region: str, ssh_keys: List[str], cmd: str
) -> List[str]:
image_url = self.get_vm_image_url()
if not image_url:
raise BackendError("No suitable VM image found.")

request_data = {
"config": {
"VirtualMachine": {
"cloudinit_commands": cmd,
"image_url": image_url,
"ssh_key": {"PublicKeys": ssh_keys},
}
},
"selector": {
"ByInstanceTypeAndLocation": {
"datacenters": [region],
"instance_type": instance_type,
}
},
"with_public_ip": True,
}
logger.debug("Deploying instance with request data: %s", request_data)

response_data = self._make_request("instances/rent", request_data)
if isinstance(response_data, dict):
return response_data.get("instance_ids", [])
return []

def list_instances(self, instance_ids: Optional[List[str]] = None) -> List[Dict]:
request_data = {
"selector": {
"ByStatus": ["Initializing", "Active", "Deactivating"],
}
}
logger.debug("Listing instances with request data: %s", request_data)
response_data = self._make_request("instances/list", request_data)
if isinstance(response_data, dict):
return response_data.get("instances", [])

return []
Comment on lines +128 to +140
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Unused

Suggested change
def list_instances(self, instance_ids: Optional[List[str]] = None) -> List[Dict]:
request_data = {
"selector": {
"ByStatus": ["Initializing", "Active", "Deactivating"],
}
}
logger.debug("Listing instances with request data: %s", request_data)
response_data = self._make_request("instances/list", request_data)
if isinstance(response_data, dict):
return response_data.get("instances", [])
return []


def get_instance_by_id(self, instance_id: str) -> Optional[Dict]:
request_data = {"selector": {"ById": [instance_id]}}
logger.debug("Getting instance with request data: %s", request_data)
response_data = self._make_request("instances/list", request_data)
if isinstance(response_data, dict):
instances = response_data.get("instances", [])
if isinstance(instances, list) and len(instances) > 0:
return instances[0]

return None

def terminate_instance(self, instance_id: str) -> bool:
request_data = {"selector": {"ById": [instance_id]}}
logger.debug("Terminating instance with request data: %s", request_data)
response_data = self._make_request("instances/terminate", request_data)
if isinstance(response_data, dict):
info = response_data.get("terminated", [])
return len(info) > 0

return False

def _make_request(
self,
endpoint: str,
data: Optional[Mapping[str, Any]] = None,
method: str = "POST",
**kwargs,
) -> Union[Mapping[str, Any], str, Response]:
headers = {}
if self.api_key is not None:
headers["X-API-Key"] = self.api_key

version = CLOUDRIFT_API_VERSION
full_url = f"{self.public_api_root}/{endpoint}"

try:
response = requests.request(
method,
full_url,
headers=headers,
json={"version": version, "data": data},
timeout=15,
**kwargs,
)

if not response.ok:
response.raise_for_status()
try:
response_json = response.json()
if isinstance(response_json, str):
return response_json
if version is not None and version < response_json["version"]:
logger.warning(
"The API version %s is lower than the server version %s. ",
version,
response_json["version"],
)
return response_json["data"]
except requests.exceptions.JSONDecodeError:
return response
except requests.HTTPError as e:
if e.response is not None and e.response.status_code in (
requests.codes.forbidden,
requests.codes.unauthorized,
):
raise BackendInvalidCredentialsError(e.response.text)
raise
16 changes: 16 additions & 0 deletions src/dstack/_internal/core/backends/cloudrift/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.cloudrift.compute import CloudRiftCompute
from dstack._internal.core.backends.cloudrift.models import CloudRiftConfig
from dstack._internal.core.models.backends.base import BackendType


class CloudRiftBackend(Backend):
TYPE = BackendType.CLOUDRIFT
COMPUTE_CLASS = CloudRiftCompute

def __init__(self, config: CloudRiftConfig):
self.config = config
self._compute = CloudRiftCompute(self.config)

def compute(self) -> CloudRiftCompute:
return self._compute
Loading