Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/gpuhunt/providers/nebius.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
from typing import Optional
from dataclasses import dataclass
from typing import Optional, cast

from nebius.aio.channel import Credentials
from nebius.api.nebius.billing.v1alpha1 import (
Expand All @@ -26,11 +27,13 @@
TenantServiceClient,
)
from nebius.sdk import SDK
from typing_extensions import TypedDict

from gpuhunt._internal.constraints import find_accelerators
from gpuhunt._internal.models import (
AcceleratorInfo,
AcceleratorVendor,
JSONObject,
QueryFilter,
RawCatalogItem,
)
Expand All @@ -40,6 +43,26 @@
TIMEOUT = 7


@dataclass(frozen=True)
class InfinibandFabric:
name: str
platform: str
region: str


# https://docs.nebius.com/compute/clusters/gpu#fabrics
INFINIBAND_FABRICS = [
InfinibandFabric("fabric-2", "gpu-h100-sxm", "eu-north1"),
InfinibandFabric("fabric-3", "gpu-h100-sxm", "eu-north1"),
InfinibandFabric("fabric-4", "gpu-h100-sxm", "eu-north1"),
InfinibandFabric("fabric-5", "gpu-h200-sxm", "eu-west1"),
InfinibandFabric("fabric-6", "gpu-h100-sxm", "eu-north1"),
InfinibandFabric("fabric-7", "gpu-h200-sxm", "eu-north1"),
InfinibandFabric("us-central1-a", "gpu-h200-sxm", "us-central1"),
InfinibandFabric("us-central1-b", "gpu-b200-sxm", "us-central1"),
]


class NebiusProvider(AbstractProvider):
NAME = "nebius"

Expand Down Expand Up @@ -77,6 +100,10 @@ def get(
return items


class NebiusCatalogItemProviderData(TypedDict):
fabrics: list[str]


def get_sample_projects(sdk: SDK) -> dict[str, str]:
"""
Returns:
Expand Down Expand Up @@ -141,6 +168,12 @@ def make_item(
spot: bool,
price: float,
) -> Optional[RawCatalogItem]:
fabrics = []
if preset.allow_gpu_clustering:
fabrics = [
f.name for f in INFINIBAND_FABRICS if f.platform == platform and f.region == region
]

item = RawCatalogItem(
instance_name=f"{platform} {preset.name}",
location=region,
Expand All @@ -153,6 +186,7 @@ def make_item(
gpu_vendor=None,
spot=spot,
disk_size=None,
provider_data=cast(JSONObject, NebiusCatalogItemProviderData(fabrics=fabrics)),
)

if preset.resources.gpu_count:
Expand Down
39 changes: 39 additions & 0 deletions src/integrity_tests/test_nebius.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import json
from operator import itemgetter
from pathlib import Path

Expand Down Expand Up @@ -28,3 +29,41 @@ def test_spots_presented(data_rows: list[dict]):
@pytest.mark.parametrize("location", ["eu-north1", "eu-west1"])
def test_location_present(location: str, data_rows: list[dict]):
assert location in map(itemgetter("location"), data_rows)


def test_fabrics_unique(data_rows: list[dict]) -> None:
for row in data_rows:
fabrics = json.loads(row["provider_data"])["fabrics"]
assert len(fabrics) == len(set(fabrics)), f"Duplicate fabrics in row: {row}"


def test_fabrics_on_sample_offer(data_rows: list[dict]) -> None:
for row in data_rows:
if (
row["instance_name"] == "gpu-h100-sxm 8gpu-128vcpu-1600gb"
and row["location"] == "eu-north1"
):
break
else:
raise ValueError("Offer not found")
fabrics = set(json.loads(row["provider_data"])["fabrics"])
expected_fabrics = {
"fabric-2",
"fabric-3",
"fabric-4",
"fabric-6",
}
missing_fabrics = expected_fabrics - fabrics
assert not missing_fabrics


def test_no_fabrics_on_sample_non_clustered_offer(data_rows: list[dict]) -> None:
for row in data_rows:
if (
row["instance_name"] == "gpu-h100-sxm 1gpu-16vcpu-200gb"
and row["location"] == "eu-north1"
):
break
else:
raise ValueError("Offer not found")
assert json.loads(row["provider_data"])["fabrics"] == []