diff --git a/src/gpuhunt/providers/vastai.py b/src/gpuhunt/providers/vastai.py index 19ab367..8469f7e 100644 --- a/src/gpuhunt/providers/vastai.py +++ b/src/gpuhunt/providers/vastai.py @@ -21,13 +21,20 @@ class VastAIProvider(AbstractProvider): NAME = "vastai" - def __init__(self, extra_filters: Optional[dict[str, dict[Operators, FilterValue]]] = None): + def __init__( + self, + extra_filters: Optional[dict[str, dict[Operators, FilterValue]]] = None, + community_cloud: bool = True, + ): self.extra_filters = extra_filters + self.community_cloud = community_cloud def get( self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True ) -> list[RawCatalogItem]: - filters: dict[str, Any] = self.make_filters(query_filter or QueryFilter()) + filters: dict[str, Any] = self.make_filters( + query_filter or QueryFilter(), community_cloud=self.community_cloud + ) if self.extra_filters: for key, constraints in self.extra_filters.items(): for op, value in constraints.items(): @@ -79,7 +86,9 @@ def get( return instance_offers @staticmethod - def make_filters(q: QueryFilter) -> dict[str, dict[Operators, FilterValue]]: + def make_filters( + q: QueryFilter, community_cloud: bool = True + ) -> dict[str, dict[Operators, FilterValue]]: filters = defaultdict(dict) if q.min_cpu is not None: filters["cpu_cores"]["gte"] = q.min_cpu @@ -120,6 +129,11 @@ def make_filters(q: QueryFilter) -> dict[str, dict[Operators, FilterValue]]: filters["compute_capability"]["gte"] = compute_cap(q.min_compute_capability) if q.max_compute_capability is not None: filters["compute_capability"]["lte"] = compute_cap(q.max_compute_capability) + # Datacenter offers map to Vast's "server cloud" scope. + # When community_cloud is enabled, keep scope unfiltered so both + # server and community offers are returned. + if not community_cloud: + filters["datacenter"]["eq"] = True filters["rentable"]["eq"] = True filters["rented"]["eq"] = False filters["order"] = [["score", "desc"]] @@ -128,6 +142,10 @@ def make_filters(q: QueryFilter) -> dict[str, dict[Operators, FilterValue]]: @staticmethod def satisfies_filters(offer: dict, filters: dict[str, dict[Operators, FilterValue]]) -> bool: for key in filters: + # `datacenter`/`external` are query scope controls. + # They don't map to offer fields with strict eq semantics. + if key in {"datacenter", "external"}: + continue if key not in offer: continue for op, value in filters[key].items(): diff --git a/src/tests/providers/test_vastai.py b/src/tests/providers/test_vastai.py new file mode 100644 index 0000000..ffb5f7c --- /dev/null +++ b/src/tests/providers/test_vastai.py @@ -0,0 +1,14 @@ +from gpuhunt._internal.models import QueryFilter +from gpuhunt.providers.vastai import VastAIProvider + + +def test_make_filters_defaults_to_datacenter_only(): + filters = VastAIProvider.make_filters(QueryFilter(), community_cloud=False) + assert filters["datacenter"]["eq"] is True + assert "external" not in filters + + +def test_make_filters_does_not_constrain_scope_when_community_cloud_enabled(): + filters = VastAIProvider.make_filters(QueryFilter(), community_cloud=True) + assert "datacenter" not in filters + assert "external" not in filters