22import subprocess
33import tempfile
44from threading import Thread
5- from typing import List , Optional
5+ from typing import Any , List , Optional
66
77import gpuhunt
88from gpuhunt .providers .hotaisle import HotAisleProvider
2222from dstack ._internal .core .models .instances import (
2323 InstanceAvailability ,
2424 InstanceConfiguration ,
25+ InstanceOffer ,
2526 InstanceOfferWithAvailability ,
2627)
2728from dstack ._internal .core .models .placement import PlacementGroup
3132logger = get_logger (__name__ )
3233
3334
34- INSTANCE_TYPE_SPECS = {
35- "1x MI300X 8x Xeon Platinum 8462Y+" : {
36- "cpu_model" : "Xeon Platinum 8462Y+" ,
37- "cpu_frequency" : 2800000000 ,
38- "cpu_manufacturer" : "Intel" ,
39- },
40- "1x MI300X 13x Xeon Platinum 8470" : {
41- "cpu_model" : "Xeon Platinum 8470" ,
42- "cpu_frequency" : 2000000000 ,
43- "cpu_manufacturer" : "Intel" ,
44- },
45- "2x MI300X 26x Xeon Platinum 8470" : {
46- "cpu_model" : "Xeon Platinum 8470" ,
47- "cpu_frequency" : 2000000000 ,
48- "cpu_manufacturer" : "Intel" ,
49- },
50- "2x MI300X 26x Xeon Platinum 8462Y+" : {
51- "cpu_model" : "Xeon Platinum 8462Y+" ,
52- "cpu_frequency" : 2800000000 ,
53- "cpu_manufacturer" : "Intel" ,
54- },
55- "4x MI300X 52x Xeon Platinum 8470" : {
56- "cpu_model" : "Xeon Platinum 8470" ,
57- "cpu_frequency" : 2000000000 ,
58- "cpu_manufacturer" : "Intel" ,
59- },
60- "4x MI300X 52x Xeon Platinum 8462Y+" : {
61- "cpu_model" : "Xeon Platinum 8462Y+" ,
62- "cpu_frequency" : 2800000000 ,
63- "cpu_manufacturer" : "Intel" ,
64- },
65- "8x MI300X 104x Xeon Platinum 8470" : {
66- "cpu_model" : "Xeon Platinum 8470" ,
67- "cpu_frequency" : 2000000000 ,
68- "cpu_manufacturer" : "Intel" ,
69- },
70- "8x MI300X 104x Xeon Platinum 8462Y+" : {
71- "cpu_model" : "Xeon Platinum 8462Y+" ,
72- "cpu_frequency" : 2800000000 ,
73- "cpu_manufacturer" : "Intel" ,
74- },
75- }
35+ SUPPORTED_GPUS = ["MI300X" ]
7636
7737
7838class HotAisleCompute (
@@ -95,45 +55,15 @@ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability
9555 backend = BackendType .HOTAISLE ,
9656 locations = self .config .regions or None ,
9757 catalog = self .catalog ,
58+ extra_filter = _supported_instances ,
9859 )
99- supported_offers = []
100- for offer in offers :
101- if offer .instance .name in INSTANCE_TYPE_SPECS :
102- supported_offers .append (
103- InstanceOfferWithAvailability (
104- ** offer .dict (), availability = InstanceAvailability .AVAILABLE
105- )
106- )
107- else :
108- logger .warning (
109- f"Skipping unsupported Hot Aisle instance type: { offer .instance .name } "
110- )
111- return supported_offers
112-
113- def get_payload_from_offer (self , instance_type ) -> dict :
114- instance_type_name = instance_type .name
115- cpu_specs = INSTANCE_TYPE_SPECS [instance_type_name ]
116- cpu_cores = instance_type .resources .cpus
117-
118- return {
119- "cpu_cores" : cpu_cores ,
120- "cpus" : {
121- "count" : 1 ,
122- "manufacturer" : cpu_specs ["cpu_manufacturer" ],
123- "model" : cpu_specs ["cpu_model" ],
124- "cores" : cpu_cores ,
125- "frequency" : cpu_specs ["cpu_frequency" ],
126- },
127- "disk_capacity" : instance_type .resources .disk .size_mib * 1024 ** 2 ,
128- "ram_capacity" : instance_type .resources .memory_mib * 1024 ** 2 ,
129- "gpus" : [
130- {
131- "count" : len (instance_type .resources .gpus ),
132- "manufacturer" : instance_type .resources .gpus [0 ].vendor ,
133- "model" : instance_type .resources .gpus [0 ].name ,
134- }
135- ],
136- }
60+ return [
61+ InstanceOfferWithAvailability (
62+ ** offer .dict (),
63+ availability = InstanceAvailability .AVAILABLE ,
64+ )
65+ for offer in offers
66+ ]
13767
13868 def create_instance (
13969 self ,
@@ -143,8 +73,10 @@ def create_instance(
14373 ) -> JobProvisioningData :
14474 project_ssh_key = instance_config .ssh_keys [0 ]
14575 self .api_client .upload_ssh_key (project_ssh_key .public )
146- vm_payload = self .get_payload_from_offer (instance_offer .instance )
147- vm_data = self .api_client .create_virtual_machine (vm_payload )
76+ offer_backend_data : HotAisleOfferBackendData = (
77+ HotAisleOfferBackendData .__response__ .parse_obj (instance_offer .backend_data )
78+ )
79+ vm_data = self .api_client .create_virtual_machine (offer_backend_data .vm_specs )
14880 return JobProvisioningData (
14981 backend = instance_offer .backend ,
15082 instance_type = instance_offer .instance ,
@@ -240,10 +172,20 @@ def _run_ssh_command(hostname: str, ssh_private_key: str, command: str):
240172 )
241173
242174
175+ def _supported_instances (offer : InstanceOffer ) -> bool :
176+ return len (offer .instance .resources .gpus ) > 0 and all (
177+ gpu .name in SUPPORTED_GPUS for gpu in offer .instance .resources .gpus
178+ )
179+
180+
243181class HotAisleInstanceBackendData (CoreModel ):
244182 ip_address : str
245183
246184 @classmethod
247185 def load (cls , raw : Optional [str ]) -> "HotAisleInstanceBackendData" :
248186 assert raw is not None
249187 return cls .__response__ .parse_raw (raw )
188+
189+
190+ class HotAisleOfferBackendData (CoreModel ):
191+ vm_specs : dict [str , Any ]
0 commit comments