Skip to content

Commit aa885be

Browse files
Add Crusoe Cloud provider (#211)
* Add Crusoe Cloud provider Implement an online provider for Crusoe Cloud that fetches real-time VM availability from the capacities API, combined with instance type specs and hardcoded per-GPU/per-vCPU pricing. Key design decisions: - Online provider: the capacities API provides real-time availability (quantity > 0), so dstack only tries actually-available instances - Pricing hardcoded from https://crusoe.ai/cloud/pricing since the API does not expose prices - HMAC-SHA256 authentication per Crusoe API docs - Requires CRUSOE_ACCESS_KEY, CRUSOE_SECRET_KEY, CRUSOE_PROJECT_ID Co-authored-by: Cursor <cursoragent@cursor.com> * Include disk_gb in provider_data for Crusoe offers This allows dstack to distinguish instance types with ephemeral NVMe storage (disk_gb > 0) from types that need a separate data disk (disk_gb == 0), enabling configurable disk sizes. Co-authored-by: Cursor <cursoragent@cursor.com> * Disable spot offers until API support is confirmed The Crusoe VM create API doesn't have an obvious field for requesting spot billing. Disable spot offers until we confirm how to provision spot instances. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 3f11826 commit aa885be

4 files changed

Lines changed: 255 additions & 1 deletion

File tree

src/gpuhunt/__main__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def main():
1414
"aws",
1515
"azure",
1616
"cloudrift",
17+
"crusoe",
1718
"cudo",
1819
"verda",
1920
"digitalocean",
@@ -40,6 +41,14 @@ def main():
4041
from gpuhunt.providers.azure import AzureProvider
4142

4243
provider = AzureProvider(os.getenv("AZURE_SUBSCRIPTION_ID"))
44+
elif args.provider == "crusoe":
45+
from gpuhunt.providers.crusoe import CrusoeProvider
46+
47+
provider = CrusoeProvider(
48+
access_key=os.getenv("CRUSOE_ACCESS_KEY"),
49+
secret_key=os.getenv("CRUSOE_SECRET_KEY"),
50+
project_id=os.getenv("CRUSOE_PROJECT_ID"),
51+
)
4352
elif args.provider == "cudo":
4453
from gpuhunt.providers.cudo import CudoProvider
4554

src/gpuhunt/_internal/catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
"runpod",
3434
"cloudrift",
3535
]
36-
ONLINE_PROVIDERS = ["cudo", "digitalocean", "hotaisle", "vastai", "vultr"]
36+
ONLINE_PROVIDERS = ["crusoe", "cudo", "digitalocean", "hotaisle", "vastai", "vultr"]
3737
RELOAD_INTERVAL = 15 * 60 # 15 minutes
3838

3939

