diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index b18825543c..0982e146a8 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -1,6 +1,6 @@ import threading from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import boto3 import botocore.client @@ -18,6 +18,7 @@ ) from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithGatewaySupport, ComputeWithMultinodeSupport, @@ -32,7 +33,7 @@ get_user_data, merge_tags, ) -from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier from dstack._internal.core.errors import ( ComputeError, NoCapacityError, @@ -87,6 +88,7 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs): class AWSCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, ComputeWithReservationSupport, @@ -109,6 +111,8 @@ def __init__(self, config: AWSConfig): # Caches to avoid redundant API calls when provisioning many instances # get_offers is already cached but we still cache its sub-functions # with more aggressive/longer caches. + self._offers_post_filter_cache_lock = threading.Lock() + self._offers_post_filter_cache = TTLCache(maxsize=10, ttl=180) self._get_regions_to_quotas_cache_lock = threading.Lock() self._get_regions_to_quotas_execution_lock = threading.Lock() self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300) @@ -125,43 +129,11 @@ def __init__(self, config: AWSConfig): self._get_image_id_and_username_cache_lock = threading.Lock() self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: - filter = _supported_instances - if requirements and requirements.reservation: - region_to_reservation = {} - for region in self.config.regions: - reservation = aws_resources.get_reservation( - ec2_client=self.session.client("ec2", region_name=region), - reservation_id=requirements.reservation, - instance_count=1, - ) - if reservation is not None: - region_to_reservation[region] = reservation - - def _supported_instances_with_reservation(offer: InstanceOffer) -> bool: - # Filter: only instance types supported by dstack - if not _supported_instances(offer): - return False - # Filter: Spot instances can't be used with reservations - if offer.instance.resources.spot: - return False - region = offer.region - reservation = region_to_reservation.get(region) - # Filter: only instance types matching the capacity reservation - if not bool(reservation and offer.instance.name == reservation["InstanceType"]): - return False - return True - - filter = _supported_instances_with_reservation - + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.AWS, locations=self.config.regions, - requirements=requirements, - configurable_disk_size=CONFIGURABLE_DISK_SIZE, - extra_filter=filter, + extra_filter=_supported_instances, ) regions = list(set(i.region for i in offers)) with self._get_regions_to_quotas_execution_lock: @@ -185,6 +157,49 @@ def _supported_instances_with_reservation(offer: InstanceOffer) -> bool: ) return availability_offers + def get_offers_modifier( + self, requirements: Requirements + ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]: + return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements) + + def _get_offers_cached_key(self, requirements: Requirements) -> int: + # Requirements is not hashable, so we use a hack to get arguments hash + return hash(requirements.json()) + + @cachedmethod( + cache=lambda self: self._offers_post_filter_cache, + key=_get_offers_cached_key, + lock=lambda self: self._offers_post_filter_cache_lock, + ) + def get_offers_post_filter( + self, requirements: Requirements + ) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]: + if requirements.reservation: + region_to_reservation = {} + for region in get_or_error(self.config.regions): + reservation = aws_resources.get_reservation( + ec2_client=self.session.client("ec2", region_name=region), + reservation_id=requirements.reservation, + instance_count=1, + ) + if reservation is not None: + region_to_reservation[region] = reservation + + def reservation_filter(offer: InstanceOfferWithAvailability) -> bool: + # Filter: Spot instances can't be used with reservations + if offer.instance.resources.spot: + return False + region = offer.region + reservation = region_to_reservation.get(region) + # Filter: only instance types matching the capacity reservation + if not bool(reservation and offer.instance.name == reservation["InstanceType"]): + return False + return True + + return reservation_filter + + return None + def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ) -> None: diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 6847e7912f..13f619be84 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -2,7 +2,7 @@ import enum import re from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple from azure.core.credentials import TokenCredential from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError @@ -39,6 +39,7 @@ from dstack._internal.core.backends.azure.models import AzureConfig from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithGatewaySupport, ComputeWithMultinodeSupport, @@ -48,7 +49,7 @@ get_user_data, merge_tags, ) -from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier from dstack._internal.core.errors import ComputeError, NoCapacityError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.gateways import ( @@ -73,6 +74,7 @@ class AzureCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, ComputeWithGatewaySupport, @@ -89,14 +91,10 @@ def __init__(self, config: AzureConfig, credential: TokenCredential): credential=credential, subscription_id=config.subscription_id ) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.AZURE, locations=self.config.regions, - requirements=requirements, - configurable_disk_size=CONFIGURABLE_DISK_SIZE, extra_filter=_supported_instances, ) offers_with_availability = _get_offers_with_availability( @@ -106,6 +104,11 @@ def get_offers( ) return offers_with_availability + def get_offers_modifier( + self, requirements: Requirements + ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]: + return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements) + def create_instance( self, instance_offer: InstanceOfferWithAvailability, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index d68ce78b7c..bba603f901 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -7,7 +7,7 @@ from collections.abc import Iterable from functools import lru_cache from pathlib import Path -from typing import Dict, List, Literal, Optional +from typing import Callable, Dict, List, Literal, Optional import git import requests @@ -15,6 +15,7 @@ from cachetools import TTLCache, cachedmethod from dstack._internal import settings +from dstack._internal.core.backends.base.offers import filter_offers_by_requirements from dstack._internal.core.consts import ( DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT, @@ -57,14 +58,8 @@ class Compute(ABC): If a compute supports additional features, it must also subclass `ComputeWith*` classes. """ - def __init__(self): - self._offers_cache_lock = threading.Lock() - self._offers_cache = TTLCache(maxsize=10, ttl=180) - @abstractmethod - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]: """ Returns offers with availability matching `requirements`. If the provider is added to gpuhunt, typically gets offers using `base.offers.get_catalog_offers()` @@ -121,10 +116,97 @@ def update_provisioning_data( """ pass - def _get_offers_cached_key(self, requirements: Optional[Requirements] = None) -> int: + +class ComputeWithAllOffersCached(ABC): + """ + Provides common `get_offers()` implementation for backends + whose offers do not depend on requirements. + It caches all offers with availability and post-filters by requirements. + """ + + def __init__(self) -> None: + super().__init__() + self._offers_cache_lock = threading.Lock() + self._offers_cache = TTLCache(maxsize=1, ttl=180) + + @abstractmethod + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: + """ + Returns all backend offers with availability. + """ + pass + + def get_offers_modifier( + self, requirements: Requirements + ) -> Optional[ + Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]] + ]: + """ + Returns a modifier function that modifies offers before they are filtered by requirements. + Can return `None` to exclude the offer. + E.g. can be used to set appropriate disk size based on requirements. + """ + return None + + def get_offers_post_filter( + self, requirements: Requirements + ) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]: + """ + Returns a filter function to apply to offers based on requirements. + This allows backends to implement custom post-filtering logic for specific requirements. + """ + return None + + def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]: + offers = self._get_all_offers_with_availability_cached() + modifier = self.get_offers_modifier(requirements) + if modifier is not None: + modified_offers = [] + for o in offers: + modified_offer = modifier(o) + if modified_offer is not None: + modified_offers.append(modified_offer) + offers = modified_offers + offers = filter_offers_by_requirements(offers, requirements) + post_filter = self.get_offers_post_filter(requirements) + if post_filter is not None: + offers = [o for o in offers if post_filter(o)] + return offers + + @cachedmethod( + cache=lambda self: self._offers_cache, + lock=lambda self: self._offers_cache_lock, + ) + def _get_all_offers_with_availability_cached(self) -> List[InstanceOfferWithAvailability]: + return self.get_all_offers_with_availability() + + +class ComputeWithFilteredOffersCached(ABC): + """ + Provides common `get_offers()` implementation for backends + whose offers depend on requirements. + It caches offers using requirements as key. + """ + + def __init__(self) -> None: + super().__init__() + self._offers_cache_lock = threading.Lock() + self._offers_cache = TTLCache(maxsize=10, ttl=180) + + @abstractmethod + def get_offers_by_requirements( + self, requirements: Requirements + ) -> List[InstanceOfferWithAvailability]: + """ + Returns backend offers with availability matching requirements. + """ + pass + + def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]: + return self._get_offers_cached(requirements) + + def _get_offers_cached_key(self, requirements: Requirements) -> int: # Requirements is not hashable, so we use a hack to get arguments hash - if requirements is None: - return hash(None) return hash(requirements.json()) @cachedmethod( @@ -132,10 +214,10 @@ def _get_offers_cached_key(self, requirements: Optional[Requirements] = None) -> key=_get_offers_cached_key, lock=lambda self: self._offers_cache_lock, ) - def get_offers_cached( - self, requirements: Optional[Requirements] = None + def _get_offers_cached( + self, requirements: Requirements ) -> List[InstanceOfferWithAvailability]: - return self.get_offers(requirements) + return self.get_offers_by_requirements(requirements) class ComputeWithCreateInstanceSupport(ABC): diff --git a/src/dstack/_internal/core/backends/base/offers.py b/src/dstack/_internal/core/backends/base/offers.py index d3d004172b..41367ac952 100644 --- a/src/dstack/_internal/core/backends/base/offers.py +++ b/src/dstack/_internal/core/backends/base/offers.py @@ -1,5 +1,5 @@ from dataclasses import asdict -from typing import Callable, List, Optional +from typing import Callable, List, Optional, TypeVar import gpuhunt from pydantic import parse_obj_as @@ -9,11 +9,13 @@ Disk, Gpu, InstanceOffer, + InstanceOfferWithAvailability, InstanceType, Resources, ) from dstack._internal.core.models.resources import DEFAULT_DISK, CPUSpec, Memory, Range from dstack._internal.core.models.runs import Requirements +from dstack._internal.utils.common import get_or_error # Offers not supported by all dstack versions are hidden behind one or more flags. # This list enables the flags that are currently supported. @@ -163,9 +165,13 @@ def requirements_to_query_filter(req: Optional[Requirements]) -> gpuhunt.QueryFi return q -def match_requirements( - offers: List[InstanceOffer], requirements: Optional[Requirements] -) -> List[InstanceOffer]: +InstanceOfferT = TypeVar("InstanceOfferT", InstanceOffer, InstanceOfferWithAvailability) + + +def filter_offers_by_requirements( + offers: List[InstanceOfferT], + requirements: Optional[Requirements], +) -> List[InstanceOfferT]: query_filter = requirements_to_query_filter(requirements) filtered_offers = [] for offer in offers: @@ -190,3 +196,27 @@ def choose_disk_size_mib( disk_size_gib = disk_size_range.min return round(disk_size_gib * 1024) + + +def get_offers_disk_modifier( + configurable_disk_size: Range[Memory], requirements: Requirements +) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]: + """ + Returns a func that modifies offers disk by setting min value that satisfies both + `configurable_disk_size` and `requirements`. + """ + + def modifier(offer: InstanceOfferWithAvailability) -> Optional[InstanceOfferWithAvailability]: + requirements_disk_range = DEFAULT_DISK.size + if requirements.resources.disk is not None: + requirements_disk_range = requirements.resources.disk.size + disk_size_range = requirements_disk_range.intersect(configurable_disk_size) + if disk_size_range is None: + return None + offer_copy = offer.copy(deep=True) + offer_copy.instance.resources.disk = Disk( + size_mib=get_or_error(disk_size_range.min) * 1024 + ) + return offer_copy + + return modifier diff --git a/src/dstack/_internal/core/backends/cloudrift/compute.py b/src/dstack/_internal/core/backends/cloudrift/compute.py index 03d9fd74c6..21b6016e76 100644 --- a/src/dstack/_internal/core/backends/cloudrift/compute.py +++ b/src/dstack/_internal/core/backends/cloudrift/compute.py @@ -1,7 +1,8 @@ from typing import Dict, List, Optional -from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + Compute, + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, get_shim_commands, ) @@ -17,13 +18,14 @@ InstanceOfferWithAvailability, ) from dstack._internal.core.models.placement import PlacementGroup -from dstack._internal.core.models.runs import JobProvisioningData, Requirements +from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) class CloudRiftCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, Compute, ): @@ -32,15 +34,11 @@ def __init__(self, config: CloudRiftConfig): self.config = config self.client = RiftClient(self.config.creds.api_key) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.CLOUDRIFT, locations=self.config.regions or None, - requirements=requirements, ) - offers_with_availabilities = self._get_offers_with_availability(offers) return offers_with_availabilities diff --git a/src/dstack/_internal/core/backends/cudo/compute.py b/src/dstack/_internal/core/backends/cudo/compute.py index 4da43b6b2a..23a8721fa6 100644 --- a/src/dstack/_internal/core/backends/cudo/compute.py +++ b/src/dstack/_internal/core/backends/cudo/compute.py @@ -5,6 +5,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( ComputeWithCreateInstanceSupport, + ComputeWithFilteredOffersCached, generate_unique_instance_name, get_shim_commands, ) @@ -29,6 +30,7 @@ class CudoCompute( + ComputeWithFilteredOffersCached, ComputeWithCreateInstanceSupport, Compute, ): @@ -37,8 +39,8 @@ def __init__(self, config: CudoConfig): self.config = config self.api_client = CudoApiClient(config.creds.api_key) - def get_offers( - self, requirements: Optional[Requirements] = None + def get_offers_by_requirements( + self, requirements: Requirements ) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.CUDO, diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 7410fe6742..afc4bf8511 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional from datacrunch import DataCrunchClient from datacrunch.exceptions import APIException @@ -6,11 +6,12 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, generate_unique_instance_name, get_shim_commands, ) -from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier from dstack._internal.core.backends.datacrunch.models import DataCrunchConfig from dstack._internal.core.errors import NoCapacityError from dstack._internal.core.models.backends.base import BackendType @@ -36,6 +37,7 @@ class DataCrunchCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, Compute, ): @@ -47,18 +49,19 @@ def __init__(self, config: DataCrunchConfig): client_secret=self.config.creds.client_secret, ) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.DATACRUNCH, locations=self.config.regions, - requirements=requirements, - configurable_disk_size=CONFIGURABLE_DISK_SIZE, ) offers_with_availability = self._get_offers_with_availability(offers) return offers_with_availability + def get_offers_modifier( + self, requirements: Requirements + ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]: + return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements) + def _get_offers_with_availability( self, offers: List[InstanceOffer] ) -> List[InstanceOfferWithAvailability]: diff --git a/src/dstack/_internal/core/backends/digitalocean_base/compute.py b/src/dstack/_internal/core/backends/digitalocean_base/compute.py index d8eb878ba1..cc338df053 100644 --- a/src/dstack/_internal/core/backends/digitalocean_base/compute.py +++ b/src/dstack/_internal/core/backends/digitalocean_base/compute.py @@ -5,6 +5,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, generate_unique_instance_name, get_user_data, @@ -20,7 +21,7 @@ InstanceOfferWithAvailability, ) from dstack._internal.core.models.placement import PlacementGroup -from dstack._internal.core.models.runs import JobProvisioningData, Requirements +from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -37,6 +38,7 @@ class BaseDigitalOceanCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, Compute, ): @@ -50,13 +52,10 @@ def __init__(self, config: BaseDigitalOceanConfig, api_url: str, type: BackendTy DigitalOceanProvider(api_key=config.creds.api_key, api_url=api_url) ) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=self.BACKEND_TYPE, locations=self.config.regions, - requirements=requirements, catalog=self.catalog, ) return [ diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 506308ef4a..820205360b 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -17,6 +17,7 @@ from dstack import version from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithGatewaySupport, ComputeWithMultinodeSupport, @@ -31,7 +32,10 @@ get_user_data, merge_tags, ) -from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.base.offers import ( + get_catalog_offers, + get_offers_disk_modifier, +) from dstack._internal.core.backends.gcp.features import tcpx as tcpx_features from dstack._internal.core.backends.gcp.models import GCPConfig from dstack._internal.core.errors import ( @@ -82,6 +86,7 @@ class GCPVolumeDiskBackendData(CoreModel): class GCPCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, ComputeWithPlacementGroupSupport, @@ -107,14 +112,10 @@ def __init__(self, config: GCPConfig): self._extra_subnets_cache_lock = threading.Lock() self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: regions = get_or_error(self.config.regions) offers = get_catalog_offers( backend=BackendType.GCP, - requirements=requirements, - configurable_disk_size=CONFIGURABLE_DISK_SIZE, extra_filter=_supported_instances_and_zones(regions), ) quotas: Dict[str, Dict[str, float]] = defaultdict(dict) @@ -142,9 +143,13 @@ def get_offers( offer_keys_to_offers[key] = offer_with_availability offers_with_availability.append(offer_with_availability) offers_with_availability[-1].region = region - return offers_with_availability + def get_offers_modifier( + self, requirements: Requirements + ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]: + return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements) + def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ) -> None: diff --git a/src/dstack/_internal/core/backends/hotaisle/compute.py b/src/dstack/_internal/core/backends/hotaisle/compute.py index 8aa83b88ca..47e7526d3d 100644 --- a/src/dstack/_internal/core/backends/hotaisle/compute.py +++ b/src/dstack/_internal/core/backends/hotaisle/compute.py @@ -9,6 +9,7 @@ from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, get_shim_commands, ) @@ -23,7 +24,7 @@ InstanceOfferWithAvailability, ) from dstack._internal.core.models.placement import PlacementGroup -from dstack._internal.core.models.runs import JobProvisioningData, Requirements +from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -44,6 +45,7 @@ class HotAisleCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, Compute, ): @@ -56,16 +58,12 @@ def __init__(self, config: HotAisleConfig): HotAisleProvider(api_key=config.creds.api_key, team_handle=config.team_handle) ) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.HOTAISLE, locations=self.config.regions or None, - requirements=requirements, catalog=self.catalog, ) - supported_offers = [] for offer in offers: if offer.instance.name in INSTANCE_TYPE_SPECS: @@ -78,7 +76,6 @@ def get_offers( logger.warning( f"Skipping unsupported Hot Aisle instance type: {offer.instance.name}" ) - return supported_offers def get_payload_from_offer(self, instance_type) -> dict: diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index b5213c74d3..8307c7672c 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -9,13 +9,14 @@ from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithFilteredOffersCached, ComputeWithGatewaySupport, generate_unique_gateway_instance_name, generate_unique_instance_name_for_job, get_docker_commands, get_dstack_gateway_commands, ) -from dstack._internal.core.backends.base.offers import match_requirements +from dstack._internal.core.backends.base.offers import filter_offers_by_requirements from dstack._internal.core.backends.kubernetes.models import ( KubernetesConfig, KubernetesNetworkingConfig, @@ -58,6 +59,7 @@ class KubernetesCompute( + ComputeWithFilteredOffersCached, ComputeWithGatewaySupport, Compute, ): @@ -70,8 +72,8 @@ def __init__(self, config: KubernetesConfig): self.networking_config = networking_config self.api = get_api_from_config_data(config.kubeconfig.data) - def get_offers( - self, requirements: Optional[Requirements] = None + def get_offers_by_requirements( + self, requirements: Requirements ) -> List[InstanceOfferWithAvailability]: nodes = self.api.list_node() instance_offers = [] @@ -99,7 +101,7 @@ def get_offers( availability=InstanceAvailability.AVAILABLE, instance_runtime=InstanceRuntime.RUNNER, ) - instance_offers.extend(match_requirements([instance_offer], requirements)) + instance_offers.extend(filter_offers_by_requirements([instance_offer], requirements)) return instance_offers def run_job( diff --git a/src/dstack/_internal/core/backends/lambdalabs/compute.py b/src/dstack/_internal/core/backends/lambdalabs/compute.py index aead3e1eb0..d460300725 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/compute.py +++ b/src/dstack/_internal/core/backends/lambdalabs/compute.py @@ -7,6 +7,7 @@ from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, generate_unique_instance_name, get_shim_commands, @@ -22,12 +23,13 @@ InstanceOfferWithAvailability, ) from dstack._internal.core.models.placement import PlacementGroup -from dstack._internal.core.models.runs import JobProvisioningData, Requirements +from dstack._internal.core.models.runs import JobProvisioningData MAX_INSTANCE_NAME_LEN = 60 class LambdaCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, Compute, ): @@ -36,13 +38,10 @@ def __init__(self, config: LambdaConfig): self.config = config self.api_client = LambdaAPIClient(config.creds.api_key) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.LAMBDA, locations=self.config.regions or None, - requirements=requirements, ) offers_with_availability = self._get_offers_with_availability(offers) return offers_with_availability diff --git a/src/dstack/_internal/core/backends/local/compute.py b/src/dstack/_internal/core/backends/local/compute.py index 7f9e257f35..125d74b4c7 100644 --- a/src/dstack/_internal/core/backends/local/compute.py +++ b/src/dstack/_internal/core/backends/local/compute.py @@ -28,9 +28,7 @@ class LocalCompute( ComputeWithVolumeSupport, Compute, ): - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]: return [ InstanceOfferWithAvailability( backend=BackendType.LOCAL, diff --git a/src/dstack/_internal/core/backends/nebius/compute.py b/src/dstack/_internal/core/backends/nebius/compute.py index 36131f5972..9e6b399a4b 100644 --- a/src/dstack/_internal/core/backends/nebius/compute.py +++ b/src/dstack/_internal/core/backends/nebius/compute.py @@ -3,7 +3,7 @@ import shlex import time from functools import cached_property -from typing import List, Optional +from typing import Callable, List, Optional from nebius.aio.operation import Operation as SDKOperation from nebius.aio.service_error import RequestError, StatusCode @@ -12,13 +12,14 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, ComputeWithPlacementGroupSupport, generate_unique_instance_name, get_user_data, ) -from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier from dstack._internal.core.backends.nebius import resources from dstack._internal.core.backends.nebius.fabrics import get_suitable_infiniband_fabrics from dstack._internal.core.backends.nebius.models import NebiusConfig, NebiusServiceAccountCreds @@ -76,6 +77,7 @@ class NebiusCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, ComputeWithPlacementGroupSupport, @@ -106,15 +108,11 @@ def _get_subnet_id(self, region: str) -> str: ).metadata.id return self._subnet_id_cache[region] - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.NEBIUS, locations=list(self._region_to_project_id), - requirements=requirements, extra_filter=_supported_instances, - configurable_disk_size=CONFIGURABLE_DISK_SIZE, ) return [ InstanceOfferWithAvailability( @@ -124,6 +122,11 @@ def get_offers( for offer in offers ] + def get_offers_modifier( + self, requirements: Requirements + ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]: + return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements) + def create_instance( self, instance_offer: InstanceOfferWithAvailability, diff --git a/src/dstack/_internal/core/backends/oci/compute.py b/src/dstack/_internal/core/backends/oci/compute.py index 00c097bc59..eaf87603be 100644 --- a/src/dstack/_internal/core/backends/oci/compute.py +++ b/src/dstack/_internal/core/backends/oci/compute.py @@ -1,17 +1,18 @@ from concurrent.futures import ThreadPoolExecutor from functools import cached_property -from typing import List, Optional +from typing import Callable, List, Optional import oci from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, generate_unique_instance_name, get_user_data, ) -from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier from dstack._internal.core.backends.oci import resources from dstack._internal.core.backends.oci.models import OCIConfig from dstack._internal.core.backends.oci.region import make_region_clients_map @@ -47,6 +48,7 @@ class OCICompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, Compute, @@ -60,14 +62,10 @@ def __init__(self, config: OCIConfig): def shapes_quota(self) -> resources.ShapesQuota: return resources.ShapesQuota.load(self.regions, self.config.compartment_id) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.OCI, locations=self.config.regions, - requirements=requirements, - configurable_disk_size=CONFIGURABLE_DISK_SIZE, extra_filter=_supported_instances, ) @@ -96,6 +94,11 @@ def get_offers( return offers_with_availability + def get_offers_modifier( + self, requirements: Requirements + ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]: + return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements) + def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ) -> None: diff --git a/src/dstack/_internal/core/backends/runpod/compute.py b/src/dstack/_internal/core/backends/runpod/compute.py index eb52b4eec8..9b7fa6e652 100644 --- a/src/dstack/_internal/core/backends/runpod/compute.py +++ b/src/dstack/_internal/core/backends/runpod/compute.py @@ -1,17 +1,18 @@ import json import uuid from datetime import timedelta -from typing import List, Optional +from typing import Callable, List, Optional from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithAllOffersCached, ComputeWithVolumeSupport, generate_unique_instance_name, generate_unique_volume_name, get_docker_commands, get_job_instance_name, ) -from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier from dstack._internal.core.backends.runpod.api_client import RunpodApiClient from dstack._internal.core.backends.runpod.models import RunpodConfig from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT @@ -27,6 +28,7 @@ InstanceOfferWithAvailability, SSHKey, ) +from dstack._internal.core.models.resources import Memory, Range from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume, VolumeProvisioningData from dstack._internal.utils.common import get_current_datetime @@ -39,8 +41,12 @@ CONTAINER_REGISTRY_AUTH_CLEANUP_INTERVAL = 60 * 60 * 24 # 24 hour +# RunPod does not seem to have any limits on the disk size. +CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("1GB"), max=None) + class RunpodCompute( + ComputeWithAllOffersCached, ComputeWithVolumeSupport, Compute, ): @@ -51,13 +57,11 @@ def __init__(self, config: RunpodConfig): self.config = config self.api_client = RunpodApiClient(config.creds.api_key) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.RUNPOD, locations=self.config.regions or None, - requirements=requirements, + requirements=None, extra_filter=lambda o: _is_secure_cloud(o.region) or self.config.allow_community_cloud, ) offers = [ @@ -68,6 +72,11 @@ def get_offers( ] return offers + def get_offers_modifier( + self, requirements: Requirements + ) -> Callable[[InstanceOfferWithAvailability], Optional[InstanceOfferWithAvailability]]: + return get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements) + def run_job( self, run: Run, diff --git a/src/dstack/_internal/core/backends/template/compute.py.jinja b/src/dstack/_internal/core/backends/template/compute.py.jinja index 51ffbfdd53..8eb95e32d4 100644 --- a/src/dstack/_internal/core/backends/template/compute.py.jinja +++ b/src/dstack/_internal/core/backends/template/compute.py.jinja @@ -2,6 +2,7 @@ from typing import List, Optional from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithGatewaySupport, ComputeWithMultinodeSupport, @@ -28,6 +29,7 @@ logger = get_logger(__name__) class {{ backend_name }}Compute( # TODO: Choose ComputeWith* classes to extend and implement + # ComputeWithAllOffersCached, # ComputeWithCreateInstanceSupport, # ComputeWithMultinodeSupport, # ComputeWithReservationSupport, @@ -42,7 +44,7 @@ class {{ backend_name }}Compute( self.config = config def get_offers( - self, requirements: Optional[Requirements] = None + self, requirements: Requirements ) -> List[InstanceOfferWithAvailability]: # If the provider is added to gpuhunt, you'd typically get offers # using `get_catalog_offers()` and extend them with availability info. diff --git a/src/dstack/_internal/core/backends/tensordock/compute.py b/src/dstack/_internal/core/backends/tensordock/compute.py new file mode 100644 index 0000000000..44daa1e7e3 --- /dev/null +++ b/src/dstack/_internal/core/backends/tensordock/compute.py @@ -0,0 +1,120 @@ +import json +from typing import List, Optional + +import requests + +from dstack._internal.core.backends.base.backend import Compute +from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, + generate_unique_instance_name, + get_shim_commands, +) +from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.tensordock.api_client import TensorDockAPIClient +from dstack._internal.core.backends.tensordock.models import TensorDockConfig +from dstack._internal.core.errors import NoCapacityError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceConfiguration, + InstanceOfferWithAvailability, +) +from dstack._internal.core.models.placement import PlacementGroup +from dstack._internal.core.models.runs import JobProvisioningData, Requirements +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +# Undocumented but names of len 60 work +MAX_INSTANCE_NAME_LEN = 60 + + +class TensorDockCompute( + ComputeWithCreateInstanceSupport, + Compute, +): + def __init__(self, config: TensorDockConfig): + super().__init__() + self.config = config + self.api_client = TensorDockAPIClient(config.creds.api_key, config.creds.api_token) + + def get_offers(self, requirements: Requirements) -> List[InstanceOfferWithAvailability]: + offers = get_catalog_offers( + backend=BackendType.TENSORDOCK, + requirements=requirements, + ) + offers = [ + InstanceOfferWithAvailability( + **offer.dict(), availability=InstanceAvailability.AVAILABLE + ) + for offer in offers + ] + return offers + + def create_instance( + self, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + placement_group: Optional[PlacementGroup], + ) -> JobProvisioningData: + instance_name = generate_unique_instance_name( + instance_config, max_length=MAX_INSTANCE_NAME_LEN + ) + commands = get_shim_commands(authorized_keys=instance_config.get_public_keys()) + try: + resp = self.api_client.deploy_single( + instance_name=instance_name, + instance=instance_offer.instance, + cloudinit={ + "ssh_pwauth": False, # disable password auth + "users": [ + "default", + { + "name": "user", + "ssh_authorized_keys": instance_config.get_public_keys(), + }, + ], + "runcmd": [ + ["sh", "-c", " && ".join(commands)], + ], + "write_files": [ + { + "path": "/etc/docker/daemon.json", + "content": json.dumps( + { + "runtimes": { + "nvidia": { + "path": "nvidia-container-runtime", + "runtimeArgs": [], + } + }, + "exec-opts": ["native.cgroupdriver=cgroupfs"], + } + ), + } + ], + }, + ) + except requests.HTTPError as e: + logger.warning("Got error from tensordock: %s", e) + raise NoCapacityError() + return JobProvisioningData( + backend=instance_offer.backend, + instance_type=instance_offer.instance, + instance_id=resp["server"], + hostname=resp["ip"], + internal_ip=None, + region=instance_offer.region, + price=instance_offer.price, + username="user", + ssh_port={int(v): int(k) for k, v in resp["port_forwards"].items()}[22], + dockerized=True, + ssh_proxy=None, + backend_data=None, + ) + + def terminate_instance( + self, instance_id: str, region: str, backend_data: Optional[str] = None + ): + self.api_client.delete_single_if_exists(instance_id) diff --git a/src/dstack/_internal/core/backends/vastai/compute.py b/src/dstack/_internal/core/backends/vastai/compute.py index e18f8e1313..86391cc093 100644 --- a/src/dstack/_internal/core/backends/vastai/compute.py +++ b/src/dstack/_internal/core/backends/vastai/compute.py @@ -5,6 +5,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithFilteredOffersCached, generate_unique_instance_name_for_job, get_docker_commands, ) @@ -30,7 +31,10 @@ MAX_INSTANCE_NAME_LEN = 60 -class VastAICompute(Compute): +class VastAICompute( + ComputeWithFilteredOffersCached, + Compute, +): def __init__(self, config: VastAIConfig): super().__init__() self.config = config @@ -49,8 +53,8 @@ def __init__(self, config: VastAIConfig): ) ) - def get_offers( - self, requirements: Optional[Requirements] = None + def get_offers_by_requirements( + self, requirements: Requirements ) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.VASTAI, diff --git a/src/dstack/_internal/core/backends/vultr/compute.py b/src/dstack/_internal/core/backends/vultr/compute.py index a6b102b71f..016d0a8c50 100644 --- a/src/dstack/_internal/core/backends/vultr/compute.py +++ b/src/dstack/_internal/core/backends/vultr/compute.py @@ -6,6 +6,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, generate_unique_instance_name, @@ -23,7 +24,7 @@ InstanceOfferWithAvailability, ) from dstack._internal.core.models.placement import PlacementGroup -from dstack._internal.core.models.runs import JobProvisioningData, Requirements +from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -32,6 +33,7 @@ class VultrCompute( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, Compute, @@ -41,12 +43,10 @@ def __init__(self, config: VultrConfig): self.config = config self.api_client = VultrApiClient(config.creds.api_key) - def get_offers( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.VULTR, - requirements=requirements, + requirements=None, locations=self.config.regions or None, extra_filter=_supported_instances, ) diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 7613d75550..38350d9ca7 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -345,7 +345,7 @@ async def get_instance_offers( Returns list of instances satisfying minimal resource requirements sorted by price """ logger.info("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends]) - tasks = [run_async(backend.compute().get_offers_cached, requirements) for backend in backends] + tasks = [run_async(backend.compute().get_offers, requirements) for backend in backends] offers_by_backend = [] for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)): if isinstance(result, BackendError): diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index 8255007073..c1983fbed6 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -729,7 +729,7 @@ async def test_creates_instance( availability=InstanceAvailability.AVAILABLE, ) backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.get_offers_cached.return_value = [offer] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.create_instance.return_value = JobProvisioningData( backend=offer.backend, instance_type=offer.instance, @@ -762,13 +762,13 @@ async def test_tries_second_offer_if_first_fails(self, session: AsyncSession, er aws_mock.TYPE = BackendType.AWS offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) - aws_mock.compute.return_value.get_offers_cached.return_value = [offer] + aws_mock.compute.return_value.get_offers.return_value = [offer] aws_mock.compute.return_value.create_instance.side_effect = err gcp_mock = Mock() gcp_mock.TYPE = BackendType.GCP offer = get_instance_offer_with_availability(backend=BackendType.GCP, price=2.0) gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) - gcp_mock.compute.return_value.get_offers_cached.return_value = [offer] + gcp_mock.compute.return_value.get_offers.return_value = [offer] gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( backend=offer.backend, region=offer.region, price=offer.price ) @@ -791,7 +791,7 @@ async def test_fails_if_all_offers_fail(self, session: AsyncSession, err: Except aws_mock.TYPE = BackendType.AWS offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) - aws_mock.compute.return_value.get_offers_cached.return_value = [offer] + aws_mock.compute.return_value.get_offers.return_value = [offer] aws_mock.compute.return_value.create_instance.side_effect = err with patch("dstack._internal.server.services.backends.get_project_backends") as m: m.return_value = [aws_mock] @@ -903,7 +903,7 @@ async def test_create_placement_group_if_placement_cluster( backend_mock = Mock() backend_mock.TYPE = BackendType.AWS backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.get_offers_cached.return_value = [ + backend_mock.compute.return_value.get_offers.return_value = [ get_instance_offer_with_availability() ] backend_mock.compute.return_value.create_instance.return_value = ( @@ -951,7 +951,7 @@ async def test_reuses_placement_group_between_offers_if_the_group_is_suitable( backend_mock = Mock() backend_mock.TYPE = BackendType.AWS backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.get_offers_cached.return_value = [ + backend_mock.compute.return_value.get_offers.return_value = [ get_instance_offer_with_availability(instance_type="bad-offer-1"), get_instance_offer_with_availability(instance_type="bad-offer-2"), get_instance_offer_with_availability(instance_type="good-offer"), @@ -1010,7 +1010,7 @@ async def test_handles_create_placement_group_errors( backend_mock = Mock() backend_mock.TYPE = BackendType.AWS backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.get_offers_cached.return_value = [ + backend_mock.compute.return_value.get_offers.return_value = [ get_instance_offer_with_availability(instance_type="bad-offer"), get_instance_offer_with_availability(instance_type="good-offer"), ] diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 109dd4f2e8..868bfb6355 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -125,11 +125,11 @@ async def test_provisions_job( backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = backend - backend_mock.compute.return_value.get_offers_cached.return_value = [offer] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data() await process_submitted_jobs() m.assert_called_once() - backend_mock.compute.return_value.get_offers_cached.assert_called_once() + backend_mock.compute.return_value.get_offers.assert_called_once() backend_mock.compute.return_value.run_job.assert_called_once() await session.refresh(job) @@ -172,13 +172,13 @@ async def test_fails_job_when_privileged_true_and_no_offers_with_create_instance backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = BackendType.RUNPOD - backend_mock.compute.return_value.get_offers_cached.return_value = [offer] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data() with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: datetime_mock.return_value = datetime(2023, 1, 2, 3, 30, 0, tzinfo=timezone.utc) await process_submitted_jobs() m.assert_called_once() - backend_mock.compute.return_value.get_offers_cached.assert_not_called() + backend_mock.compute.return_value.get_offers.assert_not_called() backend_mock.compute.return_value.run_job.assert_not_called() await session.refresh(job) @@ -222,13 +222,13 @@ async def test_fails_job_when_instance_mounts_and_no_offers_with_create_instance backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = BackendType.RUNPOD - backend_mock.compute.return_value.get_offers_cached.return_value = [offer] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data() with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: datetime_mock.return_value = datetime(2023, 1, 2, 3, 30, 0, tzinfo=timezone.utc) await process_submitted_jobs() m.assert_called_once() - backend_mock.compute.return_value.get_offers_cached.assert_not_called() + backend_mock.compute.return_value.get_offers.assert_not_called() backend_mock.compute.return_value.run_job.assert_not_called() await session.refresh(job) @@ -274,7 +274,7 @@ async def test_provisions_job_with_optional_instance_volume_not_attached( backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = BackendType.RUNPOD - backend_mock.compute.return_value.get_offers_cached.return_value = [offer] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data() await process_submitted_jobs() @@ -693,11 +693,11 @@ async def test_creates_new_instance_in_existing_non_empty_fleet( backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = BackendType.AWS - backend_mock.compute.return_value.get_offers_cached.return_value = [offer] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data() await process_submitted_jobs() m.assert_called_once() - backend_mock.compute.return_value.get_offers_cached.assert_called_once() + backend_mock.compute.return_value.get_offers.assert_called_once() backend_mock.compute.return_value.run_job.assert_called_once() await session.refresh(job) @@ -884,11 +884,11 @@ async def test_creates_new_instance_in_existing_empty_fleet( backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = BackendType.AWS - backend_mock.compute.return_value.get_offers_cached.return_value = [offer] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data() await process_submitted_jobs() m.assert_called_once() - backend_mock.compute.return_value.get_offers_cached.assert_called_once() + backend_mock.compute.return_value.get_offers.assert_called_once() backend_mock.compute.return_value.run_job.assert_called_once() await session.refresh(job) diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 33fc73e019..934f333b63 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1065,13 +1065,13 @@ async def test_returns_create_plan_for_new_fleet( backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = BackendType.AWS - backend_mock.compute.return_value.get_offers_cached.return_value = offers + backend_mock.compute.return_value.get_offers.return_value = offers response = await client.post( f"/api/project/{project.name}/fleets/get_plan", headers=get_auth_headers(user.token), json={"spec": spec.dict()}, ) - backend_mock.compute.return_value.get_offers_cached.assert_called_once() + backend_mock.compute.return_value.get_offers.assert_called_once() assert response.status_code == 200 assert response.json() == { diff --git a/src/tests/_internal/server/routers/test_gpus.py b/src/tests/_internal/server/routers/test_gpus.py index 8116e2ceba..d07a92bb2f 100644 --- a/src/tests/_internal/server/routers/test_gpus.py +++ b/src/tests/_internal/server/routers/test_gpus.py @@ -84,7 +84,7 @@ def create_mock_backends_with_offers( for backend_type, offers in offers_by_backend.items(): backend_mock = Mock() backend_mock.TYPE = backend_type - backend_mock.compute.return_value.get_offers_cached.return_value = offers + backend_mock.compute.return_value.get_offers.return_value = offers mocked_backends.append(backend_mock) return mocked_backends @@ -161,7 +161,7 @@ async def test_returns_empty_gpus_when_no_offers( with patch("dstack._internal.server.services.backends.get_project_backends") as m: backend_mock_aws = Mock() backend_mock_aws.TYPE = BackendType.AWS - backend_mock_aws.compute.return_value.get_offers_cached.return_value = [] + backend_mock_aws.compute.return_value.get_offers.return_value = [] m.return_value = [backend_mock_aws] response = await client.post( @@ -310,7 +310,7 @@ async def test_exact_aggregation_values( with patch("dstack._internal.server.services.backends.get_project_backends") as m: backend_mock_aws = Mock() backend_mock_aws.TYPE = BackendType.AWS - backend_mock_aws.compute.return_value.get_offers_cached.return_value = [ + backend_mock_aws.compute.return_value.get_offers.return_value = [ offer_t4_spot, offer_t4_ondemand, offer_t4_quota, @@ -319,7 +319,7 @@ async def test_exact_aggregation_values( backend_mock_runpod = Mock() backend_mock_runpod.TYPE = BackendType.RUNPOD - backend_mock_runpod.compute.return_value.get_offers_cached.return_value = [ + backend_mock_runpod.compute.return_value.get_offers.return_value = [ offer_runpod_rtx_east, offer_runpod_rtx_eu, offer_runpod_t4_east, diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index efd571ef10..b087be8a99 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -997,12 +997,10 @@ async def test_returns_run_plan_privileged_false( with patch("dstack._internal.server.services.backends.get_project_backends") as m: backend_mock_aws = Mock() backend_mock_aws.TYPE = BackendType.AWS - backend_mock_aws.compute.return_value.get_offers_cached.return_value = [offer_aws] + backend_mock_aws.compute.return_value.get_offers.return_value = [offer_aws] backend_mock_runpod = Mock() backend_mock_runpod.TYPE = BackendType.RUNPOD - backend_mock_runpod.compute.return_value.get_offers_cached.return_value = [ - offer_runpod - ] + backend_mock_runpod.compute.return_value.get_offers.return_value = [offer_runpod] m.return_value = [backend_mock_aws, backend_mock_runpod] response = await client.post( f"/api/project/{project.name}/runs/get_plan", @@ -1059,12 +1057,10 @@ async def test_returns_run_plan_privileged_true( with patch("dstack._internal.server.services.backends.get_project_backends") as m: backend_mock_aws = Mock() backend_mock_aws.TYPE = BackendType.AWS - backend_mock_aws.compute.return_value.get_offers_cached.return_value = [offer_aws] + backend_mock_aws.compute.return_value.get_offers.return_value = [offer_aws] backend_mock_runpod = Mock() backend_mock_runpod.TYPE = BackendType.RUNPOD - backend_mock_runpod.compute.return_value.get_offers_cached.return_value = [ - offer_runpod - ] + backend_mock_runpod.compute.return_value.get_offers.return_value = [offer_runpod] m.return_value = [backend_mock_aws, backend_mock_runpod] response = await client.post( f"/api/project/{project.name}/runs/get_plan", @@ -1121,12 +1117,10 @@ async def test_returns_run_plan_docker_true( with patch("dstack._internal.server.services.backends.get_project_backends") as m: backend_mock_aws = Mock() backend_mock_aws.TYPE = BackendType.AWS - backend_mock_aws.compute.return_value.get_offers_cached.return_value = [offer_aws] + backend_mock_aws.compute.return_value.get_offers.return_value = [offer_aws] backend_mock_runpod = Mock() backend_mock_runpod.TYPE = BackendType.RUNPOD - backend_mock_runpod.compute.return_value.get_offers_cached.return_value = [ - offer_runpod - ] + backend_mock_runpod.compute.return_value.get_offers.return_value = [offer_runpod] m.return_value = [backend_mock_aws, backend_mock_runpod] response = await client.post( f"/api/project/{project.name}/runs/get_plan", @@ -1183,12 +1177,10 @@ async def test_returns_run_plan_instance_volumes( with patch("dstack._internal.server.services.backends.get_project_backends") as m: backend_mock_aws = Mock() backend_mock_aws.TYPE = BackendType.AWS - backend_mock_aws.compute.return_value.get_offers_cached.return_value = [offer_aws] + backend_mock_aws.compute.return_value.get_offers.return_value = [offer_aws] backend_mock_runpod = Mock() backend_mock_runpod.TYPE = BackendType.RUNPOD - backend_mock_runpod.compute.return_value.get_offers_cached.return_value = [ - offer_runpod - ] + backend_mock_runpod.compute.return_value.get_offers.return_value = [offer_runpod] m.return_value = [backend_mock_aws, backend_mock_runpod] response = await client.post( f"/api/project/{project.name}/runs/get_plan", diff --git a/src/tests/_internal/server/services/test_offers.py b/src/tests/_internal/server/services/test_offers.py index 8c97a0e4fd..3e67bc7c3f 100644 --- a/src/tests/_internal/server/services/test_offers.py +++ b/src/tests/_internal/server/services/test_offers.py @@ -23,13 +23,11 @@ async def test_returns_all_offers(self): aws_backend_mock = Mock() aws_backend_mock.TYPE = BackendType.AWS aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS) - aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer] + aws_backend_mock.compute.return_value.get_offers.return_value = [aws_offer] runpod_backend_mock = Mock() runpod_backend_mock.TYPE = BackendType.RUNPOD runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD) - runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [ - runpod_offer - ] + runpod_backend_mock.compute.return_value.get_offers.return_value = [runpod_offer] m.return_value = [aws_backend_mock, runpod_backend_mock] res = await get_offers_by_requirements( project=Mock(), @@ -47,13 +45,11 @@ async def test_returns_multinode_offers(self): aws_backend_mock = Mock() aws_backend_mock.TYPE = BackendType.AWS aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS) - aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer] + aws_backend_mock.compute.return_value.get_offers.return_value = [aws_offer] runpod_backend_mock = Mock() runpod_backend_mock.TYPE = BackendType.RUNPOD runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD) - runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [ - runpod_offer - ] + runpod_backend_mock.compute.return_value.get_offers.return_value = [runpod_offer] m.return_value = [aws_backend_mock, runpod_backend_mock] res = await get_offers_by_requirements( project=Mock(), @@ -72,7 +68,7 @@ async def test_returns_volume_offers(self): aws_backend_mock = Mock() aws_backend_mock.TYPE = BackendType.AWS aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS) - aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer] + aws_backend_mock.compute.return_value.get_offers.return_value = [aws_offer] runpod_backend_mock = Mock() runpod_backend_mock.TYPE = BackendType.RUNPOD runpod_offer1 = get_instance_offer_with_availability( @@ -81,7 +77,7 @@ async def test_returns_volume_offers(self): runpod_offer2 = get_instance_offer_with_availability( backend=BackendType.RUNPOD, region="us" ) - runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [ + runpod_backend_mock.compute.return_value.get_offers.return_value = [ runpod_offer1, runpod_offer2, ] @@ -124,7 +120,7 @@ async def test_returns_az_offers(self): aws_offer4 = get_instance_offer_with_availability( backend=BackendType.AWS, availability_zones=None ) - aws_backend_mock.compute.return_value.get_offers_cached.return_value = [ + aws_backend_mock.compute.return_value.get_offers.return_value = [ aws_offer1, aws_offer2, aws_offer3, @@ -148,13 +144,11 @@ async def test_returns_no_offers_for_multinode_instance_mounts_and_non_multinode aws_backend_mock = Mock() aws_backend_mock.TYPE = BackendType.AWS aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS) - aws_backend_mock.compute.return_value.get_offers_cached.return_value = [aws_offer] + aws_backend_mock.compute.return_value.get_offers.return_value = [aws_offer] runpod_backend_mock = Mock() runpod_backend_mock.TYPE = BackendType.RUNPOD runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD) - runpod_backend_mock.compute.return_value.get_offers_cached.return_value = [ - runpod_offer - ] + runpod_backend_mock.compute.return_value.get_offers.return_value = [runpod_offer] m.return_value = [aws_backend_mock, runpod_backend_mock] res = await get_offers_by_requirements( project=Mock(),