From 7e3bb09d40a825ff373ecbbc7cdfbf5738828cdb Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 28 Oct 2025 17:25:14 +0100 Subject: [PATCH] [Nebius]: Add fabrics list to provider_data Associate each Nebius catalog item with a list of InfiniBand fabrics that it supports. The list of fabrics and their details is hardcoded until Nebius exposes it in the API. Previously, the list of fabrics was hardcoded in dstack. Moving it to gpuhunt allows to add new fabrics without a dstack release. --- src/gpuhunt/providers/nebius.py | 36 ++++++++++++++++++++++++++- src/integrity_tests/test_nebius.py | 39 ++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/gpuhunt/providers/nebius.py b/src/gpuhunt/providers/nebius.py index 38be2ef..288ce04 100644 --- a/src/gpuhunt/providers/nebius.py +++ b/src/gpuhunt/providers/nebius.py @@ -1,6 +1,7 @@ import logging import re -from typing import Optional +from dataclasses import dataclass +from typing import Optional, cast from nebius.aio.channel import Credentials from nebius.api.nebius.billing.v1alpha1 import ( @@ -26,11 +27,13 @@ TenantServiceClient, ) from nebius.sdk import SDK +from typing_extensions import TypedDict from gpuhunt._internal.constraints import find_accelerators from gpuhunt._internal.models import ( AcceleratorInfo, AcceleratorVendor, + JSONObject, QueryFilter, RawCatalogItem, ) @@ -40,6 +43,26 @@ TIMEOUT = 7 +@dataclass(frozen=True) +class InfinibandFabric: + name: str + platform: str + region: str + + +# https://docs.nebius.com/compute/clusters/gpu#fabrics +INFINIBAND_FABRICS = [ + InfinibandFabric("fabric-2", "gpu-h100-sxm", "eu-north1"), + InfinibandFabric("fabric-3", "gpu-h100-sxm", "eu-north1"), + InfinibandFabric("fabric-4", "gpu-h100-sxm", "eu-north1"), + InfinibandFabric("fabric-5", "gpu-h200-sxm", "eu-west1"), + InfinibandFabric("fabric-6", "gpu-h100-sxm", "eu-north1"), + InfinibandFabric("fabric-7", "gpu-h200-sxm", "eu-north1"), + InfinibandFabric("us-central1-a", "gpu-h200-sxm", "us-central1"), + InfinibandFabric("us-central1-b", "gpu-b200-sxm", "us-central1"), +] + + class NebiusProvider(AbstractProvider): NAME = "nebius" @@ -77,6 +100,10 @@ def get( return items +class NebiusCatalogItemProviderData(TypedDict): + fabrics: list[str] + + def get_sample_projects(sdk: SDK) -> dict[str, str]: """ Returns: @@ -141,6 +168,12 @@ def make_item( spot: bool, price: float, ) -> Optional[RawCatalogItem]: + fabrics = [] + if preset.allow_gpu_clustering: + fabrics = [ + f.name for f in INFINIBAND_FABRICS if f.platform == platform and f.region == region + ] + item = RawCatalogItem( instance_name=f"{platform} {preset.name}", location=region, @@ -153,6 +186,7 @@ def make_item( gpu_vendor=None, spot=spot, disk_size=None, + provider_data=cast(JSONObject, NebiusCatalogItemProviderData(fabrics=fabrics)), ) if preset.resources.gpu_count: diff --git a/src/integrity_tests/test_nebius.py b/src/integrity_tests/test_nebius.py index 308fd14..86f504c 100644 --- a/src/integrity_tests/test_nebius.py +++ b/src/integrity_tests/test_nebius.py @@ -1,4 +1,5 @@ import csv +import json from operator import itemgetter from pathlib import Path @@ -28,3 +29,41 @@ def test_spots_presented(data_rows: list[dict]): @pytest.mark.parametrize("location", ["eu-north1", "eu-west1"]) def test_location_present(location: str, data_rows: list[dict]): assert location in map(itemgetter("location"), data_rows) + + +def test_fabrics_unique(data_rows: list[dict]) -> None: + for row in data_rows: + fabrics = json.loads(row["provider_data"])["fabrics"] + assert len(fabrics) == len(set(fabrics)), f"Duplicate fabrics in row: {row}" + + +def test_fabrics_on_sample_offer(data_rows: list[dict]) -> None: + for row in data_rows: + if ( + row["instance_name"] == "gpu-h100-sxm 8gpu-128vcpu-1600gb" + and row["location"] == "eu-north1" + ): + break + else: + raise ValueError("Offer not found") + fabrics = set(json.loads(row["provider_data"])["fabrics"]) + expected_fabrics = { + "fabric-2", + "fabric-3", + "fabric-4", + "fabric-6", + } + missing_fabrics = expected_fabrics - fabrics + assert not missing_fabrics + + +def test_no_fabrics_on_sample_non_clustered_offer(data_rows: list[dict]) -> None: + for row in data_rows: + if ( + row["instance_name"] == "gpu-h100-sxm 1gpu-16vcpu-200gb" + and row["location"] == "eu-north1" + ): + break + else: + raise ValueError("Offer not found") + assert json.loads(row["provider_data"])["fabrics"] == []