Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9e2a6b0
Cache GCP offers with availability
r4victor Sep 10, 2025
b183ae7
refactor: update get_offers method signature to remove optional requi…
r4victor Sep 10, 2025
a3b5136
Introduce ComputeWithAllOffersCached
r4victor Sep 10, 2025
191a408
feat: migrate AWSCompute to use ComputeWithAllOffersCached with reser…
r4victor Sep 10, 2025
f3ceb96
refactor: update compute classes to use flexible requirements filtering
r4victor Sep 10, 2025
ee49d7d
Cache AWS offers with availability
r4victor Sep 10, 2025
dbbf3dc
refactor: migrate AzureCompute to use ComputeWithAllOffersCached
r4victor Sep 10, 2025
43a8b63
refactor: migrate CloudriftCompute to use ComputeWithAllOffersCached
r4victor Sep 10, 2025
fa6d39b
refactor: migrate DatacrunchCompute to use ComputeWithAllOffersCached
r4victor Sep 10, 2025
693a33f
fix missing Compute
r4victor Sep 10, 2025
aa3e6ac
Migrate all backends to ComputeWithAllOffersCached
r4victor Sep 10, 2025
29a0fbc
refactor: inherit from ComputeWithAllOffersCached and update get_offe…
r4victor Sep 10, 2025
16ba873
Move by requirements cache to ComputeWithFilteredOffersCached
r4victor Sep 11, 2025
c64e01e
Implement get_offers_modifier for AWS
r4victor Sep 11, 2025
cadd0f1
Implement get_offers_modifier for all backends with CONFIGURABLE_DISK…
r4victor Sep 11, 2025
044c14d
Fix backend offers
r4victor Sep 11, 2025
03d15b3
Fix nebius
r4victor Sep 11, 2025
469a9e2
Fix oci
r4victor Sep 11, 2025
a8babca
Use ComputeWithAllOffersCached for kuberenetes
r4victor Sep 11, 2025
9707064
Cache AWS.get_offers_post_filter
r4victor Sep 11, 2025
9d64349
Update template
r4victor Sep 11, 2025
7a5de8f
Fix tests
r4victor Sep 11, 2025
e14c6fb
Lint
r4victor Sep 11, 2025
d6dece8
Merge branch 'master' into pr_offers_with_availability_cache
r4victor Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 51 additions & 36 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import boto3
import botocore.client
Expand All @@ -18,6 +18,7 @@
)
from dstack._internal.core.backends.base.compute import (
Compute,
ComputeWithAllOffersCached,
ComputeWithCreateInstanceSupport,
ComputeWithGatewaySupport,
ComputeWithMultinodeSupport,
Expand All @@ -32,7 +33,7 @@
get_user_data,
merge_tags,
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
from dstack._internal.core.errors import (
ComputeError,
NoCapacityError,
Expand Down Expand Up @@ -87,6 +88,7 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):


class AWSCompute(
ComputeWithAllOffersCached,
ComputeWithCreateInstanceSupport,
ComputeWithMultinodeSupport,
ComputeWithReservationSupport,
Expand All @@ -109,6 +111,8 @@ def __init__(self, config: AWSConfig):
# Caches to avoid redundant API calls when provisioning many instances
# get_offers is already cached but we still cache its sub-functions
# with more aggressive/longer caches.
self._offers_post_filter_cache_lock = threading.Lock()
self._offers_post_filter_cache = TTLCache(maxsize=10, ttl=180)
self._get_regions_to_quotas_cache_lock = threading.Lock()
self._get_regions_to_quotas_execution_lock = threading.Lock()
self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300)
Expand All @@ -125,43 +129,11 @@ def __init__(self, config: AWSConfig):
self._get_image_id_and_username_cache_lock = threading.Lock()
self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600)

