Skip to content

Commit f5f2f2e

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
add hotaisle backend
1 parent b5f26c8 commit f5f2f2e

File tree

9 files changed

+459
-0
lines changed

9 files changed

+459
-0
lines changed

src/dstack/_internal/core/backends/configurators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@
5454
except ImportError:
5555
pass
5656

57+
try:
58+
from dstack._internal.core.backends.hotaisle.configurator import (
59+
HotaisleConfigurator,
60+
)
61+
62+
_CONFIGURATOR_CLASSES.append(HotaisleConfigurator)
63+
except ImportError:
64+
pass
65+
5766
try:
5867
from dstack._internal.core.backends.kubernetes.configurator import (
5968
KubernetesConfigurator,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Hotaisle backend for dstack
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from typing import Any, Dict, Optional
2+
3+
import requests
4+
5+
from dstack._internal.utils.logging import get_logger
6+
7+
API_URL = "https://admin.hotaisle.app/api"
8+
9+
logger = get_logger(__name__)
10+
11+
12+
class HotaisleAPIClient:
13+
def __init__(self, api_key: str, team_handle: str):
14+
self.api_key = api_key
15+
self.team_handle = team_handle
16+
17+
def validate_api_key(self) -> bool:
18+
try:
19+
self._validate_user_and_team()
20+
return True
21+
except requests.HTTPError as e:
22+
if e.response.status_code in [401, 403]:
23+
return False
24+
raise e
25+
except ValueError:
26+
return False
27+
28+
def _validate_user_and_team(self) -> None:
29+
url = f"{API_URL}/user/"
30+
response = self._make_request("GET", url)
31+
32+
if response.ok:
33+
user_data = response.json()
34+
else:
35+
response.raise_for_status()
36+
37+
teams = user_data.get("teams", [])
38+
if not teams:
39+
raise ValueError("No Hotaisle teams found for this user")
40+
41+
available_teams = [team["handle"] for team in teams]
42+
if self.team_handle not in available_teams:
43+
raise ValueError(f"Hotaisle Team '{self.team_handle}' not found.")
44+
45+
def upload_ssh_key(self, public_key: str) -> bool:
46+
url = f"{API_URL}/user/ssh_keys/"
47+
payload = {"authorized_key": public_key}
48+
49+
response = self._make_request("POST", url, json=payload)
50+
51+
if response.status_code == 409:
52+
return True # Key already exists - success
53+
if not response.ok:
54+
response.raise_for_status()
55+
return True
56+
57+
def create_virtual_machine(
58+
self, vm_payload: Dict[str, Any], instance_name: str
59+
) -> Dict[str, Any]:
60+
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/"
61+
response = self._make_request("POST", url, json=vm_payload)
62+
63+
if not response.ok:
64+
response.raise_for_status()
65+
66+
vm_data = response.json()
67+
return vm_data
68+
69+
def get_vm_state(self, vm_name: str) -> str:
70+
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/state/"
71+
response = self._make_request("GET", url)
72+
73+
if not response.ok:
74+
response.raise_for_status()
75+
76+
state_data = response.json()
77+
return state_data["state"]
78+
79+
def terminate_virtual_machine(self, vm_name: str) -> bool:
80+
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/"
81+
response = self._make_request("DELETE", url)
82+
83+
if response.status_code == 204:
84+
return True
85+
else:
86+
response.raise_for_status()
87+
88+
def _make_request(
89+
self, method: str, url: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30
90+
) -> requests.Response:
91+
headers = {
92+
"accept": "application/json",
93+
"Authorization": self.api_key,
94+
}
95+
if json is not None:
96+
headers["Content-Type"] = "application/json"
97+
98+
return requests.request(
99+
method=method,
100+
url=url,
101+
headers=headers,
102+
json=json,
103+
timeout=timeout,
104+
)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from dstack._internal.core.backends.base.backend import Backend
2+
from dstack._internal.core.backends.hotaisle.compute import HotaisleCompute
3+
from dstack._internal.core.backends.hotaisle.models import HotaisleConfig
4+
from dstack._internal.core.models.backends.base import BackendType
5+
6+
7+
class HotaisleBackend(Backend):
8+
TYPE = BackendType.HOTAISLE
9+
COMPUTE_CLASS = HotaisleCompute
10+
11+
def __init__(self, config: HotaisleConfig):
12+
self.config = config
13+
self._compute = HotaisleCompute(self.config)
14+
15+
def compute(self) -> HotaisleCompute:
16+
return self._compute
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import shlex
2+
import subprocess
3+
import tempfile
4+
from threading import Thread
5+
from typing import List, Optional
6+
7+
import gpuhunt
8+
from gpuhunt.providers.hotaisle import HotAisleProvider
9+
10+
from dstack._internal.core.backends.base.compute import (
11+
Compute,
12+
ComputeWithCreateInstanceSupport,
13+
generate_unique_instance_name,
14+
get_shim_commands,
15+
)
16+
from dstack._internal.core.backends.base.offers import get_catalog_offers
17+
from dstack._internal.core.backends.hotaisle.api_client import HotaisleAPIClient
18+
from dstack._internal.core.backends.hotaisle.models import HotaisleConfig
19+
from dstack._internal.core.models.backends.base import BackendType
20+
from dstack._internal.core.models.instances import (
21+
InstanceAvailability,
22+
InstanceConfiguration,
23+
InstanceOfferWithAvailability,
24+
)
25+
from dstack._internal.core.models.placement import PlacementGroup
26+
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
27+
from dstack._internal.utils.logging import get_logger
28+
29+
logger = get_logger(__name__)
30+
31+
MAX_INSTANCE_NAME_LEN = 60
32+
33+
34+
class HotaisleCompute(
35+
ComputeWithCreateInstanceSupport,
36+
Compute,
37+
):
38+
def __init__(self, config: HotaisleConfig):
39+
super().__init__()
40+
self.config = config
41+
self.api_client = HotaisleAPIClient(config.creds.api_key, config.team_handle)
42+
self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
43+
self.catalog.add_provider(
44+
HotAisleProvider(api_key=config.creds.api_key, team_handle=config.team_handle)
45+
)
46+
47+
def get_offers(
48+
self, requirements: Optional[Requirements] = None
49+
) -> List[InstanceOfferWithAvailability]:
50+
offers = get_catalog_offers(
51+
backend=BackendType.HOTAISLE,
52+
locations=self.config.regions or None,
53+
requirements=requirements,
54+
catalog=self.catalog,
55+
)
56+
offers = [
57+
InstanceOfferWithAvailability(
58+
**offer.dict(), availability=InstanceAvailability.AVAILABLE
59+
)
60+
for offer in offers
61+
]
62+
return offers
63+
64+
def get_payload_from_offer(self, instance_type) -> dict:
65+
# Only two instance types are available in Hotaisle with CPUs: 8-core and 13-core. Other fields are
66+
# not configurable.
67+
cpu_cores = instance_type.resources.cpus
68+
if cpu_cores == 8:
69+
cpu_model = "Xeon Platinum 8462Y+"
70+
frequency = 2800000000
71+
else: # cpu_cores == 13
72+
cpu_model = "Xeon Platinum 8470"
73+
frequency = 2000000000
74+
75+
return {
76+
"cpu_cores": cpu_cores,
77+
"cpus": {
78+
"count": 1,
79+
"manufacturer": "Intel",
80+
"model": cpu_model,
81+
"cores": cpu_cores,
82+
"frequency": frequency,
83+
},
84+
"disk_capacity": 13194139533312,
85+
"ram_capacity": 240518168576,
86+
"gpus": [
87+
{
88+
"count": len(instance_type.resources.gpus),
89+
"manufacturer": "AMD",
90+
"model": "MI300X",
91+
}
92+
],
93+
}
94+
95+
def create_instance(
96+
self,
97+
instance_offer: InstanceOfferWithAvailability,
98+
instance_config: InstanceConfiguration,
99+
placement_group: Optional[PlacementGroup],
100+
) -> JobProvisioningData:
101+
instance_name = generate_unique_instance_name(
102+
instance_config, max_length=MAX_INSTANCE_NAME_LEN
103+
)
104+
project_ssh_key = instance_config.ssh_keys[0]
105+
self.api_client.upload_ssh_key(project_ssh_key.public)
106+
vm_payload = self.get_payload_from_offer(instance_offer.instance)
107+
vm_data = self.api_client.create_virtual_machine(vm_payload, instance_name)
108+
return JobProvisioningData(
109+
backend=instance_offer.backend,
110+
instance_type=instance_offer.instance,
111+
instance_id=vm_data["name"],
112+
hostname=None,
113+
internal_ip=None,
114+
region=instance_offer.region,
115+
price=instance_offer.price,
116+
username="hotaisle",
117+
ssh_port=22,
118+
dockerized=True,
119+
ssh_proxy=None,
120+
backend_data=vm_data["ip_address"],
121+
)
122+
123+
def update_provisioning_data(
124+
self,
125+
provisioning_data: JobProvisioningData,
126+
project_ssh_public_key: str,
127+
project_ssh_private_key: str,
128+
):
129+
vm_state = self.api_client.get_vm_state(provisioning_data.instance_id)
130+
if vm_state == "running":
131+
if provisioning_data.hostname is None and provisioning_data.backend_data:
132+
provisioning_data.hostname = provisioning_data.backend_data
133+
commands = get_shim_commands(
134+
authorized_keys=[project_ssh_public_key],
135+
arch=provisioning_data.instance_type.resources.cpu_arch,
136+
)
137+
launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands))
138+
thread = Thread(
139+
target=_start_runner,
140+
kwargs={
141+
"hostname": provisioning_data.hostname,
142+
"project_ssh_private_key": project_ssh_private_key,
143+
"launch_command": launch_command,
144+
},
145+
daemon=True,
146+
)
147+
thread.start()
148+
149+
def terminate_instance(
150+
self, instance_id: str, region: str, backend_data: Optional[str] = None
151+
):
152+
vm_name = instance_id
153+
self.api_client.terminate_virtual_machine(vm_name)
154+
155+
156+
def _start_runner(
157+
hostname: str,
158+
project_ssh_private_key: str,
159+
launch_command: str,
160+
):
161+
_setup_instance(
162+
hostname=hostname,
163+
ssh_private_key=project_ssh_private_key,
164+
)
165+
_launch_runner(
166+
hostname=hostname,
167+
ssh_private_key=project_ssh_private_key,
168+
launch_command=launch_command,
169+
)
170+
171+
172+
def _setup_instance(
173+
hostname: str,
174+
ssh_private_key: str,
175+
):
176+
setup_commands = ("sudo apt-get update",)
177+
_run_ssh_command(
178+
hostname=hostname, ssh_private_key=ssh_private_key, command=" && ".join(setup_commands)
179+
)
180+
181+
182+
def _launch_runner(
183+
hostname: str,
184+
ssh_private_key: str,
185+
launch_command: str,
186+
):
187+
_run_ssh_command(
188+
hostname=hostname,
189+
ssh_private_key=ssh_private_key,
190+
command=launch_command,
191+
)
192+
193+
194+
def _run_ssh_command(hostname: str, ssh_private_key: str, command: str):
195+
with tempfile.NamedTemporaryFile("w+", 0o600) as f:
196+
f.write(ssh_private_key)
197+
f.flush()
198+
subprocess.run(
199+
[
200+
"ssh",
201+
"-F",
202+
"none",
203+
"-o",
204+
"StrictHostKeyChecking=no",
205+
"-i",
206+
f.name,
207+
f"hotaisle@{hostname}",
208+
command,
209+
],
210+
stdout=subprocess.DEVNULL,
211+
stderr=subprocess.DEVNULL,
212+
)

0 commit comments

Comments
 (0)