Skip to content

Commit 1db1121

Browse files
committed
Fix services.gpus typing
1 parent b053cfa commit 1db1121

File tree

2 files changed

+49
-53
lines changed

2 files changed

+49
-53
lines changed

src/dstack/_internal/server/routers/gpus.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from typing import Tuple
22

33
from fastapi import APIRouter, Depends
4-
from sqlalchemy.ext.asyncio import AsyncSession
54

6-
from dstack._internal.server.db import get_session
75
from dstack._internal.server.models import ProjectModel, UserModel
86
from dstack._internal.server.schemas.gpus import ListGpusRequest, ListGpusResponse
97
from dstack._internal.server.security.permissions import ProjectMember
@@ -20,10 +18,7 @@
2018
@project_router.post("/list", response_model=ListGpusResponse, response_model_exclude_none=True)
2119
async def list_gpus(
2220
body: ListGpusRequest,
23-
session: AsyncSession = Depends(get_session),
2421
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
2522
) -> ListGpusResponse:
2623
_, project = user_project
27-
return await list_gpus_grouped(
28-
session=session, project=project, run_spec=body.run_spec, group_by=body.group_by
29-
)
24+
return await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by)

src/dstack/_internal/server/services/gpus.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Dict, List, Literal, Optional, Tuple
22

3-
from sqlalchemy.ext.asyncio import AsyncSession
4-
53
from dstack._internal.core.backends.base.backend import Backend
4+
from dstack._internal.core.errors import ServerClientError
5+
from dstack._internal.core.models.backends.base import BackendType
66
from dstack._internal.core.models.instances import InstanceOfferWithAvailability
77
from dstack._internal.core.models.profiles import SpotPolicy
88
from dstack._internal.core.models.resources import Range
@@ -15,10 +15,43 @@
1515
ListGpusResponse,
1616
)
1717
from dstack._internal.server.services.offers import get_offers_by_requirements
18+
from dstack._internal.utils.common import get_or_error
19+
20+
21+
async def list_gpus_grouped(
22+
project: ProjectModel,
23+
run_spec: RunSpec,
24+
group_by: Optional[List[Literal["backend", "region", "count"]]] = None,
25+
) -> ListGpusResponse:
26+
"""Retrieves available GPU specifications based on a run spec, with optional grouping."""
27+
offers = await _get_gpu_offers(project=project, run_spec=run_spec)
28+
backend_gpus = _process_offers_into_backend_gpus(offers)
29+
group_by_set = set(group_by) if group_by else set()
30+
if "region" in group_by_set and "backend" not in group_by_set:
31+
raise ServerClientError("Cannot group by 'region' without also grouping by 'backend'")
32+
33+
# Determine grouping strategy based on combination
34+
has_backend = "backend" in group_by_set
35+
has_region = "region" in group_by_set
36+
has_count = "count" in group_by_set
37+
if has_backend and has_region and has_count:
38+
gpus = _get_gpus_grouped_by_backend_region_and_count(backend_gpus)
39+
elif has_backend and has_count:
40+
gpus = _get_gpus_grouped_by_backend_and_count(backend_gpus)
41+
elif has_backend and has_region:
42+
gpus = _get_gpus_grouped_by_backend_and_region(backend_gpus)
43+
elif has_backend:
44+
gpus = _get_gpus_grouped_by_backend(backend_gpus)
45+
elif has_count:
46+
gpus = _get_gpus_grouped_by_count(backend_gpus)
47+
else:
48+
gpus = _get_gpus_with_no_grouping(backend_gpus)
49+
50+
return ListGpusResponse(gpus=gpus)
1851

1952

2053
async def _get_gpu_offers(
21-
session: AsyncSession, project: ProjectModel, run_spec: RunSpec
54+
project: ProjectModel, run_spec: RunSpec
2255
) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
2356
"""Fetches all available instance offers that match the run spec's GPU requirements."""
2457
profile = run_spec.merged_profile
@@ -28,7 +61,6 @@ async def _get_gpu_offers(
2861
spot=get_policy_map(profile.spot_policy, default=SpotPolicy.AUTO),
2962
reservation=profile.reservation,
3063
)
31-
3264
return await get_offers_by_requirements(
3365
project=project,
3466
profile=profile,
@@ -45,7 +77,7 @@ def _process_offers_into_backend_gpus(
4577
offers: List[Tuple[Backend, InstanceOfferWithAvailability]],
4678
) -> List[BackendGpus]:
4779
"""Transforms raw offers into a structured list of BackendGpus, aggregating GPU info."""
48-
backend_data: Dict[str, Dict] = {}
80+
backend_data: Dict[BackendType, Dict] = {}
4981

5082
for backend, offer in offers:
5183
backend_type = backend.TYPE
@@ -111,7 +143,7 @@ def _process_offers_into_backend_gpus(
111143
return backend_gpus_list
112144

113145

114-
def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: str):
146+
def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: BackendType):
115147
"""Updates an existing GpuGroup with new data from another GPU offer."""
116148
spot_type: Literal["spot", "on-demand"] = "spot" if gpu.spot else "on-demand"
117149

