diff --git a/src/gpuhunt/providers/nebius.py b/src/gpuhunt/providers/nebius.py index 38be2ef..288ce04 100644 --- a/src/gpuhunt/providers/nebius.py +++ b/src/gpuhunt/providers/nebius.py @@ -1,6 +1,7 @@ import logging import re -from typing import Optional +from dataclasses import dataclass +from typing import Optional, cast from nebius.aio.channel import Credentials from nebius.api.nebius.billing.v1alpha1 import ( @@ -26,11 +27,13 @@ TenantServiceClient, ) from nebius.sdk import SDK +from typing_extensions import TypedDict from gpuhunt._internal.constraints import find_accelerators from gpuhunt._internal.models import ( AcceleratorInfo, AcceleratorVendor, + JSONObject, QueryFilter, RawCatalogItem, ) @@ -40,6 +43,26 @@ TIMEOUT = 7 +@dataclass(frozen=True) +class InfinibandFabric: + name: str + platform: str + region: str + + +# https://docs.nebius.com/compute/clusters/gpu#fabrics +INFINIBAND_FABRICS = [ + InfinibandFabric("fabric-2", "gpu-h100-sxm", "eu-north1"), + InfinibandFabric("fabric-3", "gpu-h100-sxm", "eu-north1"), + InfinibandFabric("fabric-4", "gpu-h100-sxm", "eu-north1"), + InfinibandFabric("fabric-5", "gpu-h200-sxm", "eu-west1"), + InfinibandFabric("fabric-6", "gpu-h100-sxm", "eu-north1"), + InfinibandFabric("fabric-7", "gpu-h200-sxm", "eu-north1"), + InfinibandFabric("us-central1-a", "gpu-h200-sxm", "us-central1"), + InfinibandFabric("us-central1-b", "gpu-b200-sxm", "us-central1"), +] + + class NebiusProvider(AbstractProvider): NAME = "nebius" @@ -77,6 +100,10 @@ def get( return items +class NebiusCatalogItemProviderData(TypedDict): + fabrics: list[str] + + def get_sample_projects(sdk: SDK) -> dict[str, str]: """ Returns: @@ -141,6 +168,12 @@ def make_item( spot: bool, price: float, ) -> Optional[RawCatalogItem]: + fabrics = [] + if preset.allow_gpu_clustering: + fabrics = [ + f.name for f in INFINIBAND_FABRICS if f.platform == platform and f.region == region + ] + item = RawCatalogItem( instance_name=f"{platform} {preset.name}", location=region, @@ -153,6 +186,7 @@ def make_item( gpu_vendor=None, spot=spot, disk_size=None, + provider_data=cast(JSONObject, NebiusCatalogItemProviderData(fabrics=fabrics)), ) if preset.resources.gpu_count: diff --git a/src/integrity_tests/test_nebius.py b/src/integrity_tests/test_nebius.py index 308fd14..86f504c 100644 --- a/src/integrity_tests/test_nebius.py +++ b/src/integrity_tests/test_nebius.py @@ -1,4 +1,5 @@ import csv +import json from operator import itemgetter from pathlib import Path @@ -28,3 +29,41 @@ def test_spots_presented(data_rows: list[dict]): @pytest.mark.parametrize("location", ["eu-north1", "eu-west1"]) def test_location_present(location: str, data_rows: list[dict]): assert location in map(itemgetter("location"), data_rows) + + +def test_fabrics_unique(data_rows: list[dict]) -> None: + for row in data_rows: + fabrics = json.loads(row["provider_data"])["fabrics"] + assert len(fabrics) == len(set(fabrics)), f"Duplicate fabrics in row: {row}" + + +def test_fabrics_on_sample_offer(data_rows: list[dict]) -> None: + for row in data_rows: + if ( + row["instance_name"] == "gpu-h100-sxm 8gpu-128vcpu-1600gb" + and row["location"] == "eu-north1" + ): + break + else: + raise ValueError("Offer not found") + fabrics = set(json.loads(row["provider_data"])["fabrics"]) + expected_fabrics = { + "fabric-2", + "fabric-3", + "fabric-4", + "fabric-6", + } + missing_fabrics = expected_fabrics - fabrics + assert not missing_fabrics + + +def test_no_fabrics_on_sample_non_clustered_offer(data_rows: list[dict]) -> None: + for row in data_rows: + if ( + row["instance_name"] == "gpu-h100-sxm 1gpu-16vcpu-200gb" + and row["location"] == "eu-north1" + ): + break + else: + raise ValueError("Offer not found") + assert json.loads(row["provider_data"])["fabrics"] == []