diff --git a/src/gpuhunt/_internal/constraints.py b/src/gpuhunt/_internal/constraints.py index ffd06a3..876f0be 100644 --- a/src/gpuhunt/_internal/constraints.py +++ b/src/gpuhunt/_internal/constraints.py @@ -254,6 +254,12 @@ def is_nvidia_superchip(gpu_name: str) -> bool: architecture=AMDArchitecture.CDNA3, device_ids=(0x74A5,), ), + AMDGPUInfo( + name="MI350X", + memory=288, + architecture=AMDArchitecture.CDNA4, + device_ids=(0x75A0,), + ), AMDGPUInfo( name="MI355X", memory=288, diff --git a/src/gpuhunt/providers/cloudrift.py b/src/gpuhunt/providers/cloudrift.py index 365481c..175c95b 100644 --- a/src/gpuhunt/providers/cloudrift.py +++ b/src/gpuhunt/providers/cloudrift.py @@ -5,6 +5,7 @@ import requests from gpuhunt import QueryFilter, RawCatalogItem +from gpuhunt._internal.models import AcceleratorVendor from gpuhunt.providers import AbstractProvider logger = logging.getLogger(__name__) @@ -33,13 +34,14 @@ def _get_instance_types(self): def generate_instances(instance) -> list[RawCatalogItem]: instance_gpu_brand = instance["brand_short"] - dstack_gpu_name = next( - iter(gpu_record[1] for gpu_record in GPU_MAP if gpu_record[0] in instance_gpu_brand), None + gpu_info = next( + (gpu_record for gpu_record in GPU_MAP if gpu_record[0] in instance_gpu_brand), None ) - if dstack_gpu_name is None: + if gpu_info is None: logger.warning(f"Failed to find GPU name matching '{instance_gpu_brand}'") return [] + _, dstack_gpu_name, gpu_vendor = gpu_info instance_types = [] for variant in instance["variants"]: for location, _count in variant["nodes_per_dc"].items(): @@ -54,6 +56,7 @@ def generate_instances(instance) -> list[RawCatalogItem]: gpu_count=variant["gpu_count"], gpu_name=dstack_gpu_name, gpu_memory=round(variant["vram"] / 1024**3), + gpu_vendor=gpu_vendor, ) instance_types.append(raw) @@ -61,9 +64,10 @@ def generate_instances(instance) -> list[RawCatalogItem]: GPU_MAP = [ - ("RTX 4090", "RTX4090"), - ("RTX 5090", "RTX5090"), - ("RTX PRO 6000", "RTXPRO6000"), + ("MI350X", "MI350X", AcceleratorVendor.AMD), + ("RTX 4090", "RTX4090", AcceleratorVendor.NVIDIA), + ("RTX 5090", "RTX5090", AcceleratorVendor.NVIDIA), + ("RTX PRO 6000", "RTXPRO6000", AcceleratorVendor.NVIDIA), ]