Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/backends/configurators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@
except ImportError:
pass

try:
from dstack._internal.core.backends.hotaisle.configurator import (
HotaisleConfigurator,
)

_CONFIGURATOR_CLASSES.append(HotaisleConfigurator)
except ImportError:
pass

try:
from dstack._internal.core.backends.kubernetes.configurator import (
KubernetesConfigurator,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/hotaisle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Hotaisle backend for dstack
104 changes: 104 additions & 0 deletions src/dstack/_internal/core/backends/hotaisle/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Any, Dict, Optional

import requests

from dstack._internal.utils.logging import get_logger

API_URL = "https://admin.hotaisle.app/api"

logger = get_logger(__name__)


class HotaisleAPIClient:
def __init__(self, api_key: str, team_handle: str):
self.api_key = api_key
self.team_handle = team_handle

def validate_api_key(self) -> bool:
try:
self._validate_user_and_team()
return True
except requests.HTTPError as e:
if e.response.status_code in [401, 403]:
return False
raise e
except ValueError:
return False

def _validate_user_and_team(self) -> None:
url = f"{API_URL}/user/"
response = self._make_request("GET", url)

if response.ok:
user_data = response.json()
else:
response.raise_for_status()

teams = user_data.get("teams", [])
if not teams:
raise ValueError("No Hotaisle teams found for this user")

available_teams = [team["handle"] for team in teams]
if self.team_handle not in available_teams:
raise ValueError(f"Hotaisle Team '{self.team_handle}' not found.")

def upload_ssh_key(self, public_key: str) -> bool:
url = f"{API_URL}/user/ssh_keys/"
payload = {"authorized_key": public_key}

response = self._make_request("POST", url, json=payload)

if response.status_code == 409:
return True # Key already exists - success
if not response.ok:
response.raise_for_status()
return True

def create_virtual_machine(
self, vm_payload: Dict[str, Any], instance_name: str
) -> Dict[str, Any]:
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/"
response = self._make_request("POST", url, json=vm_payload)

if not response.ok:
response.raise_for_status()

vm_data = response.json()
return vm_data

def get_vm_state(self, vm_name: str) -> str:
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/state/"
response = self._make_request("GET", url)

if not response.ok:
response.raise_for_status()

state_data = response.json()
return state_data["state"]

def terminate_virtual_machine(self, vm_name: str) -> bool:
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/"
response = self._make_request("DELETE", url)

if response.status_code == 204:
return True
else:
response.raise_for_status()

def _make_request(
self, method: str, url: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30
) -> requests.Response:
headers = {
"accept": "application/json",
"Authorization": self.api_key,
}
if json is not None:
headers["Content-Type"] = "application/json"

return requests.request(
method=method,
url=url,
headers=headers,
json=json,
timeout=timeout,
)
16 changes: 16 additions & 0 deletions src/dstack/_internal/core/backends/hotaisle/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.hotaisle.compute import HotaisleCompute
from dstack._internal.core.backends.hotaisle.models import HotaisleConfig
from dstack._internal.core.models.backends.base import BackendType


class HotaisleBackend(Backend):
TYPE = BackendType.HOTAISLE
COMPUTE_CLASS = HotaisleCompute

def __init__(self, config: HotaisleConfig):
self.config = config
self._compute = HotaisleCompute(self.config)

def compute(self) -> HotaisleCompute:
return self._compute
213 changes: 213 additions & 0 deletions src/dstack/_internal/core/backends/hotaisle/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import shlex
import subprocess
import tempfile
from threading import Thread
from typing import List, Optional

import gpuhunt
from gpuhunt.providers.hotaisle import HotAisleProvider

from dstack._internal.core.backends.base.compute import (
Compute,
ComputeWithCreateInstanceSupport,
generate_unique_instance_name,
get_shim_commands,
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.hotaisle.api_client import HotaisleAPIClient
from dstack._internal.core.backends.hotaisle.models import HotaisleConfig
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
InstanceOfferWithAvailability,
)
from dstack._internal.core.models.placement import PlacementGroup
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)

MAX_INSTANCE_NAME_LEN = 60


class HotaisleCompute(
ComputeWithCreateInstanceSupport,
Compute,
):
def __init__(self, config: HotaisleConfig):
super().__init__()
self.config = config
self.api_client = HotaisleAPIClient(config.creds.api_key, config.team_handle)
self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
self.catalog.add_provider(
HotAisleProvider(api_key=config.creds.api_key, team_handle=config.team_handle)
)