src/gpuhunt/_internal/default.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def default_catalog() -> Catalog:
2121
for module, provider in [
2222
("gpuhunt.providers.vastai", "VastAIProvider"),
2323
("gpuhunt.providers.cudo", "CudoProvider"),
24+
("gpuhunt.providers.crusoe", "CrusoeProvider"),
2425
("gpuhunt.providers.vultr", "VultrProvider"),
2526
("gpuhunt.providers.hotaisle", "HotAisleProvider"),
2627
("gpuhunt.providers.digitalocean", "DigitalOceanProvider"),

src/gpuhunt/providers/crusoe.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import base64
2+
import copy
3+
import datetime
4+
import hashlib
5+
import hmac
6+
import logging
7+
import os
8+
from collections import defaultdict
9+
from typing import Optional
10+
11+
import requests
12+
13+
from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem
14+
from gpuhunt.providers import AbstractProvider
15+
16+
logger = logging.getLogger(__name__)
17+
18+
API_URL = "https://api.crusoecloud.com"
19+
API_VERSION = "/v1alpha5"
20+
SIGNATURE_VERSION = "1.0"
21+
TIMEOUT = 30
22+
23+
GPU_TYPE_MAP: dict[str, tuple[str, AcceleratorVendor, float]] = {
24+
# gpu_type -> (gpuhunt_name, vendor, vram_gb)
25+
"A100-PCIe-40GB": ("A100", AcceleratorVendor.NVIDIA, 40),
26+
"A100-PCIe-80GB": ("A100", AcceleratorVendor.NVIDIA, 80),
27+
"A100-SXM-80GB": ("A100", AcceleratorVendor.NVIDIA, 80),
28+
"H100-SXM-80GB": ("H100", AcceleratorVendor.NVIDIA, 80),
29+
"L40S-48GB": ("L40S", AcceleratorVendor.NVIDIA, 48),
30+
"A40-PCIe-48GB": ("A40", AcceleratorVendor.NVIDIA, 48),
31+
"MI300X-192GB": ("MI300X", AcceleratorVendor.AMD, 192),
32+
# TODO: The following GPUs are listed on https://crusoe.ai/cloud/pricing but not yet
33+
# returned by the instance types API. Add them once Crusoe exposes them:
34+
# - H200 141GB ($4.29/GPU-hr on-demand, spot: contact sales)
35+
# - GB200 186GB (contact sales)
36+
# - B200 180GB (contact sales)
37+
# - MI355X 288GB ($3.45 listed but not confirmed; also missing from KNOWN_AMD_GPUS)
38+
}
39+
40+
# Per-GPU-hour pricing from https://crusoe.ai/cloud/pricing
41+
GPU_PRICING: dict[str, tuple[float, Optional[float]]] = {
42+
# gpu_type -> (on_demand_per_gpu_hr, spot_per_gpu_hr or None)
43+
"A100-PCIe-40GB": (1.45, 1.00),
44+
"A100-PCIe-80GB": (1.65, 1.20),
45+
"A100-SXM-80GB": (1.95, 1.30),
46+
"H100-SXM-80GB": (3.90, 1.60),
47+
"L40S-48GB": (1.00, 0.50),
48+
"A40-PCIe-48GB": (0.90, 0.40),
49+
"MI300X-192GB": (3.45, 0.95),
50+
}
51+
52+
# Per-vCPU-hour pricing from https://crusoe.ai/cloud/pricing
53+
CPU_PRICING: dict[str, float] = {
54+
# product_name prefix -> per_vcpu_hr
55+
"c1a": 0.04,
56+
"s1a": 0.09,
57+
}
58+
59+
60+
class CrusoeProvider(AbstractProvider):
61+
NAME = "crusoe"
62+
63+
def __init__(
64+
self,
65+
access_key: Optional[str] = None,
66+
secret_key: Optional[str] = None,
67+
project_id: Optional[str] = None,
68+
):
69+
self.access_key = access_key or os.getenv("CRUSOE_ACCESS_KEY")
70+
self.secret_key = secret_key or os.getenv("CRUSOE_SECRET_KEY")
71+
self.project_id = project_id or os.getenv("CRUSOE_PROJECT_ID")
72+
73+
if not self.access_key:
74+
raise ValueError("Set the CRUSOE_ACCESS_KEY environment variable.")
75+
if not self.secret_key:
76+
raise ValueError("Set the CRUSOE_SECRET_KEY environment variable.")
77+
if not self.project_id:
78+
raise ValueError("Set the CRUSOE_PROJECT_ID environment variable.")
79+
80+
def get(
81+
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
82+
) -> list[RawCatalogItem]:
83+
instance_types = self._get_instance_types()
84+
type_specs = {t["product_name"]: t for t in instance_types}
85+
86+
# Note: capacities reflect hardware availability, not project quotas.
87+
# Quota enforcement should be done on the dstack side via
88+
# GET /projects/{project_id}/quotas, which returns per-instance-family
89+
# quota with max/used/available fields.
90+
capacities = self._get_capacities()
91+
available = _get_available_type_locations(capacities)
92+
93+
offers = []
94+
for product_name, locations in available.items():
95+
spec = type_specs.get(product_name)
96+
if spec is None:
97+
logger.warning("Capacity for unknown instance type %s, skipping", product_name)
98+
continue
99+
100+
items = _make_catalog_items(product_name, spec, locations)
101+
offers.extend(items)
102+
103+
return sorted(offers, key=lambda i: i.price)
104+
105+
def _get_instance_types(self) -> list[dict]:
106+
resp = self._request("GET", f"/projects/{self.project_id}/compute/vms/types")
107+
resp.raise_for_status()
108+
return resp.json()["items"]
109+
110+
def _get_capacities(self) -> list[dict]:
111+
resp = self._request("GET", "/capacities")
112+
resp.raise_for_status()
113+
return resp.json()["items"]
114+
115+
def _request(self, method: str, path: str, params: Optional[dict] = None) -> requests.Response:
116+
dt = str(datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0))
117+
dt = dt.replace(" ", "T")
118+
119+
query_string = ""
120+
if params:
121+
query_string = "&".join(f"{k}={v}" for k, v in sorted(params.items()))
122+
123+
payload = f"{API_VERSION}{path}\n{query_string}\n{method}\n{dt}\n"
124+
125+
decoded_secret = base64.urlsafe_b64decode(
126+
self.secret_key + "=" * (-len(self.secret_key) % 4)
127+
)
128+
sig = hmac.new(decoded_secret, msg=payload.encode("ascii"), digestmod=hashlib.sha256)
129+
encoded_sig = base64.urlsafe_b64encode(sig.digest()).decode("ascii").rstrip("=")
130+
131+
headers = {
132+
"X-Crusoe-Timestamp": dt,
133+
"Authorization": f"Bearer {SIGNATURE_VERSION}:{self.access_key}:{encoded_sig}",
134+
}
135+
136+
url = f"{API_URL}{API_VERSION}{path}"
137+
return requests.request(method, url, headers=headers, params=params, timeout=TIMEOUT)
138+
139+
140+
def _get_available_type_locations(capacities: list[dict]) -> dict[str, list[str]]:
141+
best_qty: dict[tuple[str, str], int] = defaultdict(int)
142+
for cap in capacities:
143+
key = (cap["type"], cap["location"])
144+
best_qty[key] = max(best_qty[key], cap["quantity"])
145+
146+
result: dict[str, list[str]] = defaultdict(list)
147+
for (instance_type, location), qty in best_qty.items():
148+
if qty > 0:
149+
result[instance_type].append(location)
150+
return dict(result)
151+
152+
153+
def _make_catalog_items(
154+
product_name: str, spec: dict, locations: list[str]
155+
) -> list[RawCatalogItem]:
156+
gpu_type = spec.get("gpu_type", "")
157+
num_gpu = spec.get("num_gpu", 0)
158+
159+
if num_gpu > 0 and gpu_type:
160+
return _make_gpu_items(product_name, spec, gpu_type, locations)
161+
else:
162+
return _make_cpu_items(product_name, spec, locations)
163+
164+
165+
def _make_gpu_items(
166+
product_name: str, spec: dict, gpu_type: str, locations: list[str]
167+
) -> list[RawCatalogItem]:
168+
gpu_info = GPU_TYPE_MAP.get(gpu_type)
169+
if gpu_info is None:
170+
logger.warning("Unknown GPU type %s for %s, skipping", gpu_type, product_name)
171+
return []
172+
173+
pricing = GPU_PRICING.get(gpu_type)
174+
if pricing is None:
175+
logger.warning("No pricing for GPU type %s (%s), skipping", gpu_type, product_name)
176+
return []
177+
178+
gpu_name, gpu_vendor, gpu_memory = gpu_info
179+
on_demand_per_gpu, spot_per_gpu = pricing
180+
num_gpu = spec["num_gpu"]
181+
182+
template = RawCatalogItem(
183+
instance_name=product_name,
184+
location=None,
185+
price=None,
186+
cpu=spec["cpu_cores"],
187+
memory=float(spec["memory_gb"]),
188+
gpu_vendor=gpu_vendor.value,
189+
gpu_count=num_gpu,
190+
gpu_name=gpu_name,
191+
gpu_memory=gpu_memory,
192+
spot=None,
193+
disk_size=float(spec["disk_gb"]) if spec.get("disk_gb") else None,
194+
# disk_gb: ephemeral NVMe size in GB (0 = no ephemeral disk).
195+
# Used by dstack to decide whether to create a persistent data disk.
196+
provider_data={"disk_gb": spec.get("disk_gb", 0)},
197+
)
198+
199+
items = []
200+
for location in locations:
201+
on_demand = copy.deepcopy(template)
202+
on_demand.location = location
203+
on_demand.spot = False
204+
on_demand.price = round(num_gpu * on_demand_per_gpu, 2)
205+
items.append(on_demand)
206+
207+
# TODO: Enable spot offers once we confirm how to request spot billing
208+
# via the VM create API (POST /v1alpha5/projects/{pid}/compute/vms/instances).
209+
# The API schema doesn't have an obvious spot/billing_type field.
210+
211+
return items
212+
213+
214+
def _make_cpu_items(product_name: str, spec: dict, locations: list[str]) -> list[RawCatalogItem]:
215+
prefix = product_name.split(".")[0]
216+
per_vcpu = CPU_PRICING.get(prefix)
217+
if per_vcpu is None:
218+
logger.warning("No pricing for CPU prefix %s (%s), skipping", prefix, product_name)
219+
return []
220+
221+
cpu_cores = spec["cpu_cores"]
222+
template = RawCatalogItem(
223+
instance_name=product_name,
224+
location=None,
225+
price=None,
226+
cpu=cpu_cores,
227+
memory=float(spec["memory_gb"]),
228+
gpu_vendor=None,
229+
gpu_count=0,
230+
gpu_name=None,
231+
gpu_memory=None,
232+
spot=False,
233+
disk_size=float(spec["disk_gb"]) if spec.get("disk_gb") else None,
234+
provider_data={"disk_gb": spec.get("disk_gb", 0)},
235+
)
236+
237+
items = []
238+
for location in locations:
239+
item = copy.deepcopy(template)
240+
item.location = location
241+
item.price = round(cpu_cores * per_vcpu, 2)
242+
items.append(item)
243+
244+
return items

0 commit comments

Comments
 (0)