@@ -122,6 +154,12 @@ def _update_gpu_group(row: GpuGroup, gpu: BackendGpu, backend_type: str):
122154
if row.backends and backend_type not in row.backends:
123155
row.backends.append(backend_type)
124156

157+
# FIXME: Consider using non-optional range
158+
assert row.count.min is not None
159+
assert row.count.max is not None
160+
assert row.price.min is not None
161+
assert row.price.max is not None
162+
125163
row.count.min = min(row.count.min, gpu.count)
126164
row.count.max = max(row.count.max, gpu.count)
127165
per_gpu_price = gpu.price / gpu.count
@@ -194,7 +232,7 @@ def _get_gpus_grouped_by_backend(backend_gpus: List[BackendGpus]) -> List[GpuGro
194232
not any(av.is_available() for av in g.availability),
195233
g.price.min,
196234
g.price.max,
197-
g.backend.value,
235+
get_or_error(g.backend).value,
198236
g.name,
199237
g.memory_mib,
200238
),
@@ -229,7 +267,7 @@ def _get_gpus_grouped_by_backend_and_region(backend_gpus: List[BackendGpus]) ->
229267
not any(av.is_available() for av in g.availability),
230268
g.price.min,
231269
g.price.max,
232-
g.backend.value,
270+
get_or_error(g.backend).value,
233271
g.region,
234272
g.name,
235273
g.memory_mib,
@@ -299,7 +337,7 @@ def _get_gpus_grouped_by_backend_and_count(backend_gpus: List[BackendGpus]) -> L
299337
not any(av.is_available() for av in g.availability),
300338
g.price.min,
301339
g.price.max,
302-
g.backend.value,
340+
get_or_error(g.backend).value,
303341
g.count.min,
304342
g.name,
305343
g.memory_mib,
@@ -344,47 +382,10 @@ def _get_gpus_grouped_by_backend_region_and_count(
344382
not any(av.is_available() for av in g.availability),
345383
g.price.min,
346384
g.price.max,
347-
g.backend.value,
385+
get_or_error(g.backend).value,
348386
g.region,
349387
g.count.min,
350388
g.name,
351389
g.memory_mib,
352390
),
353391
)
354-
355-
356-
async def list_gpus_grouped(
357-
session: AsyncSession,
358-
project: ProjectModel,
359-
run_spec: RunSpec,
360-
group_by: Optional[List[Literal["backend", "region", "count"]]] = None,
361-
) -> ListGpusResponse:
362-
"""Retrieves available GPU specifications based on a run spec, with optional grouping."""
363-
offers = await _get_gpu_offers(session, project, run_spec)
364-
backend_gpus = _process_offers_into_backend_gpus(offers)
365-
366-
group_by_set = set(group_by) if group_by else set()
367-
368-
if "region" in group_by_set and "backend" not in group_by_set:
369-
from dstack._internal.core.errors import ServerClientError
370-
371-
raise ServerClientError("Cannot group by 'region' without also grouping by 'backend'")
372-
373-
# Determine grouping strategy based on combination
374-
has_backend = "backend" in group_by_set
375-
has_region = "region" in group_by_set
376-
has_count = "count" in group_by_set
377-
if has_backend and has_region and has_count:
378-
gpus = _get_gpus_grouped_by_backend_region_and_count(backend_gpus)
379-
elif has_backend and has_count:
380-
gpus = _get_gpus_grouped_by_backend_and_count(backend_gpus)
381-
elif has_backend and has_region:
382-
gpus = _get_gpus_grouped_by_backend_and_region(backend_gpus)
383-
elif has_backend:
384-
gpus = _get_gpus_grouped_by_backend(backend_gpus)
385-
elif has_count:
386-
gpus = _get_gpus_grouped_by_count(backend_gpus)
387-
else:
388-
gpus = _get_gpus_with_no_grouping(backend_gpus)
389-
390-
return ListGpusResponse(gpus=gpus)

0 commit comments

Comments
 (0)