def get_offers(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
offers = get_catalog_offers(
backend=BackendType.HOTAISLE,
locations=self.config.regions or None,
requirements=requirements,
catalog=self.catalog,
)
offers = [
InstanceOfferWithAvailability(
**offer.dict(), availability=InstanceAvailability.AVAILABLE
)
for offer in offers
]
return offers

def get_payload_from_offer(self, instance_type) -> dict:
# Only two instance types are available in Hotaisle with CPUs: 8-core and 13-core. Other fields are
# not configurable.
cpu_cores = instance_type.resources.cpus
if cpu_cores == 8:
cpu_model = "Xeon Platinum 8462Y+"
frequency = 2800000000
else: # cpu_cores == 13
cpu_model = "Xeon Platinum 8470"
frequency = 2000000000

return {
"cpu_cores": cpu_cores,
"cpus": {
"count": 1,
"manufacturer": "Intel",
"model": cpu_model,
"cores": cpu_cores,
"frequency": frequency,
},
"disk_capacity": 13194139533312,
"ram_capacity": 240518168576,
"gpus": [
{
"count": len(instance_type.resources.gpus),
"manufacturer": "AMD",
"model": "MI300X",
}
],
}

def create_instance(
self,
instance_offer: InstanceOfferWithAvailability,
instance_config: InstanceConfiguration,
placement_group: Optional[PlacementGroup],
) -> JobProvisioningData:
instance_name = generate_unique_instance_name(
instance_config, max_length=MAX_INSTANCE_NAME_LEN
)
project_ssh_key = instance_config.ssh_keys[0]
self.api_client.upload_ssh_key(project_ssh_key.public)
vm_payload = self.get_payload_from_offer(instance_offer.instance)
vm_data = self.api_client.create_virtual_machine(vm_payload, instance_name)
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=vm_data["name"],
hostname=None,
internal_ip=None,
region=instance_offer.region,
price=instance_offer.price,
username="hotaisle",
ssh_port=22,
dockerized=True,
ssh_proxy=None,
backend_data=vm_data["ip_address"],
)

def update_provisioning_data(
self,
provisioning_data: JobProvisioningData,
project_ssh_public_key: str,
project_ssh_private_key: str,
):
vm_state = self.api_client.get_vm_state(provisioning_data.instance_id)
if vm_state == "running":
if provisioning_data.hostname is None and provisioning_data.backend_data:
provisioning_data.hostname = provisioning_data.backend_data
commands = get_shim_commands(
authorized_keys=[project_ssh_public_key],
arch=provisioning_data.instance_type.resources.cpu_arch,
)
launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands))
thread = Thread(
target=_start_runner,
kwargs={
"hostname": provisioning_data.hostname,
"project_ssh_private_key": project_ssh_private_key,
"launch_command": launch_command,
},
daemon=True,
)
thread.start()

def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
):
vm_name = instance_id
self.api_client.terminate_virtual_machine(vm_name)


def _start_runner(
hostname: str,
project_ssh_private_key: str,
launch_command: str,
):
_setup_instance(
hostname=hostname,
ssh_private_key=project_ssh_private_key,
)
_launch_runner(
hostname=hostname,
ssh_private_key=project_ssh_private_key,
launch_command=launch_command,
)


def _setup_instance(
hostname: str,
ssh_private_key: str,
):
setup_commands = ("sudo apt-get update",)
_run_ssh_command(
hostname=hostname, ssh_private_key=ssh_private_key, command=" && ".join(setup_commands)
)


def _launch_runner(
hostname: str,
ssh_private_key: str,
launch_command: str,
):
daemonized_command = f"{launch_command.rstrip('&')} >/tmp/dstack-shim.log 2>&1 & disown"
_run_ssh_command(
hostname=hostname,
ssh_private_key=ssh_private_key,
command=daemonized_command,
)


def _run_ssh_command(hostname: str, ssh_private_key: str, command: str):
with tempfile.NamedTemporaryFile("w+", 0o600) as f:
f.write(ssh_private_key)
f.flush()
subprocess.run(
[
"ssh",
"-F",
"none",
"-o",
"StrictHostKeyChecking=no",
"-i",
f.name,
f"hotaisle@{hostname}",
command,
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
Loading
Loading