def get_offers(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
filter = _supported_instances
if requirements and requirements.reservation:
region_to_reservation = {}
for region in self.config.regions:
reservation = aws_resources.get_reservation(
ec2_client=self.session.client("ec2", region_name=region),
reservation_id=requirements.reservation,
instance_count=1,
)
if reservation is not None:
region_to_reservation[region] = reservation

def _supported_instances_with_reservation(offer: InstanceOffer) -> bool:
# Filter: only instance types supported by dstack
if not _supported_instances(offer):
return False
# Filter: Spot instances can't be used with reservations
if offer.instance.resources.spot:
return False
region = offer.region
reservation = region_to_reservation.get(region)
# Filter: only instance types matching the capacity reservation
if not bool(reservation and offer.instance.name == reservation["InstanceType"]):
return False
return True

filter = _supported_instances_with_reservation

def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
offers = get_catalog_offers(
backend=BackendType.AWS,
locations=self.config.regions,
requirements=requirements,
configurable_disk_size=CONFIGURABLE_DISK_SIZE,
extra_filter=filter,
extra_filter=_supported_instances,
)
regions = list(set(i.region for i in offers))
with self._get_regions_to_quotas_execution_lock:
Expand All @@ -185,6 +157,49 @@ def _supported_instances_with_reservation(offer: InstanceOffer) -> bool:
)
return availability_offers

def get_offers_modifier(
self, requirements: Requirements
) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)

def _get_offers_cached_key(self, requirements: Requirements) -> int:
# Requirements is not hashable, so we use a hack to get arguments hash
return hash(requirements.json())

@cachedmethod(
cache=lambda self: self._offers_post_filter_cache,
key=_get_offers_cached_key,
lock=lambda self: self._offers_post_filter_cache_lock,
)
def get_offers_post_filter(
self, requirements: Requirements
) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]:
if requirements.reservation:
region_to_reservation = {}
for region in get_or_error(self.config.regions):
reservation = aws_resources.get_reservation(
ec2_client=self.session.client("ec2", region_name=region),
reservation_id=requirements.reservation,
instance_count=1,
)
if reservation is not None:
region_to_reservation[region] = reservation

def reservation_filter(offer: InstanceOfferWithAvailability) -> bool:
# Filter: Spot instances can't be used with reservations
if offer.instance.resources.spot:
return False
region = offer.region
reservation = region_to_reservation.get(region)
# Filter: only instance types matching the capacity reservation
if not bool(reservation and offer.instance.name == reservation["InstanceType"]):
return False
return True

return reservation_filter

return None

