Skip to content

Commit 85dd21e

Browse files
authored
Optimize per fleet offers (#3316)
* WIP: Support lazy max_offers for get_offers_by_requirements * Implement get_offers iterator for ComputeWithAllOffersCached * Refetch backend offers without limit to return all offers for the optimal fleet * Drop tensordock compute * Update get_offers() signatures * Fix get_backend_offers() * Replace yield with iter for LocalCompute * Fix var capture by generator expression * Fix exclude_not_available ignored * Drop tensordock configurator import
1 parent 6dd1057 commit 85dd21e

File tree

9 files changed

+232
-299
lines changed

9 files changed

+232
-299
lines changed

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import string
55
import threading
66
from abc import ABC, abstractmethod
7-
from collections.abc import Iterable
7+
from collections.abc import Iterable, Iterator
88
from enum import Enum
99
from functools import lru_cache
1010
from pathlib import Path
@@ -95,11 +95,12 @@ class Compute(ABC):
9595
"""
9696

9797
@abstractmethod
98-
def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
98+
def get_offers(self, requirements: Requirements) -> Iterator[InstanceOfferWithAvailability]:
9999
"""
100100
Returns offers with availability matching `requirements`.
101-
If the provider is added to gpuhunt, typically gets offers using `base.offers.get_catalog_offers()`
102-
and extends them with availability info.
101+
If the provider is added to gpuhunt, typically gets offers using
102+
`base.offers.get_catalog_offers()` and extends them with availability info.
103+
It is called from async code in executor. It can block on call but not between yields.
103104
"""
104105
pass
105106

@@ -190,13 +191,13 @@ def get_offers_post_filter(
190191
"""
191192
return None
192193

193-
def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
194-
offers = self._get_all_offers_with_availability_cached()
195-
offers = self.__apply_modifiers(offers, self.get_offers_modifiers(requirements))
194+
def get_offers(self, requirements: Requirements) -> Iterator[InstanceOfferWithAvailability]:
195+
cached_offers = self._get_all_offers_with_availability_cached()
196+
offers = self.__apply_modifiers(cached_offers, self.get_offers_modifiers(requirements))
196197
offers = filter_offers_by_requirements(offers, requirements)
197198
post_filter = self.get_offers_post_filter(requirements)
198199
if post_filter is not None:
199-
offers = [o for o in offers if post_filter(o)]
200+
offers = (o for o in offers if post_filter(o))
200201
return offers
201202

202203
@cachedmethod(
@@ -209,16 +210,14 @@ def _get_all_offers_with_availability_cached(self) -> List[InstanceOfferWithAvai
209210
@staticmethod
210211
def __apply_modifiers(
211212
offers: Iterable[InstanceOfferWithAvailability], modifiers: Iterable[OfferModifier]
212-
) -> list[InstanceOfferWithAvailability]:
213-
modified_offers = []
213+
) -> Iterator[InstanceOfferWithAvailability]:
214214
for offer in offers:
215215
for modifier in modifiers:
216216
offer = modifier(offer)
217217
if offer is None:
218218
break
219219
else:
220-
modified_offers.append(offer)
221-
return modified_offers
220+
yield offer
222221

223222

224223
class ComputeWithFilteredOffersCached(ABC):
@@ -242,8 +241,8 @@ def get_offers_by_requirements(
242241
"""
243242
pass
244243

245-
def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
246-
return self._get_offers_cached(requirements)
244+
def get_offers(self, requirements: Requirements) -> Iterator[InstanceOfferWithAvailability]:
245+
return iter(self._get_offers_cached(requirements))
247246

248247
def _get_offers_cached_key(self, requirements: Requirements) -> int:
249248
# Requirements is not hashable, so we use a hack to get arguments hash

src/dstack/_internal/core/backends/base/offers.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterable, Iterator
12
from dataclasses import asdict
23
from typing import Callable, List, Optional, TypeVar
34

@@ -174,16 +175,14 @@ def requirements_to_query_filter(req: Optional[Requirements]) -> gpuhunt.QueryFi
174175

175176

176177
def filter_offers_by_requirements(
177-
offers: List[InstanceOfferT],
178+
offers: Iterable[InstanceOfferT],
178179
requirements: Optional[Requirements],
179-
) -> List[InstanceOfferT]:
180+
) -> Iterator[InstanceOfferT]:
180181
query_filter = requirements_to_query_filter(requirements)
181-
filtered_offers = []
182182
for offer in offers:
183183
catalog_item = offer_to_catalog_item(offer)
184184
if gpuhunt.matches(catalog_item, q=query_filter):
185-
filtered_offers.append(offer)
186-
return filtered_offers
185+
yield offer
187186

188187

189188
def choose_disk_size_mib(

src/dstack/_internal/core/backends/configurators.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,6 @@
119119
except ImportError:
120120
pass
121121

122-
try:
123-
from dstack._internal.core.backends.tensordock.configurator import (
124-
TensorDockConfigurator,
125-
)
126-
127-
_CONFIGURATOR_CLASSES.append(TensorDockConfigurator)
128-
except ImportError:
129-
pass
130122

131123
try:
132124
from dstack._internal.core.backends.vastai.configurator import VastAIConfigurator

src/dstack/_internal/core/backends/local/compute.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterator
12
from typing import List, Optional
23

34
from dstack._internal.core.backends.base.compute import (
@@ -18,7 +19,12 @@
1819
)
1920
from dstack._internal.core.models.placement import PlacementGroup
2021
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
21-
from dstack._internal.core.models.volumes import Volume, VolumeProvisioningData
22+
from dstack._internal.core.models.volumes import (
23+
Volume,
24+
VolumeAttachmentData,
25+
VolumeProvisioningData,
26+
)
27+
from dstack._internal.utils.common import get_or_error
2228
from dstack._internal.utils.logging import get_logger
2329

2430
logger = get_logger(__name__)
@@ -30,20 +36,22 @@ class LocalCompute(
3036
ComputeWithVolumeSupport,
3137
Compute,
3238
):
33-
def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
34-
return [
35-
InstanceOfferWithAvailability(
36-
backend=BackendType.LOCAL,
37-
instance=InstanceType(
38-
name="local",
39-
resources=Resources(cpus=4, memory_mib=8192, gpus=[], spot=False),
40-
),
41-
region="local",
42-
price=0.00,
43-
availability=InstanceAvailability.AVAILABLE,
44-
instance_runtime=InstanceRuntime.RUNNER,
45-
)
46-
]
39+
def get_offers(self, requirements: Requirements) -> Iterator[InstanceOfferWithAvailability]:
40+
return iter(
41+
[
42+
InstanceOfferWithAvailability(
43+
backend=BackendType.LOCAL,
44+
instance=InstanceType(
45+
name="local",
46+
resources=Resources(cpus=4, memory_mib=8192, gpus=[], spot=False),
47+
),
48+
region="local",
49+
price=0.00,
50+
availability=InstanceAvailability.AVAILABLE,
51+
instance_runtime=InstanceRuntime.RUNNER,
52+
)
53+
]
54+
)
4755

4856
def terminate_instance(
4957
self, instance_id: str, region: str, backend_data: Optional[str] = None
@@ -98,7 +106,7 @@ def run_job(
98106

99107
def register_volume(self, volume: Volume) -> VolumeProvisioningData:
100108
return VolumeProvisioningData(
101-
volume_id=volume.volume_id,
109+
volume_id=get_or_error(volume.volume_id),
102110
size_gb=volume.configuration.size_gb,
103111
)
104112

@@ -111,8 +119,10 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
111119
def delete_volume(self, volume: Volume):
112120
pass
113121

114-
def attach_volume(self, volume: Volume, provisioning_data: JobProvisioningData):
115-
pass
122+
def attach_volume(
123+
self, volume: Volume, provisioning_data: JobProvisioningData
124+
) -> VolumeAttachmentData:
125+
return VolumeAttachmentData(device_name=None)
116126

117127
def detach_volume(
118128
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False

src/dstack/_internal/core/backends/template/compute.py.jinja

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterator
12
from typing import List, Optional
23

34
from dstack._internal.core.backends.base.backend import Compute
@@ -47,7 +48,7 @@ class {{ backend_name }}Compute(
4748

4849
def get_offers(
4950
self, requirements: Requirements
50-
) -> List[InstanceOfferWithAvailability]:
51+
) -> Iterator[InstanceOfferWithAvailability]:
5152
# If the provider is added to gpuhunt, you'd typically get offers
5253
# using `get_catalog_offers()` and extend them with availability info.
5354
offers = get_catalog_offers(
@@ -57,13 +58,13 @@ class {{ backend_name }}Compute(
5758
# configurable_disk_size=..., TODO: set in case of boot volume size limits
5859
)
5960
# TODO: Add availability info to offers
60-
return [
61+
return (
6162
InstanceOfferWithAvailability(
6263
**offer.dict(),
6364
availability=InstanceAvailability.UNKNOWN,
6465
)
6566
for offer in offers
66-
]
67+
)
6768

6869
def create_instance(
6970
self,

src/dstack/_internal/core/backends/tensordock/compute.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

src/dstack/_internal/server/services/backends/__init__.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import heapq
3+
from collections.abc import Iterable, Iterator
34
from typing import Callable, Coroutine, Dict, List, Optional, Tuple
45
from uuid import UUID
56

@@ -338,12 +339,23 @@ async def get_project_backend_model_by_type_or_error(
338339
return backend_model
339340

340341

341-
async def get_instance_offers(
342-
backends: List[Backend], requirements: Requirements, exclude_not_available: bool = False
343-
) -> List[Tuple[Backend, InstanceOfferWithAvailability]]:
342+
async def get_backend_offers(
343+
backends: List[Backend],
344+
requirements: Requirements,
345+
exclude_not_available: bool = False,
346+
) -> Iterator[Tuple[Backend, InstanceOfferWithAvailability]]:
344347
"""
345-
Returns list of instances satisfying minimal resource requirements sorted by price
348+
Yields backend offers satisfying `requirements` sorted by price.
346349
"""
350+
351+
def get_filtered_offers_with_backends(
352+
backend: Backend,
353+
offers: Iterable[InstanceOfferWithAvailability],
354+
) -> Iterator[Tuple[Backend, InstanceOfferWithAvailability]]:
355+
for offer in offers:
356+
if not exclude_not_available or offer.availability.is_available():
357+
yield (backend, offer)
358+
347359
logger.info("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
348360
tasks = [run_async(backend.compute().get_offers, requirements) for backend in backends]
349361
offers_by_backend = []
@@ -362,17 +374,10 @@ async def get_instance_offers(
362374
exc_info=result,
363375
)
364376
continue
365-
offers_by_backend.append(
366-
[
367-
(backend, offer)
368-
for offer in result
369-
if not exclude_not_available or offer.availability.is_available()
370-
]
371-
)
372-
# Merge preserving order for every backend
377+
offers_by_backend.append(get_filtered_offers_with_backends(backend, result))
378+
# Merge preserving order for every backend.
373379
offers = heapq.merge(*offers_by_backend, key=lambda i: i[1].price)
374-
# Put NOT_AVAILABLE, NO_QUOTA, and BUSY instances at the end, do not sort by price
375-
return sorted(offers, key=lambda i: not i[1].availability.is_available())
380+
return offers
376381

377382

378383
def check_backend_type_available(backend_type: BackendType):

0 commit comments

Comments
 (0)