11from typing import Dict , List , Literal , Optional , Tuple
22
3- from sqlalchemy .ext .asyncio import AsyncSession
4-
53from dstack ._internal .core .backends .base .backend import Backend
4+ from dstack ._internal .core .errors import ServerClientError
5+ from dstack ._internal .core .models .backends .base import BackendType
66from dstack ._internal .core .models .instances import InstanceOfferWithAvailability
77from dstack ._internal .core .models .profiles import SpotPolicy
88from dstack ._internal .core .models .resources import Range
1515 ListGpusResponse ,
1616)
1717from dstack ._internal .server .services .offers import get_offers_by_requirements
18+ from dstack ._internal .utils .common import get_or_error
19+
20+
21+ async def list_gpus_grouped (
22+ project : ProjectModel ,
23+ run_spec : RunSpec ,
24+ group_by : Optional [List [Literal ["backend" , "region" , "count" ]]] = None ,
25+ ) -> ListGpusResponse :
26+ """Retrieves available GPU specifications based on a run spec, with optional grouping."""
27+ offers = await _get_gpu_offers (project = project , run_spec = run_spec )
28+ backend_gpus = _process_offers_into_backend_gpus (offers )
29+ group_by_set = set (group_by ) if group_by else set ()
30+ if "region" in group_by_set and "backend" not in group_by_set :
31+ raise ServerClientError ("Cannot group by 'region' without also grouping by 'backend'" )
32+
33+ # Determine grouping strategy based on combination
34+ has_backend = "backend" in group_by_set
35+ has_region = "region" in group_by_set
36+ has_count = "count" in group_by_set
37+ if has_backend and has_region and has_count :
38+ gpus = _get_gpus_grouped_by_backend_region_and_count (backend_gpus )
39+ elif has_backend and has_count :
40+ gpus = _get_gpus_grouped_by_backend_and_count (backend_gpus )
41+ elif has_backend and has_region :
42+ gpus = _get_gpus_grouped_by_backend_and_region (backend_gpus )
43+ elif has_backend :
44+ gpus = _get_gpus_grouped_by_backend (backend_gpus )
45+ elif has_count :
46+ gpus = _get_gpus_grouped_by_count (backend_gpus )
47+ else :
48+ gpus = _get_gpus_with_no_grouping (backend_gpus )
49+
50+ return ListGpusResponse (gpus = gpus )
1851
1952
2053async def _get_gpu_offers (
21- session : AsyncSession , project : ProjectModel , run_spec : RunSpec
54+ project : ProjectModel , run_spec : RunSpec
2255) -> List [Tuple [Backend , InstanceOfferWithAvailability ]]:
2356 """Fetches all available instance offers that match the run spec's GPU requirements."""
2457 profile = run_spec .merged_profile
@@ -28,7 +61,6 @@ async def _get_gpu_offers(
2861 spot = get_policy_map (profile .spot_policy , default = SpotPolicy .AUTO ),
2962 reservation = profile .reservation ,
3063 )
31-
3264 return await get_offers_by_requirements (
3365 project = project ,
3466 profile = profile ,
@@ -45,7 +77,7 @@ def _process_offers_into_backend_gpus(
4577 offers : List [Tuple [Backend , InstanceOfferWithAvailability ]],
4678) -> List [BackendGpus ]:
4779 """Transforms raw offers into a structured list of BackendGpus, aggregating GPU info."""
48- backend_data : Dict [str , Dict ] = {}
80+ backend_data : Dict [BackendType , Dict ] = {}
4981
5082 for backend , offer in offers :
5183 backend_type = backend .TYPE
@@ -111,7 +143,7 @@ def _process_offers_into_backend_gpus(
111143 return backend_gpus_list
112144
113145
114- def _update_gpu_group (row : GpuGroup , gpu : BackendGpu , backend_type : str ):
146+ def _update_gpu_group (row : GpuGroup , gpu : BackendGpu , backend_type : BackendType ):
115147 """Updates an existing GpuGroup with new data from another GPU offer."""
116148 spot_type : Literal ["spot" , "on-demand" ] = "spot" if gpu .spot else "on-demand"
117149
@@ -122,6 +154,12 @@ def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: str):
122154 if row .backends and backend_type not in row .backends :
123155 row .backends .append (backend_type )
124156
157+ # FIXME: Consider using non-optional range
158+ assert row .count .min is not None
159+ assert row .count .max is not None
160+ assert row .price .min is not None
161+ assert row .price .max is not None
162+
125163 row .count .min = min (row .count .min , gpu .count )
126164 row .count .max = max (row .count .max , gpu .count )
127165 per_gpu_price = gpu .price / gpu .count
@@ -194,7 +232,7 @@ def _get_gpus_grouped_by_backend(backend_gpus: List[BackendGpus]) -> List[GpuGro
194232 not any (av .is_available () for av in g .availability ),
195233 g .price .min ,
196234 g .price .max ,
197- g .backend .value ,
235+ get_or_error ( g .backend ) .value ,
198236 g .name ,
199237 g .memory_mib ,
200238 ),
@@ -229,7 +267,7 @@ def _get_gpus_grouped_by_backend_and_region(backend_gpus: List[BackendGpus]) ->
229267 not any (av .is_available () for av in g .availability ),
230268 g .price .min ,
231269 g .price .max ,
232- g .backend .value ,
270+ get_or_error ( g .backend ) .value ,
233271 g .region ,
234272 g .name ,
235273 g .memory_mib ,
@@ -299,7 +337,7 @@ def _get_gpus_grouped_by_backend_and_count(backend_gpus: List[BackendGpus]) -> L
299337 not any (av .is_available () for av in g .availability ),
300338 g .price .min ,
301339 g .price .max ,
302- g .backend .value ,
340+ get_or_error ( g .backend ) .value ,
303341 g .count .min ,
304342 g .name ,
305343 g .memory_mib ,
@@ -344,47 +382,10 @@ def _get_gpus_grouped_by_backend_region_and_count(
344382 not any (av .is_available () for av in g .availability ),
345383 g .price .min ,
346384 g .price .max ,
347- g .backend .value ,
385+ get_or_error ( g .backend ) .value ,
348386 g .region ,
349387 g .count .min ,
350388 g .name ,
351389 g .memory_mib ,
352390 ),
353391 )
354-
355-
356- async def list_gpus_grouped (
357- session : AsyncSession ,
358- project : ProjectModel ,
359- run_spec : RunSpec ,
360- group_by : Optional [List [Literal ["backend" , "region" , "count" ]]] = None ,
361- ) -> ListGpusResponse :
362- """Retrieves available GPU specifications based on a run spec, with optional grouping."""
363- offers = await _get_gpu_offers (session , project , run_spec )
364- backend_gpus = _process_offers_into_backend_gpus (offers )
365-
366- group_by_set = set (group_by ) if group_by else set ()
367-
368- if "region" in group_by_set and "backend" not in group_by_set :
369- from dstack ._internal .core .errors import ServerClientError
370-
371- raise ServerClientError ("Cannot group by 'region' without also grouping by 'backend'" )
372-
373- # Determine grouping strategy based on combination
374- has_backend = "backend" in group_by_set
375- has_region = "region" in group_by_set
376- has_count = "count" in group_by_set
377- if has_backend and has_region and has_count :
378- gpus = _get_gpus_grouped_by_backend_region_and_count (backend_gpus )
379- elif has_backend and has_count :
380- gpus = _get_gpus_grouped_by_backend_and_count (backend_gpus )
381- elif has_backend and has_region :
382- gpus = _get_gpus_grouped_by_backend_and_region (backend_gpus )
383- elif has_backend :
384- gpus = _get_gpus_grouped_by_backend (backend_gpus )
385- elif has_count :
386- gpus = _get_gpus_grouped_by_count (backend_gpus )
387- else :
388- gpus = _get_gpus_with_no_grouping (backend_gpus )
389-
390- return ListGpusResponse (gpus = gpus )
0 commit comments