diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 79da7e41cb..e63a67557e 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -397,7 +397,7 @@ class ProfileProps(CoreModel): Field( description="The name of the profile that can be passed as `--profile` to `dstack apply`" ), - ] + ] = "" default: Annotated[ bool, Field(description="If set to true, `dstack apply` will use this profile by default.") ] = False diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index c85715f0e7..a51863a8a0 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -53,6 +53,7 @@ from dstack._internal.server.services.backends import get_project_backend_by_type_or_error from dstack._internal.server.services.fleets import ( fleet_model_to_fleet, + get_fleet_requirements, ) from dstack._internal.server.services.instances import ( filter_pool_instances, @@ -71,6 +72,10 @@ from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.offers import get_offers_by_requirements +from dstack._internal.server.services.requirements.combine import ( + combine_fleet_and_run_profiles, + combine_fleet_and_run_requirements, +) from dstack._internal.server.services.runs import ( check_run_spec_requires_instance_mounts, run_model_to_run, @@ -646,6 +651,8 @@ async def _run_job_on_new_instance( ) -> Optional[Tuple[JobProvisioningData, InstanceOfferWithAvailability]]: if volumes is None: volumes = [] + profile = run.run_spec.merged_profile + requirements = job.job_spec.requirements fleet = None if fleet_model is not None: fleet = fleet_model_to_fleet(fleet_model) @@ -654,13 +661,26 @@ async def _run_job_on_new_instance( "%s: cannot fit new instance into fleet %s", fmt(job_model), fleet_model.name ) return None + profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, profile) + if profile is None: + logger.debug("%s: cannot combine fleet %s profile", fmt(job_model), fleet_model.name) + return None + fleet_requirements = get_fleet_requirements(fleet.spec) + requirements = combine_fleet_and_run_requirements(fleet_requirements, requirements) + if requirements is None: + logger.debug( + "%s: cannot combine fleet %s requirements", fmt(job_model), fleet_model.name + ) + return None + # TODO: Respect fleet provisioning properties such as tags + multinode = job.job_spec.jobs_per_replica > 1 or ( fleet is not None and fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER ) offers = await get_offers_by_requirements( project=project, - profile=run.run_spec.merged_profile, - requirements=job.job_spec.requirements, + profile=profile, + requirements=requirements, exclude_not_available=True, multinode=multinode, master_job_provisioning_data=master_job_provisioning_data, diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 4f2b64cc5a..33257d3a9c 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -279,7 +279,7 @@ async def get_plan( offers_with_backends = await get_create_instance_offers( project=project, profile=effective_spec.merged_profile, - requirements=_get_fleet_requirements(effective_spec), + requirements=get_fleet_requirements(effective_spec), fleet_spec=effective_spec, blocks=effective_spec.configuration.blocks, ) @@ -458,7 +458,7 @@ async def create_fleet_instance_model( instance_num: int, ) -> InstanceModel: profile = spec.merged_profile - requirements = _get_fleet_requirements(spec) + requirements = get_fleet_requirements(spec) instance_model = await instances_services.create_instance_model( session=session, project=project, @@ -644,6 +644,17 @@ def is_fleet_empty(fleet_model: FleetModel) -> bool: return len(active_instances) == 0 +def get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements: + profile = fleet_spec.merged_profile + requirements = Requirements( + resources=fleet_spec.configuration.resources or ResourcesSpec(), + max_price=profile.max_price, + spot=get_policy_map(profile.spot_policy, default=SpotPolicy.ONDEMAND), + reservation=fleet_spec.configuration.reservation, + ) + return requirements + + async def _create_fleet( session: AsyncSession, project: ProjectModel, @@ -1004,17 +1015,6 @@ def _terminate_fleet_instances(fleet_model: FleetModel, instance_nums: Optional[ instance.status = InstanceStatus.TERMINATING -def _get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements: - profile = fleet_spec.merged_profile - requirements = Requirements( - resources=fleet_spec.configuration.resources or ResourcesSpec(), - max_price=profile.max_price, - spot=get_policy_map(profile.spot_policy, default=SpotPolicy.ONDEMAND), - reservation=fleet_spec.configuration.reservation, - ) - return requirements - - def _get_next_instance_num(instance_nums: set[int]) -> int: if not instance_nums: return 0 diff --git a/src/dstack/_internal/server/services/requirements/__init__.py b/src/dstack/_internal/server/services/requirements/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/_internal/server/services/requirements/combine.py b/src/dstack/_internal/server/services/requirements/combine.py new file mode 100644 index 0000000000..a5830ddef7 --- /dev/null +++ b/src/dstack/_internal/server/services/requirements/combine.py @@ -0,0 +1,259 @@ +from typing import Callable, Optional, Protocol, TypeVar + +from pydantic import BaseModel +from typing_extensions import Self + +from dstack._internal.core.models.profiles import Profile, SpotPolicy +from dstack._internal.core.models.resources import ( + CPUSpec, + DiskSpec, + GPUSpec, + Memory, + Range, + ResourcesSpec, +) +from dstack._internal.core.models.runs import Requirements +from dstack._internal.utils.typing import SupportsRichComparison + + +class CombineError(ValueError): + pass + + +def combine_fleet_and_run_profiles( + fleet_profile: Profile, run_profile: Profile +) -> Optional[Profile]: + """ + Combines fleet and run profile parameters that affect offer selection or provisioning. + """ + try: + return Profile( + backends=_intersect_lists_optional(fleet_profile.backends, run_profile.backends), + regions=_intersect_lists_optional(fleet_profile.regions, run_profile.regions), + availability_zones=_intersect_lists_optional( + fleet_profile.availability_zones, run_profile.availability_zones + ), + instance_types=_intersect_lists_optional( + fleet_profile.instance_types, run_profile.instance_types + ), + reservation=_get_single_value_optional( + fleet_profile.reservation, run_profile.reservation + ), + spot_policy=_combine_spot_policy_optional( + fleet_profile.spot_policy, run_profile.spot_policy + ), + max_price=_get_min_optional(fleet_profile.max_price, run_profile.max_price), + idle_duration=_combine_idle_duration_optional( + fleet_profile.idle_duration, run_profile.idle_duration + ), + tags=_combine_tags_optional(fleet_profile.tags, run_profile.tags), + ) + except CombineError: + return None + + +def combine_fleet_and_run_requirements( + fleet_requirements: Requirements, run_requirements: Requirements +) -> Optional[Requirements]: + try: + return Requirements( + resources=_combine_resources(fleet_requirements.resources, run_requirements.resources), + max_price=_get_min_optional(fleet_requirements.max_price, run_requirements.max_price), + spot=_combine_spot_optional(fleet_requirements.spot, run_requirements.spot), + reservation=_get_single_value_optional( + fleet_requirements.reservation, run_requirements.reservation + ), + ) + except CombineError: + return None + + +_T = TypeVar("_T") +_ModelT = TypeVar("_ModelT", bound=BaseModel) +_CompT = TypeVar("_CompT", bound=SupportsRichComparison) + + +class _SupportsCopy(Protocol): + def copy(self) -> Self: ... + + +_CopyT = TypeVar("_CopyT", bound=_SupportsCopy) + + +def _intersect_lists_optional( + list1: Optional[list[_T]], list2: Optional[list[_T]] +) -> Optional[list[_T]]: + if list1 is None: + if list2 is None: + return None + return list2.copy() + if list2 is None: + return list1.copy() + return [x for x in list1 if x in list2] + + +def _get_min(value1: _CompT, value2: _CompT) -> _CompT: + return min(value1, value2) + + +def _get_min_optional(value1: Optional[_CompT], value2: Optional[_CompT]) -> Optional[_CompT]: + return _combine_optional(value1, value2, _get_min) + + +def _get_single_value(value1: _T, value2: _T) -> _T: + if value1 == value2: + return value1 + raise CombineError(f"Values {value1} and {value2} cannot be combined") + + +def _get_single_value_optional(value1: Optional[_T], value2: Optional[_T]) -> Optional[_T]: + return _combine_optional(value1, value2, _get_single_value) + + +def _combine_spot_policy(value1: SpotPolicy, value2: SpotPolicy) -> SpotPolicy: + if value1 == SpotPolicy.AUTO: + return value2 + if value2 == SpotPolicy.AUTO: + return value1 + if value1 == value2: + return value1 + raise CombineError(f"spot_policy values {value1} and {value2} cannot be combined") + + +def _combine_spot_policy_optional( + value1: Optional[SpotPolicy], value2: Optional[SpotPolicy] +) -> Optional[SpotPolicy]: + return _combine_optional(value1, value2, _combine_spot_policy) + + +def _combine_idle_duration(value1: int, value2: int) -> int: + if value1 < 0 and value2 >= 0 or value2 < 0 and value1 >= 0: + raise CombineError(f"idle_duration values {value1} and {value2} cannot be combined") + return min(value1, value2) + + +def _combine_idle_duration_optional(value1: Optional[int], value2: Optional[int]) -> Optional[int]: + return _combine_optional(value1, value2, _combine_idle_duration) + + +def _combine_tags_optional( + value1: Optional[dict[str, str]], value2: Optional[dict[str, str]] +) -> Optional[dict[str, str]]: + return _combine_copy_optional(value1, value2, _combine_tags) + + +def _combine_tags(value1: dict[str, str], value2: dict[str, str]) -> dict[str, str]: + return value1 | value2 + + +def _combine_resources(value1: ResourcesSpec, value2: ResourcesSpec) -> ResourcesSpec: + return ResourcesSpec( + cpu=_combine_cpu(value1.cpu, value2.cpu), # type: ignore[attr-defined] + memory=_combine_memory(value1.memory, value2.memory), + shm_size=_combine_shm_size_optional(value1.shm_size, value2.shm_size), + gpu=_combine_gpu_optional(value1.gpu, value2.gpu), + disk=_combine_disk_optional(value1.disk, value2.disk), + ) + + +def _combine_cpu(value1: CPUSpec, value2: CPUSpec) -> CPUSpec: + return CPUSpec( + arch=_get_single_value_optional(value1.arch, value2.arch), + count=_combine_range(value1.count, value2.count), + ) + + +def _combine_memory(value1: Range[Memory], value2: Range[Memory]) -> Range[Memory]: + return _combine_range(value1, value2) + + +def _combine_shm_size_optional( + value1: Optional[Memory], value2: Optional[Memory] +) -> Optional[Memory]: + return _get_min_optional(value1, value2) + + +def _combine_gpu(value1: GPUSpec, value2: GPUSpec) -> GPUSpec: + return GPUSpec( + vendor=_get_single_value_optional(value1.vendor, value2.vendor), + name=_intersect_lists_optional(value1.name, value2.name), + count=_combine_range(value1.count, value2.count), + memory=_combine_range_optional(value1.memory, value2.memory), + total_memory=_combine_range_optional(value1.total_memory, value2.total_memory), + compute_capability=_get_min_optional(value1.compute_capability, value2.compute_capability), + ) + + +def _combine_gpu_optional( + value1: Optional[GPUSpec], value2: Optional[GPUSpec] +) -> Optional[GPUSpec]: + return _combine_models_optional(value1, value2, _combine_gpu) + + +def _combine_disk(value1: DiskSpec, value2: DiskSpec) -> DiskSpec: + return DiskSpec(size=_combine_range(value1.size, value2.size)) + + +def _combine_disk_optional( + value1: Optional[DiskSpec], value2: Optional[DiskSpec] +) -> Optional[DiskSpec]: + return _combine_models_optional(value1, value2, _combine_disk) + + +def _combine_spot(value1: bool, value2: bool) -> bool: + if value1 != value2: + raise CombineError(f"spot values {value1} and {value2} cannot be combined") + return value1 + + +def _combine_spot_optional(value1: Optional[bool], value2: Optional[bool]) -> Optional[bool]: + return _combine_optional(value1, value2, _combine_spot) + + +def _combine_range(value1: Range, value2: Range) -> Range: + res = value1.intersect(value2) + if res is None: + raise CombineError(f"Ranges {value1} and {value2} cannot be combined") + return res + + +def _combine_range_optional(value1: Optional[Range], value2: Optional[Range]) -> Optional[Range]: + return _combine_models_optional(value1, value2, _combine_range) + + +def _combine_optional( + value1: Optional[_T], value2: Optional[_T], combiner: Callable[[_T, _T], _T] +) -> Optional[_T]: + if value1 is None: + return value2 + if value2 is None: + return value1 + return combiner(value1, value2) + + +def _combine_models_optional( + value1: Optional[_ModelT], + value2: Optional[_ModelT], + combiner: Callable[[_ModelT, _ModelT], _ModelT], +) -> Optional[_ModelT]: + if value1 is None: + if value2 is not None: + return value2.copy(deep=True) + return None + if value2 is None: + return value1.copy(deep=True) + return combiner(value1, value2) + + +def _combine_copy_optional( + value1: Optional[_CopyT], + value2: Optional[_CopyT], + combiner: Callable[[_CopyT, _CopyT], _CopyT], +) -> Optional[_CopyT]: + if value1 is None: + if value2 is not None: + return value2.copy() + return None + if value2 is None: + return value1.copy() + return combiner(value1, value2) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index bf3d772df8..6745deac71 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -573,7 +573,7 @@ def get_fleet_spec(conf: Optional[FleetConfiguration] = None) -> FleetSpec: return FleetSpec( configuration=conf, configuration_path="fleet.dstack.yml", - profile=Profile(name=""), + profile=Profile(), ) diff --git a/src/dstack/_internal/utils/typing.py b/src/dstack/_internal/utils/typing.py new file mode 100644 index 0000000000..024464a0c3 --- /dev/null +++ b/src/dstack/_internal/utils/typing.py @@ -0,0 +1,14 @@ +from typing import Any, Protocol, TypeVar, Union + +_T_contra = TypeVar("_T_contra", contravariant=True) + + +class SupportsDunderLT(Protocol[_T_contra]): + def __lt__(self, other: _T_contra, /) -> bool: ... + + +class SupportsDunderGT(Protocol[_T_contra]): + def __gt__(self, other: _T_contra, /) -> bool: ... + + +SupportsRichComparison = Union[SupportsDunderLT[Any], SupportsDunderGT[Any]] diff --git a/src/tests/_internal/server/services/requirements/__init__.py b/src/tests/_internal/server/services/requirements/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/services/requirements/test_combine.py b/src/tests/_internal/server/services/requirements/test_combine.py new file mode 100644 index 0000000000..48e6b27a68 --- /dev/null +++ b/src/tests/_internal/server/services/requirements/test_combine.py @@ -0,0 +1,404 @@ +from typing import Optional + +import gpuhunt +import pytest + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.profiles import SpotPolicy +from dstack._internal.core.models.resources import ( + ComputeCapability, + CPUSpec, + DiskSpec, + GPUSpec, + Memory, + Range, + ResourcesSpec, +) +from dstack._internal.core.models.runs import Requirements +from dstack._internal.server.services.requirements.combine import ( + CombineError, + Profile, + _combine_cpu, + _combine_gpu_optional, + _combine_idle_duration_optional, + _combine_resources, + _combine_spot_policy_optional, + _intersect_lists_optional, + combine_fleet_and_run_profiles, + combine_fleet_and_run_requirements, +) + + +class TestCombineFleetAndRunProfiles: + def test_returns_the_same_profile_if_profiles_identical(self): + profile = Profile( + backends=[BackendType.AWS], + regions=["us-west2"], + availability_zones=None, + instance_types=None, + reservation="r-12345", + spot_policy=SpotPolicy.AUTO, + idle_duration=3600, + tags={"tag": "value"}, + ) + assert combine_fleet_and_run_profiles(profile, profile) == profile + + @pytest.mark.parametrize( + argnames=["fleet_profile", "run_profile", "expected_profile"], + argvalues=[ + pytest.param( + Profile(), + Profile(), + Profile(), + id="empty_profile", + ), + pytest.param( + Profile( + backends=[BackendType.AWS, BackendType.GCP], + regions=["eu-west1", "europe-west-4"], + instance_types=["instance1"], + reservation="r-1", + spot_policy=SpotPolicy.AUTO, + idle_duration=3600, + tags={"tag1": "value1"}, + ), + Profile( + backends=[BackendType.GCP, BackendType.RUNPOD], + regions=["eu-west2", "europe-west-4"], + instance_types=["instance2"], + reservation="r-1", + spot_policy=SpotPolicy.SPOT, + idle_duration=7200, + tags={"tag2": "value2"}, + ), + Profile( + backends=[BackendType.GCP], + regions=["europe-west-4"], + instance_types=[], + reservation="r-1", + spot_policy=SpotPolicy.SPOT, + idle_duration=3600, + tags={"tag1": "value1", "tag2": "value2"}, + ), + id="compatible_profiles", + ), + pytest.param( + Profile( + spot_policy=SpotPolicy.SPOT, + ), + Profile( + spot_policy=SpotPolicy.ONDEMAND, + ), + None, + id="incompatible_profiles", + ), + ], + ) + def test_combines_profiles( + self, + fleet_profile: Profile, + run_profile: Profile, + expected_profile: Optional[Profile], + ): + assert combine_fleet_and_run_profiles(fleet_profile, run_profile) == expected_profile + + +class TestCombineFleetAndRunRequirements: + def test_returns_the_same_requirements_if_requirements_identical(self): + requirements = Requirements( + resources=ResourcesSpec(gpu=GPUSpec(count=Range(min=2, max=None))), + max_price=100, + spot=False, + reservation="r-1", + ) + assert combine_fleet_and_run_requirements(requirements, requirements) == requirements + + @pytest.mark.parametrize( + argnames=["fleet_requirements", "run_requirements", "expected_requirements"], + argvalues=[ + pytest.param( + Requirements( + resources=ResourcesSpec(gpu=GPUSpec(count=Range(min=1, max=3))), + max_price=100, + spot=False, + ), + Requirements( + resources=ResourcesSpec(gpu=GPUSpec(count=Range(min=3, max=4))), + max_price=50, + spot=None, + ), + Requirements( + resources=ResourcesSpec(gpu=GPUSpec(count=Range(min=3, max=3))), + max_price=50, + spot=False, + ), + id="compatible_requirements", + ), + pytest.param( + Requirements( + resources=ResourcesSpec(gpu=GPUSpec(count=Range(min=1, max=2))), + ), + Requirements(resources=ResourcesSpec(gpu=GPUSpec(count=Range(min=3, max=4)))), + None, + id="incompatible_requirements", + ), + ], + ) + def test_combines_requirements( + self, + fleet_requirements: Requirements, + run_requirements: Requirements, + expected_requirements: Optional[Requirements], + ): + assert ( + combine_fleet_and_run_requirements(fleet_requirements, run_requirements) + == expected_requirements + ) + + +class TestIntersectLists: + def test_both_none_returns_none(self): + assert _intersect_lists_optional(None, None) is None + + def test_first_none_returns_copy_of_second(self): + list2 = ["a", "b", "c"] + result = _intersect_lists_optional(None, list2) + assert result == list2 + assert result is not list2 # Should be a copy + + def test_second_none_returns_copy_of_first(self): + list1 = ["x", "y", "z"] + result = _intersect_lists_optional(list1, None) + assert result == list1 + assert result is not list1 # Should be a copy + + def test_intersection_of_overlapping_lists(self): + list1 = ["a", "b", "c", "d"] + list2 = ["b", "c", "e", "f"] + result = _intersect_lists_optional(list1, list2) + assert result == ["b", "c"] + + def test_intersection_of_non_overlapping_lists(self): + list1 = ["a", "b"] + list2 = ["c", "d"] + result = _intersect_lists_optional(list1, list2) + assert result == [] + + def test_intersection_preserves_order_from_first_list(self): + list1 = ["c", "a", "b"] + list2 = ["a", "b", "c"] + result = _intersect_lists_optional(list1, list2) + assert result == ["c", "a", "b"] + + def test_intersection_with_duplicates(self): + list1 = ["a", "b", "a", "c"] + list2 = ["a", "c", "d"] + result = _intersect_lists_optional(list1, list2) + assert result == ["a", "a", "c"] + + +class TestCombineIdleDuration: + def test_both_none_returns_none(self): + assert _combine_idle_duration_optional(None, None) is None + + def test_first_none_returns_second(self): + assert _combine_idle_duration_optional(None, 3600) == 3600 + + def test_second_none_returns_first(self): + assert _combine_idle_duration_optional(7200, None) == 7200 + + def test_both_positive_returns_minimum(self): + assert _combine_idle_duration_optional(3600, 7200) == 3600 + assert _combine_idle_duration_optional(7200, 3600) == 3600 + + def test_both_negative_returns_minimum(self): + assert _combine_idle_duration_optional(-1, -2) == -2 + assert _combine_idle_duration_optional(-2, -1) == -2 + + def test_both_zero_returns_zero(self): + assert _combine_idle_duration_optional(0, 0) == 0 + + def test_positive_and_negative_raises_error(self): + with pytest.raises( + CombineError, match="idle_duration values 3600 and -1 cannot be combined" + ): + _combine_idle_duration_optional(3600, -1) + + def test_negative_and_positive_raises_error(self): + with pytest.raises( + CombineError, match="idle_duration values -1 and 3600 cannot be combined" + ): + _combine_idle_duration_optional(-1, 3600) + + def test_zero_and_positive_returns_zero(self): + assert _combine_idle_duration_optional(0, 3600) == 0 + assert _combine_idle_duration_optional(3600, 0) == 0 + + def test_zero_and_negative_raises_error(self): + with pytest.raises(CombineError, match="idle_duration values 0 and -1 cannot be combined"): + _combine_idle_duration_optional(0, -1) + with pytest.raises(CombineError, match="idle_duration values -1 and 0 cannot be combined"): + _combine_idle_duration_optional(-1, 0) + + +class TestCombineSpotPolicy: + def test_both_none_returns_none(self): + assert _combine_spot_policy_optional(None, None) is None + + def test_first_none_returns_second(self): + assert _combine_spot_policy_optional(None, SpotPolicy.SPOT) == SpotPolicy.SPOT + assert _combine_spot_policy_optional(None, SpotPolicy.ONDEMAND) == SpotPolicy.ONDEMAND + assert _combine_spot_policy_optional(None, SpotPolicy.AUTO) == SpotPolicy.AUTO + + def test_second_none_returns_first(self): + assert _combine_spot_policy_optional(SpotPolicy.SPOT, None) == SpotPolicy.SPOT + assert _combine_spot_policy_optional(SpotPolicy.ONDEMAND, None) == SpotPolicy.ONDEMAND + assert _combine_spot_policy_optional(SpotPolicy.AUTO, None) == SpotPolicy.AUTO + + def test_auto_with_other_returns_other(self): + assert _combine_spot_policy_optional(SpotPolicy.AUTO, SpotPolicy.SPOT) == SpotPolicy.SPOT + assert ( + _combine_spot_policy_optional(SpotPolicy.AUTO, SpotPolicy.ONDEMAND) + == SpotPolicy.ONDEMAND + ) + assert _combine_spot_policy_optional(SpotPolicy.SPOT, SpotPolicy.AUTO) == SpotPolicy.SPOT + assert ( + _combine_spot_policy_optional(SpotPolicy.ONDEMAND, SpotPolicy.AUTO) + == SpotPolicy.ONDEMAND + ) + + def test_auto_with_auto_returns_auto(self): + assert _combine_spot_policy_optional(SpotPolicy.AUTO, SpotPolicy.AUTO) == SpotPolicy.AUTO + + def test_same_non_auto_values_return_same(self): + assert _combine_spot_policy_optional(SpotPolicy.SPOT, SpotPolicy.SPOT) == SpotPolicy.SPOT + assert ( + _combine_spot_policy_optional(SpotPolicy.ONDEMAND, SpotPolicy.ONDEMAND) + == SpotPolicy.ONDEMAND + ) + + def test_different_non_auto_values_raise_error(self): + with pytest.raises(CombineError): + _combine_spot_policy_optional(SpotPolicy.SPOT, SpotPolicy.ONDEMAND) + with pytest.raises(CombineError): + _combine_spot_policy_optional(SpotPolicy.ONDEMAND, SpotPolicy.SPOT) + + +class TestCombineResources: + def test_combines_all_resource_specs(self): + resources1 = ResourcesSpec( + cpu=CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=2, max=8)), + memory=Range(min=Memory(4), max=Memory(16)), + shm_size=Memory(2), + gpu=GPUSpec(vendor=gpuhunt.AcceleratorVendor.NVIDIA), + disk=DiskSpec(size=Range(min=Memory(100), max=Memory(500))), + ) + resources2 = ResourcesSpec( + cpu=CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=4, max=6)), + memory=Range(min=Memory(8), max=Memory(12)), + shm_size=Memory(1), + gpu=GPUSpec(vendor=gpuhunt.AcceleratorVendor.NVIDIA), + disk=DiskSpec(size=Range(min=Memory(100), max=Memory(400))), + ) + result = _combine_resources(resources1, resources2) + expected = ResourcesSpec( + cpu=CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=4, max=6)), + memory=Range(min=Memory(8), max=Memory(12)), + shm_size=Memory(1), + gpu=GPUSpec(vendor=gpuhunt.AcceleratorVendor.NVIDIA), + disk=DiskSpec(size=Range(min=Memory(100), max=Memory(400))), + ) + assert result == expected + + +class TestCombineCpu: + def test_combines_compatible_cpu_specs(self): + cpu1 = CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=2, max=8)) + cpu2 = CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=4, max=6)) + result = _combine_cpu(cpu1, cpu2) + expected = CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=4, max=6)) + assert result == expected + + def test_incompatible_architectures_raises_error(self): + cpu1 = CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=2, max=4)) + cpu2 = CPUSpec(arch=gpuhunt.CPUArchitecture.ARM, count=Range(min=2, max=4)) + with pytest.raises(CombineError): + _combine_cpu(cpu1, cpu2) + + def test_non_overlapping_count_ranges_raises_error(self): + cpu1 = CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=1, max=2)) + cpu2 = CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=4, max=6)) + with pytest.raises(CombineError): + _combine_cpu(cpu1, cpu2) + + def test_handles_none_architecture(self): + cpu1 = CPUSpec(arch=None, count=Range(min=2, max=4)) + cpu2 = CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=2, max=4)) + result = _combine_cpu(cpu1, cpu2) + expected = CPUSpec(arch=gpuhunt.CPUArchitecture.X86, count=Range(min=2, max=4)) + assert result == expected + + def test_both_none_architecture(self): + cpu1 = CPUSpec(arch=None, count=Range(min=2, max=4)) + cpu2 = CPUSpec(arch=None, count=Range(min=3, max=5)) + result = _combine_cpu(cpu1, cpu2) + expected = CPUSpec(arch=None, count=Range(min=3, max=4)) + assert result == expected + + +class TestCombineGpu: + def test_both_none_returns_none(self): + assert _combine_gpu_optional(None, None) is None + + def test_first_none_returns_copy_of_second(self): + gpu2 = GPUSpec(count=Range(min=1, max=2)) + result = _combine_gpu_optional(None, gpu2) + assert result == gpu2 + assert result is not gpu2 # Should be a copy + + def test_second_none_returns_copy_of_first(self): + gpu1 = GPUSpec(count=Range(min=2, max=4)) + result = _combine_gpu_optional(gpu1, None) + assert result == gpu1 + assert result is not gpu1 # Should be a copy + + def test_combines_compatible_gpu_specs(self): + gpu1 = GPUSpec( + vendor=gpuhunt.AcceleratorVendor.NVIDIA, + name=["A100", "V100"], + count=Range(min=1, max=4), + memory=Range(min=Memory(8), max=Memory(32)), + compute_capability=ComputeCapability((7, 0)), + ) + gpu2 = GPUSpec( + vendor=gpuhunt.AcceleratorVendor.NVIDIA, + name=["V100", "T4"], + count=Range(min=2, max=3), + memory=Range(min=Memory(16), max=Memory(24)), + compute_capability=ComputeCapability((7, 8)), + ) + assert _combine_gpu_optional(gpu1, gpu2) == GPUSpec( + vendor=gpuhunt.AcceleratorVendor.NVIDIA, + name=["V100"], + count=Range(min=2, max=3), + memory=Range(min=Memory(16), max=Memory(24)), + compute_capability=ComputeCapability((7, 0)), + ) + + def test_incompatible_vendors_raises_error(self): + gpu1 = GPUSpec(vendor=gpuhunt.AcceleratorVendor.NVIDIA, count=Range(min=1, max=2)) + gpu2 = GPUSpec(vendor=gpuhunt.AcceleratorVendor.AMD, count=Range(min=1, max=2)) + with pytest.raises(CombineError): + _combine_gpu_optional(gpu1, gpu2) + + def test_non_overlapping_count_ranges_raises_error(self): + gpu1 = GPUSpec(count=Range(min=1, max=2)) + gpu2 = GPUSpec(count=Range(min=4, max=6)) + with pytest.raises(CombineError): + _combine_gpu_optional(gpu1, gpu2) + + def test_non_overlapping_memory_ranges_raises_error(self): + gpu1 = GPUSpec(count=Range(min=1, max=2), memory=Range(min=Memory(8), max=Memory(16))) + gpu2 = GPUSpec(count=Range(min=1, max=2), memory=Range(min=Memory(32), max=Memory(64))) + with pytest.raises(CombineError): + _combine_gpu_optional(gpu1, gpu2)