def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
) -> None:
Expand Down
17 changes: 10 additions & 7 deletions src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
Expand Down Expand Up @@ -39,6 +39,7 @@
from dstack._internal.core.backends.azure.models import AzureConfig
from dstack._internal.core.backends.base.compute import (
Compute,
ComputeWithAllOffersCached,
ComputeWithCreateInstanceSupport,
ComputeWithGatewaySupport,
ComputeWithMultinodeSupport,
Expand All @@ -48,7 +49,7 @@
get_user_data,
merge_tags,
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
from dstack._internal.core.errors import ComputeError, NoCapacityError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.gateways import (
Expand All @@ -73,6 +74,7 @@


class AzureCompute(
ComputeWithAllOffersCached,
ComputeWithCreateInstanceSupport,
ComputeWithMultinodeSupport,
ComputeWithGatewaySupport,
Expand All @@ -89,14 +91,10 @@ def __init__(self, config: AzureConfig, credential: TokenCredential):
credential=credential, subscription_id=config.subscription_id
)

def get_offers(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
offers = get_catalog_offers(
backend=BackendType.AZURE,
locations=self.config.regions,
requirements=requirements,
configurable_disk_size=CONFIGURABLE_DISK_SIZE,
extra_filter=_supported_instances,
)
offers_with_availability = _get_offers_with_availability(
Expand All @@ -106,6 +104,11 @@ def get_offers(
)
return offers_with_availability

def get_offers_modifier(
self, requirements: Requirements
) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]:
return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)

def create_instance(
self,
instance_offer: InstanceOfferWithAvailability,
Expand Down
110 changes: 96 additions & 14 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from collections.abc import Iterable
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Literal, Optional
from typing import Callable, Dict, List, Literal, Optional

import git
import requests
import yaml
from cachetools import TTLCache, cachedmethod

from dstack._internal import settings
from dstack._internal.core.backends.base.offers import filter_offers_by_requirements
from dstack._internal.core.consts import (
DSTACK_RUNNER_HTTP_PORT,
DSTACK_RUNNER_SSH_PORT,
Expand Down Expand Up @@ -57,14 +58,8 @@ class Compute(ABC):
If a compute supports additional features, it must also subclass `ComputeWith*` classes.
"""

def __init__(self):
self._offers_cache_lock = threading.Lock()
self._offers_cache = TTLCache(maxsize=10, ttl=180)

@abstractmethod
def get_offers(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
"""
Returns offers with availability matching `requirements`.
If the provider is added to gpuhunt, typically gets offers using `base.offers.get_catalog_offers()`
Expand Down Expand Up @@ -121,21 +116,108 @@ def update_provisioning_data(
"""
pass

def _get_offers_cached_key(self, requirements: Optional[Requirements] = None) -> int:

class ComputeWithAllOffersCached(ABC):
"""
Provides common `get_offers()` implementation for backends
whose offers do not depend on requirements.
It caches all offers with availability and post-filters by requirements.
"""

def __init__(self) -> None:
super().__init__()
self._offers_cache_lock = threading.Lock()
self._offers_cache = TTLCache(maxsize=1, ttl=180)

@abstractmethod
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
"""
Returns all backend offers with availability.
"""
pass

def get_offers_modifier(
self, requirements: Requirements
) -> Optional[
Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]
]:
"""
Returns a modifier function that modifies offers before they are filtered by requirements.
Can return `None` to exclude the offer.
E.g. can be used to set appropriate disk size based on requirements.
"""
return None

def get_offers_post_filter(
self, requirements: Requirements
) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]:
"""
Returns a filter function to apply to offers based on requirements.
This allows backends to implement custom post-filtering logic for specific requirements.
"""
return None

def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
offers = self._get_all_offers_with_availability_cached()
modifier = self.get_offers_modifier(requirements)
if modifier is not None:
modified_offers = []
for o in offers:
modified_offer = modifier(o)
if modified_offer is not None:
modified_offers.append(modified_offer)
offers = modified_offers
offers = filter_offers_by_requirements(offers, requirements)
post_filter = self.get_offers_post_filter(requirements)
if post_filter is not None:
offers = [o for o in offers if post_filter(o)]
return offers

@cachedmethod(
cache=lambda self: self._offers_cache,
lock=lambda self: self._offers_cache_lock,
)
def _get_all_offers_with_availability_cached(self) -> List[InstanceOfferWithAvailability]:
return self.get_all_offers_with_availability()


class ComputeWithFilteredOffersCached(ABC):
"""
Provides common `get_offers()` implementation for backends
whose offers depend on requirements.
It caches offers using requirements as key.
"""

def __init__(self) -> None:
super().__init__()
self._offers_cache_lock = threading.Lock()
self._offers_cache = TTLCache(maxsize=10, ttl=180)

@abstractmethod
def get_offers_by_requirements(
self, requirements: Requirements
) -> List[InstanceOfferWithAvailability]:
"""
Returns backend offers with availability matching requirements.
"""
pass

def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]:
return self._get_offers_cached(requirements)

def _get_offers_cached_key(self, requirements: Requirements) -> int:
# Requirements is not hashable, so we use a hack to get arguments hash
if requirements is None:
return hash(None)
return hash(requirements.json())

@cachedmethod(
cache=lambda self: self._offers_cache,
key=_get_offers_cached_key,
lock=lambda self: self._offers_cache_lock,
)
def get_offers_cached(
self, requirements: Optional[Requirements] = None
def _get_offers_cached(
self, requirements: Requirements
) -> List[InstanceOfferWithAvailability]:
return self.get_offers(requirements)
return self.get_offers_by_requirements(requirements)


class ComputeWithCreateInstanceSupport(ABC):
Expand Down
Loading
Loading