diff --git a/pyproject.toml b/pyproject.toml index 5cceee68bf..3d4f6f1cb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ include = [ "src/dstack/_internal/server", "src/dstack/_internal/core/services", "src/dstack/_internal/core/backends/kubernetes", + "src/dstack/_internal/core/backends/runpod", "src/dstack/_internal/cli/services/configurators", "src/dstack/_internal/cli/commands", ] diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 5609829339..7f70cddd5e 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -17,6 +17,7 @@ from gpuhunt import CPUArchitecture from dstack._internal import settings +from dstack._internal.core.backends.base.models import JobConfiguration from dstack._internal.core.backends.base.offers import OfferModifier, filter_offers_by_requirements from dstack._internal.core.consts import ( DSTACK_RUNNER_HTTP_PORT, @@ -24,6 +25,7 @@ DSTACK_SHIM_HTTP_PORT, ) from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.compute_groups import ComputeGroup, ComputeGroupProvisioningData from dstack._internal.core.models.configurations import LEGACY_REPO_DIR from dstack._internal.core.models.gateways import ( GatewayComputeConfiguration, @@ -324,6 +326,23 @@ def _restrict_instance_offer_az_to_volumes_az( ] +class ComputeWithGroupProvisioningSupport(ABC): + @abstractmethod + def run_jobs( + self, + run: Run, + job_configurations: List[JobConfiguration], + instance_offer: InstanceOfferWithAvailability, + project_ssh_public_key: str, + project_ssh_private_key: str, + ) -> ComputeGroupProvisioningData: + pass + + @abstractmethod + def terminate_compute_group(self, compute_group: ComputeGroup): + pass + + class ComputeWithPrivilegedSupport: """ Must be subclassed to support runs with `privileged: true`. diff --git a/src/dstack/_internal/core/backends/base/models.py b/src/dstack/_internal/core/backends/base/models.py index 00cecb7a72..b65024c1bb 100644 --- a/src/dstack/_internal/core/backends/base/models.py +++ b/src/dstack/_internal/core/backends/base/models.py @@ -1,4 +1,14 @@ from pathlib import Path +from typing import List + +from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.runs import Job +from dstack._internal.core.models.volumes import Volume + + +class JobConfiguration(CoreModel): + job: Job + volumes: List[Volume] def fill_data(values: dict, filename_field: str = "filename", data_field: str = "data") -> dict: diff --git a/src/dstack/_internal/core/backends/base/offers.py b/src/dstack/_internal/core/backends/base/offers.py index ea20055b71..c8236002bf 100644 --- a/src/dstack/_internal/core/backends/base/offers.py +++ b/src/dstack/_internal/core/backends/base/offers.py @@ -25,6 +25,7 @@ "gcp-a4", "gcp-g4", "gcp-dws-calendar-mode", + "runpod-cluster", ] diff --git a/src/dstack/_internal/core/backends/features.py b/src/dstack/_internal/core/backends/features.py index 76d64470b6..4bb9f99905 100644 --- a/src/dstack/_internal/core/backends/features.py +++ b/src/dstack/_internal/core/backends/features.py @@ -1,6 +1,7 @@ from dstack._internal.core.backends.base.compute import ( ComputeWithCreateInstanceSupport, ComputeWithGatewaySupport, + ComputeWithGroupProvisioningSupport, ComputeWithMultinodeSupport, ComputeWithPlacementGroupSupport, ComputeWithPrivateGatewaySupport, @@ -39,6 +40,10 @@ def _get_backends_with_compute_feature( configurator_classes=_configurator_classes, compute_feature_class=ComputeWithCreateInstanceSupport, ) +BACKENDS_WITH_GROUP_PROVISIONING_SUPPORT = _get_backends_with_compute_feature( + configurator_classes=_configurator_classes, + compute_feature_class=ComputeWithGroupProvisioningSupport, +) BACKENDS_WITH_PRIVILEGED_SUPPORT = _get_backends_with_compute_feature( configurator_classes=_configurator_classes, compute_feature_class=ComputeWithPrivilegedSupport, diff --git a/src/dstack/_internal/core/backends/runpod/api_client.py b/src/dstack/_internal/core/backends/runpod/api_client.py index 59d8e20045..b38254a6fb 100644 --- a/src/dstack/_internal/core/backends/runpod/api_client.py +++ b/src/dstack/_internal/core/backends/runpod/api_client.py @@ -11,6 +11,14 @@ API_URL = "https://api.runpod.io/graphql" +class RunpodApiClientError(BackendError): + errors: List[Dict] + + def __init__(self, errors: List[Dict]): + self.errors = errors + super().__init__(errors) + + class RunpodApiClient: def __init__(self, api_key: str): self.api_key = api_key @@ -23,7 +31,19 @@ def validate_api_key(self) -> bool: return True def get_user_details(self) -> Dict: - resp = self._make_request({"query": user_details_query, "variable": {}}) + resp = self._make_request( + { + "query": """ + query myself { + myself { + id + authId + email + } + } + """ + } + ) return resp.json() def create_pod( @@ -52,28 +72,28 @@ def create_pod( ) -> Dict: resp = self._make_request( { - "query": generate_pod_deployment_mutation( - name, - image_name, - gpu_type_id, - cloud_type, - support_public_ip, - start_ssh, - data_center_id, - country_code, - gpu_count, - volume_in_gb, - container_disk_in_gb, - min_vcpu_count, - min_memory_in_gb, - docker_args, - ports, - volume_mount_path, - env, - template_id, - network_volume_id, - allowed_cuda_versions, - bid_per_gpu, + "query": _generate_pod_deployment_mutation( + name=name, + image_name=image_name, + gpu_type_id=gpu_type_id, + cloud_type=cloud_type, + support_public_ip=support_public_ip, + start_ssh=start_ssh, + data_center_id=data_center_id, + country_code=country_code, + gpu_count=gpu_count, + volume_in_gb=volume_in_gb, + container_disk_in_gb=container_disk_in_gb, + min_vcpu_count=min_vcpu_count, + min_memory_in_gb=min_memory_in_gb, + docker_args=docker_args, + ports=ports, + volume_mount_path=volume_mount_path, + env=env, + template_id=template_id, + network_volume_id=network_volume_id, + allowed_cuda_versions=allowed_cuda_versions, + bid_per_gpu=bid_per_gpu, ) } ) @@ -86,7 +106,9 @@ def edit_pod( image_name: str, container_disk_in_gb: int, container_registry_auth_id: str, - volume_in_gb: int = 0, + # Default pod volume is 20GB. + # RunPod errors if it's not specified for podEditJob. + volume_in_gb: int = 20, ) -> str: resp = self._make_request( { @@ -108,12 +130,12 @@ def edit_pod( return resp.json()["data"]["podEditJob"]["id"] def get_pod(self, pod_id: str) -> Dict: - resp = self._make_request({"query": generate_pod_query(pod_id)}) + resp = self._make_request({"query": _generate_pod_query(pod_id)}) data = resp.json() return data["data"]["pod"] def terminate_pod(self, pod_id: str) -> Dict: - resp = self._make_request({"query": generate_pod_terminate_mutation(pod_id)}) + resp = self._make_request({"query": _generate_pod_terminate_mutation(pod_id)}) data = resp.json() return data["data"] @@ -213,7 +235,7 @@ def create_network_volume(self, name: str, region: str, size: int) -> str: ) return response.json()["data"]["createNetworkVolume"]["id"] - def delete_network_volume(self, volume_id: str): + def delete_network_volume(self, volume_id: str) -> None: self._make_request( { "query": f""" @@ -228,7 +250,66 @@ def delete_network_volume(self, volume_id: str): } ) - def _make_request(self, data: Any = None) -> Response: + def create_cluster( + self, + cluster_name: str, + gpu_type_id: str, + pod_count: int, + gpu_count_per_pod: int, + image_name: str, + deploy_cost: str, + template_id: Optional[str] = None, + cluster_type: str = "TRAINING", + network_volume_id: Optional[str] = None, + volume_in_gb: Optional[int] = None, + throughput: Optional[int] = None, + allowed_cuda_versions: Optional[List[str]] = None, + volume_key: Optional[str] = None, + data_center_id: Optional[str] = None, + start_jupyter: bool = False, + start_ssh: bool = False, + container_disk_in_gb: Optional[int] = None, + docker_args: Optional[str] = None, + env: Optional[Dict[str, Any]] = None, + volume_mount_path: Optional[str] = None, + ports: Optional[str] = None, + ) -> Dict: + resp = self._make_request( + { + "query": _generate_create_cluster_mutation( + cluster_name=cluster_name, + gpu_type_id=gpu_type_id, + pod_count=pod_count, + gpu_count_per_pod=gpu_count_per_pod, + image_name=image_name, + cluster_type=cluster_type, + deploy_cost=deploy_cost, + template_id=template_id, + network_volume_id=network_volume_id, + volume_in_gb=volume_in_gb, + throughput=throughput, + allowed_cuda_versions=allowed_cuda_versions, + volume_key=volume_key, + data_center_id=data_center_id, + start_jupyter=start_jupyter, + start_ssh=start_ssh, + container_disk_in_gb=container_disk_in_gb, + docker_args=docker_args, + env=env, + volume_mount_path=volume_mount_path, + ports=ports, + ) + } + ) + data = resp.json()["data"] + return data["createCluster"] + + def delete_cluster(self, cluster_id: str) -> bool: + resp = self._make_request({"query": _generate_delete_cluster_mutation(cluster_id)}) + data = resp.json()["data"] + return data["deleteCluster"] + + def _make_request(self, data: Optional[Dict[str, Any]] = None) -> Response: try: response = requests.request( method="POST", @@ -237,10 +318,10 @@ def _make_request(self, data: Any = None) -> Response: timeout=120, ) response.raise_for_status() - if "errors" in response.json(): - if "podTerminate" in response.json()["errors"][0]["path"]: - raise BackendError("Instance Not Found") - raise BackendError(response.json()["errors"][0]["message"]) + response_json = response.json() + # RunPod returns 200 on client errors + if "errors" in response_json: + raise RunpodApiClientError(errors=response_json["errors"]) return response except requests.HTTPError as e: if e.response is not None and e.response.status_code in ( @@ -250,7 +331,7 @@ def _make_request(self, data: Any = None) -> Response: raise BackendInvalidCredentialsError(e.response.text) raise - def wait_for_instance(self, instance_id) -> Optional[Dict]: + def wait_for_instance(self, instance_id: str) -> Optional[Dict]: start = get_current_datetime() wait_for_instance_interval = 5 # To change the status to "running," the image must be pulled and then started. @@ -263,18 +344,7 @@ def wait_for_instance(self, instance_id) -> Optional[Dict]: return -user_details_query = """ -query myself { - myself { - id - authId - email - } -} -""" - - -def generate_pod_query(pod_id: str) -> str: +def _generate_pod_query(pod_id: str) -> str: """ Generate a query for a specific GPU type """ @@ -283,6 +353,7 @@ def generate_pod_query(pod_id: str) -> str: query pod {{ pod(input: {{podId: "{pod_id}"}}) {{ id + clusterIp containerDiskInGb costPerHr desiredStatus @@ -319,26 +390,26 @@ def generate_pod_query(pod_id: str) -> str: """ -def generate_pod_deployment_mutation( +def _generate_pod_deployment_mutation( name: str, image_name: str, gpu_type_id: str, cloud_type: str, support_public_ip: bool = True, start_ssh: bool = True, - data_center_id=None, - country_code=None, - gpu_count=None, - volume_in_gb=None, - container_disk_in_gb=None, - min_vcpu_count=None, - min_memory_in_gb=None, - docker_args=None, - ports=None, - volume_mount_path=None, + data_center_id: Optional[str] = None, + country_code: Optional[str] = None, + gpu_count: Optional[int] = None, + volume_in_gb: Optional[int] = None, + container_disk_in_gb: Optional[int] = None, + min_vcpu_count: Optional[int] = None, + min_memory_in_gb: Optional[int] = None, + docker_args: Optional[str] = None, + ports: Optional[str] = None, + volume_mount_path: Optional[str] = None, env: Optional[Dict[str, Any]] = None, - template_id=None, - network_volume_id=None, + template_id: Optional[str] = None, + network_volume_id: Optional[str] = None, allowed_cuda_versions: Optional[List[str]] = None, bid_per_gpu: Optional[float] = None, ) -> str: @@ -425,7 +496,7 @@ def generate_pod_deployment_mutation( """ -def generate_pod_terminate_mutation(pod_id: str) -> str: +def _generate_pod_terminate_mutation(pod_id: str) -> str: """ Generates a mutation to terminate a pod. """ @@ -434,3 +505,118 @@ def generate_pod_terminate_mutation(pod_id: str) -> str: podTerminate(input: {{ podId: "{pod_id}" }}) }} """ + + +def _generate_delete_cluster_mutation(cluster_id: str) -> str: + """ + Generates a mutation to delete a cluster. + """ + return f""" + mutation {{ + deleteCluster( + input: {{ + id: "{cluster_id}" + }} + ) + }} + """ + + +def _generate_create_cluster_mutation( + cluster_name: str, + gpu_type_id: str, + pod_count: int, + gpu_count_per_pod: int, + image_name: str, + cluster_type: str, + deploy_cost: str, + template_id: Optional[str] = None, + network_volume_id: Optional[str] = None, + volume_in_gb: Optional[int] = None, + throughput: Optional[int] = None, + allowed_cuda_versions: Optional[List[str]] = None, + volume_key: Optional[str] = None, + data_center_id: Optional[str] = None, + start_jupyter: bool = False, + start_ssh: bool = False, + container_disk_in_gb: Optional[int] = None, + docker_args: Optional[str] = None, + env: Optional[Dict[str, Any]] = None, + volume_mount_path: Optional[str] = None, + ports: Optional[str] = None, +) -> str: + """ + Generates a mutation to create a cluster. + """ + input_fields = [] + + # ------------------------------ Required Fields ----------------------------- # + input_fields.append(f'clusterName: "{cluster_name}"') + input_fields.append(f'gpuTypeId: "{gpu_type_id}"') + input_fields.append(f"podCount: {pod_count}") + input_fields.append(f'imageName: "{image_name}"') + input_fields.append(f"type: {cluster_type}") + input_fields.append(f"gpuCountPerPod: {gpu_count_per_pod}") + # If deploy_cost is not specified, Runpod returns Insufficient resources error. + input_fields.append(f"deployCost: {deploy_cost}") + + # ------------------------------ Optional Fields ----------------------------- # + if template_id is not None: + input_fields.append(f'templateId: "{template_id}"') + if network_volume_id is not None: + input_fields.append(f'networkVolumeId: "{network_volume_id}"') + if volume_in_gb is not None: + input_fields.append(f"volumeInGb: {volume_in_gb}") + if throughput is not None: + input_fields.append(f"throughput: {throughput}") + if allowed_cuda_versions is not None: + allowed_cuda_versions_string = ", ".join( + [f'"{version}"' for version in allowed_cuda_versions] + ) + input_fields.append(f"allowedCudaVersions: [{allowed_cuda_versions_string}]") + if volume_key is not None: + input_fields.append(f'volumeKey: "{volume_key}"') + if data_center_id is not None: + input_fields.append(f'dataCenterId: "{data_center_id}"') + if start_jupyter: + input_fields.append("startJupyter: true") + if start_ssh: + input_fields.append("startSsh: true") + if container_disk_in_gb is not None: + input_fields.append(f"containerDiskInGb: {container_disk_in_gb}") + if docker_args is not None: + input_fields.append(f'dockerArgs: "{docker_args}"') + if env is not None: + env_string = ", ".join( + [f'{{ key: "{key}", value: "{value}" }}' for key, value in env.items()] + ) + input_fields.append(f"env: [{env_string}]") + if volume_mount_path is not None: + input_fields.append(f'volumeMountPath: "{volume_mount_path}"') + if ports is not None: + ports = ports.replace(" ", "") + input_fields.append(f'ports: "{ports}"') + + # Format input fields + input_string = ", ".join(input_fields) + return f""" + mutation {{ + createCluster( + input: {{ + {input_string} + }} + ) {{ + id + name + pods {{ + id + clusterIp + lastStatusChange + imageName + machine {{ + podHostId + }} + }} + }} + }} + """ diff --git a/src/dstack/_internal/core/backends/runpod/compute.py b/src/dstack/_internal/core/backends/runpod/compute.py index 29e521d249..e78ffa1f11 100644 --- a/src/dstack/_internal/core/backends/runpod/compute.py +++ b/src/dstack/_internal/core/backends/runpod/compute.py @@ -2,31 +2,34 @@ import uuid from collections.abc import Iterable from datetime import timedelta -from typing import List, Optional +from typing import Callable, List, Optional from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( ComputeWithAllOffersCached, + ComputeWithGroupProvisioningSupport, + ComputeWithMultinodeSupport, ComputeWithVolumeSupport, generate_unique_instance_name, generate_unique_volume_name, get_docker_commands, get_job_instance_name, ) +from dstack._internal.core.backends.base.models import JobConfiguration from dstack._internal.core.backends.base.offers import ( OfferModifier, get_catalog_offers, get_offers_disk_modifier, ) -from dstack._internal.core.backends.runpod.api_client import RunpodApiClient +from dstack._internal.core.backends.runpod.api_client import RunpodApiClient, RunpodApiClientError from dstack._internal.core.backends.runpod.models import RunpodConfig from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ( - BackendError, ComputeError, ) from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import RegistryAuth +from dstack._internal.core.models.common import CoreModel, RegistryAuth +from dstack._internal.core.models.compute_groups import ComputeGroup, ComputeGroupProvisioningData from dstack._internal.core.models.instances import ( InstanceAvailability, InstanceConfiguration, @@ -36,7 +39,7 @@ from dstack._internal.core.models.resources import Memory, Range from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume, VolumeProvisioningData -from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.common import get_current_datetime, get_or_error from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -50,9 +53,15 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("1GB"), max=None) +class RunpodOfferBackendData(CoreModel): + pod_counts: Optional[list[int]] = None + + class RunpodCompute( ComputeWithAllOffersCached, ComputeWithVolumeSupport, + ComputeWithMultinodeSupport, + ComputeWithGroupProvisioningSupport, Compute, ): _last_cleanup_time = None @@ -80,6 +89,18 @@ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability def get_offers_modifiers(self, requirements: Requirements) -> Iterable[OfferModifier]: return [get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)] + def get_offers_post_filter( + self, requirements: Requirements + ) -> Optional[Callable[[InstanceOfferWithAvailability], bool]]: + def offers_post_filter(offer: InstanceOfferWithAvailability) -> bool: + pod_counts = _get_offer_pod_counts(offer) + is_cluster_offer = len(pod_counts) > 0 and any(pc != 1 for pc in pod_counts) + if requirements.multinode: + return is_cluster_offer + return not is_cluster_offer + + return offers_post_filter + def run_job( self, run: Run, @@ -151,6 +172,8 @@ def run_job( instance_id = resp["id"] + # Call edit_pod to pass container_registry_auth_id. + # Expect a long time (~5m) for the pod to pick up the creds. # TODO: remove editPod once createPod supports docker's username and password # editPod is temporary solution to set container_registry_auth_id because createPod does not # support it currently. This will be removed once createPod supports container_registry_auth_id @@ -186,14 +209,127 @@ def run_job( backend_data=None, ) + def run_jobs( + self, + run: Run, + job_configurations: List[JobConfiguration], + instance_offer: InstanceOfferWithAvailability, + project_ssh_public_key: str, + project_ssh_private_key: str, + ) -> ComputeGroupProvisioningData: + master_job_configuration = job_configurations[0] + master_job = master_job_configuration.job + master_job_volumes = master_job_configuration.volumes + all_volumes_names = set(v.name for jc in job_configurations for v in jc.volumes) + instance_config = InstanceConfiguration( + project_name=run.project_name, + instance_name=get_job_instance_name(run, master_job), + ssh_keys=[ + SSHKey(public=get_or_error(run.run_spec.ssh_key_pub).strip()), + SSHKey(public=project_ssh_public_key.strip()), + ], + user=run.user, + ) + + pod_name = generate_unique_instance_name(instance_config, max_length=MAX_RESOURCE_NAME_LEN) + authorized_keys = instance_config.get_public_keys() + disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + + network_volume_id = None + volume_mount_path = None + if len(master_job_volumes) > 1: + raise ComputeError("Mounting more than one network volume is not supported in runpod") + if len(all_volumes_names) > 1: + raise ComputeError( + "Mounting different volumes to different jobs is not supported in runpod" + ) + if len(master_job_volumes) == 1: + network_volume_id = master_job_volumes[0].volume_id + volume_mount_path = run.run_spec.configuration.volumes[0].path + + offer_pod_counts = _get_offer_pod_counts(instance_offer) + pod_count = len(job_configurations) + gpu_count = len(instance_offer.instance.resources.gpus) + data_center_id = instance_offer.region + + if pod_count not in offer_pod_counts: + raise ComputeError( + f"Failed to provision {pod_count} pods. Available pod counts: {offer_pod_counts}" + ) + + container_registry_auth_id = self._generate_container_registry_auth_id( + master_job.job_spec.registry_auth + ) + resp = self.api_client.create_cluster( + cluster_name=pod_name, + gpu_type_id=instance_offer.instance.name, + pod_count=pod_count, + gpu_count_per_pod=gpu_count, + deploy_cost=f"{instance_offer.price * pod_count:.2f}", + image_name=master_job.job_spec.image_name, + cluster_type="TRAINING", + data_center_id=data_center_id, + container_disk_in_gb=disk_size, + docker_args=_get_docker_args(authorized_keys), + ports=f"{DSTACK_RUNNER_SSH_PORT}/tcp", + network_volume_id=network_volume_id, + volume_mount_path=volume_mount_path, + env={"RUNPOD_POD_USER": "0"}, + ) + + # An "edit pod" trick to pass container registry creds. + if container_registry_auth_id is not None: + for pod in resp["pods"]: + self.api_client.edit_pod( + pod_id=pod["id"], + image_name=master_job.job_spec.image_name, + container_disk_in_gb=disk_size, + container_registry_auth_id=container_registry_auth_id, + ) + + jpds = [ + JobProvisioningData( + backend=instance_offer.backend, + instance_type=instance_offer.instance, + instance_id=pod["id"], + hostname=None, + internal_ip=pod["clusterIp"], + region=instance_offer.region, + price=instance_offer.price, + username="root", + dockerized=False, + ) + for pod in resp["pods"] + ] + return ComputeGroupProvisioningData( + compute_group_id=resp["id"], + compute_group_name=resp["name"], + backend=BackendType.RUNPOD, + region=instance_offer.region, + job_provisioning_datas=jpds, + ) + def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ) -> None: + ): try: self.api_client.terminate_pod(instance_id) - except BackendError as e: - if e.args[0] == "Instance Not Found": - logger.debug("The instance with name %s not found", instance_id) + except RunpodApiClientError as e: + if len(e.errors) > 0 and e.errors[0]["message"] == "pod not found to terminate": + logger.debug("The instance %s not found. Skipping deletion.", instance_id) + return + raise + + def terminate_compute_group(self, compute_group: ComputeGroup): + provisioning_data = compute_group.provisioning_data + try: + self.api_client.delete_cluster(provisioning_data.compute_group_id) + except RunpodApiClientError as e: + if len(e.errors) > 0 and e.errors[0]["extensions"]["code"] == "Cluster not found": + logger.debug( + "The cluster %s not found. Skipping deletion.", + provisioning_data.compute_group_id, + ) return raise @@ -216,7 +352,9 @@ def update_provisioning_data( provisioning_data.ssh_port = port["publicPort"] def register_volume(self, volume: Volume) -> VolumeProvisioningData: - volume_data = self.api_client.get_network_volume(volume_id=volume.configuration.volume_id) + volume_data = self.api_client.get_network_volume( + volume_id=get_or_error(volume.configuration.volume_id) + ) if volume_data is None: raise ComputeError(f"Volume {volume.configuration.volume_id} not found") size_gb = volume_data["size"] @@ -258,14 +396,12 @@ def _generate_container_registry_auth_id( ) -> Optional[str]: if registry_auth is None: return None - return self.api_client.add_container_registry_auth( uuid.uuid4().hex, registry_auth.username, registry_auth.password ) def _clean_stale_container_registry_auths(self) -> None: container_registry_auths = self.api_client.get_container_registry_auths() - # Container_registry_auths sorted by creation time so try to delete the oldest first # when we reach container_registry_auths that is still in use, we stop for container_registry_auth in container_registry_auths: @@ -289,9 +425,17 @@ def _get_volume_price(size: int) -> float: return 0.05 * size -def _is_secure_cloud(region: str) -> str: +def _is_secure_cloud(region: str) -> bool: """ Secure cloud regions are datacenter IDs: CA-MTL-1, EU-NL-1, etc. Community cloud regions are country codes: CA, NL, etc. """ return "-" in region + + +def _get_offer_pod_counts(offer: InstanceOfferWithAvailability) -> list[int]: + backend_data: RunpodOfferBackendData = RunpodOfferBackendData.__response__.parse_obj( + offer.backend_data + ) + pod_counts = backend_data.pod_counts or [] + return pod_counts diff --git a/src/dstack/_internal/core/models/compute_groups.py b/src/dstack/_internal/core/models/compute_groups.py new file mode 100644 index 0000000000..66e1292eff --- /dev/null +++ b/src/dstack/_internal/core/models/compute_groups.py @@ -0,0 +1,39 @@ +import enum +import uuid +from datetime import datetime +from typing import List, Optional + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.models.runs import JobProvisioningData + + +class ComputeGroupStatus(str, enum.Enum): + RUNNING = "running" + TERMINATED = "terminated" + + +class ComputeGroupProvisioningData(CoreModel): + compute_group_id: str + compute_group_name: str + backend: BackendType + # In case backend provisions instance in another backend, + # it may set that backend as base_backend. + base_backend: Optional[BackendType] = None + region: str + job_provisioning_datas: List[JobProvisioningData] + backend_data: Optional[str] = None # backend-specific data in json + + +class ComputeGroup(CoreModel): + """ + Compute group is a group of instances managed as a single unit via the provider API, + i.e. instances are not created/deleted one-by-one but all at once. + """ + + id: uuid.UUID + name: str + project_name: str + created_at: datetime + status: ComputeGroupStatus + provisioning_data: ComputeGroupProvisioningData diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 0a5b174d23..5e6d4b4806 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -207,6 +207,9 @@ class Requirements(CoreModel): max_price: Optional[float] = None spot: Optional[bool] = None reservation: Optional[str] = None + # Backends can use `multinode` to filter out offers if + # some offers support multinode and some do not. + multinode: Optional[bool] = None def pretty_format(self, resources_only: bool = False): res = self.resources.pretty_format() diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 099f8ce51c..df7d41b9d9 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -2,6 +2,7 @@ from apscheduler.triggers.interval import IntervalTrigger from dstack._internal.server import settings +from dstack._internal.server.background.tasks.process_compute_groups import process_compute_groups from dstack._internal.server.background.tasks.process_fleets import process_fleets from dstack._internal.server.background.tasks.process_gateways import ( process_gateways, @@ -122,5 +123,11 @@ def start_background_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 5}, max_instances=2 if replica == 0 else 1, ) + _scheduler.add_job( + process_compute_groups, + IntervalTrigger(seconds=15, jitter=2), + kwargs={"batch_size": 1}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.start() return _scheduler diff --git a/src/dstack/_internal/server/background/tasks/process_compute_groups.py b/src/dstack/_internal/server/background/tasks/process_compute_groups.py new file mode 100644 index 0000000000..5f7b6820a4 --- /dev/null +++ b/src/dstack/_internal/server/background/tasks/process_compute_groups.py @@ -0,0 +1,164 @@ +import asyncio +import datetime +from datetime import timedelta + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.backends.base.compute import ComputeWithGroupProvisioningSupport +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.compute_groups import ComputeGroupStatus +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + ComputeGroupModel, + ProjectModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.compute_groups import compute_group_model_to_compute_group +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +MIN_PROCESSING_INTERVAL = timedelta(seconds=30) + +TERMINATION_RETRY_TIMEOUT = timedelta(seconds=60) +TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) + + +async def process_compute_groups(batch_size: int = 1): + tasks = [] + for _ in range(batch_size): + tasks.append(_process_next_compute_group()) + await asyncio.gather(*tasks) + + +@sentry_utils.instrument_background_task +async def _process_next_compute_group(): + lock, lockset = get_locker(get_db().dialect_name).get_lockset(ComputeGroupModel.__tablename__) + async with get_session_ctx() as session: + async with lock: + res = await session.execute( + select(ComputeGroupModel) + .where( + ComputeGroupModel.deleted == False, + ComputeGroupModel.id.not_in(lockset), + ComputeGroupModel.last_processed_at + < get_current_datetime() - MIN_PROCESSING_INTERVAL, + ) + .options(load_only(ComputeGroupModel.id)) + .order_by(ComputeGroupModel.last_processed_at.asc()) + .limit(1) + .with_for_update(skip_locked=True, key_share=True) + ) + compute_group_model = res.scalar() + if compute_group_model is None: + return + compute_group_model_id = compute_group_model.id + lockset.add(compute_group_model_id) + try: + await _process_compute_group( + session=session, + compute_group_model=compute_group_model, + ) + finally: + lockset.difference_update([compute_group_model_id]) + + +async def _process_compute_group(session: AsyncSession, compute_group_model: ComputeGroupModel): + # Refetch to load related attributes. + res = await session.execute( + select(ComputeGroupModel) + .where(ComputeGroupModel.id == compute_group_model.id) + .options( + joinedload(ComputeGroupModel.instances), + joinedload(ComputeGroupModel.project).joinedload(ProjectModel.backends), + ) + .execution_options(populate_existing=True) + ) + compute_group_model = res.unique().scalar_one() + if all(i.status == InstanceStatus.TERMINATING for i in compute_group_model.instances): + await _terminate_compute_group(compute_group_model) + compute_group_model.last_processed_at = get_current_datetime() + await session.commit() + + +async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> None: + if ( + compute_group_model.last_termination_retry_at is not None + and _next_termination_retry_at(compute_group_model) > get_current_datetime() + ): + return + compute_group = compute_group_model_to_compute_group(compute_group_model) + cgpd = compute_group.provisioning_data + backend = await backends_services.get_project_backend_by_type( + project=compute_group_model.project, + backend_type=cgpd.backend, + ) + if backend is None: + logger.error( + "Failed to terminate compute group %s. Backend %s not available.", + compute_group.name, + cgpd.backend, + ) + else: + logger.debug("Terminating compute group %s", compute_group.name) + compute = backend.compute() + assert isinstance(compute, ComputeWithGroupProvisioningSupport) + try: + await run_async( + compute.terminate_compute_group, + compute_group, + ) + except Exception as e: + if compute_group_model.first_termination_retry_at is None: + compute_group_model.first_termination_retry_at = get_current_datetime() + compute_group_model.last_termination_retry_at = get_current_datetime() + if _next_termination_retry_at(compute_group_model) < _get_termination_deadline( + compute_group_model + ): + logger.warning( + "Failed to terminate compute group %s. Will retry. Error: %r", + compute_group.name, + e, + exc_info=not isinstance(e, BackendError), + ) + return + logger.error( + "Failed all attempts to terminate compute group %s." + " Please terminate it manually to avoid unexpected charges." + " Error: %r", + compute_group.name, + e, + exc_info=not isinstance(e, BackendError), + ) + + compute_group_model.deleted = True + compute_group_model.deleted_at = get_current_datetime() + compute_group_model.status = ComputeGroupStatus.TERMINATED + # Terminating instances belonging to a compute group are locked implicitly + # by locking the compute group. + for instance_model in compute_group_model.instances: + instance_model.deleted = True + instance_model.deleted_at = get_current_datetime() + instance_model.finished_at = get_current_datetime() + instance_model.status = InstanceStatus.TERMINATED + logger.info( + "Terminated compute group %s", + compute_group.name, + ) + + +def _next_termination_retry_at(compute_group_model: ComputeGroupModel) -> datetime.datetime: + assert compute_group_model.last_termination_retry_at is not None + return compute_group_model.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT + + +def _get_termination_deadline(compute_group_model: ComputeGroupModel) -> datetime.datetime: + assert compute_group_model.first_termination_retry_at is not None + return compute_group_model.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 7a0815f3b0..a2c9f47420 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -8,7 +8,7 @@ from paramiko.pkey import PKey from paramiko.ssh_exception import PasswordRequiredException from pydantic import ValidationError -from sqlalchemy import delete, func, select +from sqlalchemy import and_, delete, func, not_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -57,7 +57,6 @@ ) from dstack._internal.core.models.runs import ( JobProvisioningData, - Retry, ) from dstack._internal.server import settings as server_settings from dstack._internal.server.background.tasks.common import get_provisioning_timeout @@ -167,6 +166,14 @@ async def _process_next_instance(): InstanceStatus.TERMINATING, ] ), + # Terminating instances belonging to a compute group + # are handled by process_compute_groups. + not_( + and_( + InstanceModel.status == InstanceStatus.TERMINATING, + InstanceModel.compute_group_id.is_not(None), + ) + ), InstanceModel.id.not_in(lockset), InstanceModel.last_processed_at < get_current_datetime() - MIN_PROCESSING_INTERVAL, @@ -918,51 +925,48 @@ async def _terminate(instance: InstanceModel) -> None: ): return jpd = get_instance_provisioning_data(instance) - if jpd is not None: - if jpd.backend != BackendType.REMOTE: - backend = await backends_services.get_project_backend_by_type( - project=instance.project, backend_type=jpd.backend + if jpd is not None and jpd.backend != BackendType.REMOTE: + backend = await backends_services.get_project_backend_by_type( + project=instance.project, backend_type=jpd.backend + ) + if backend is None: + logger.error( + "Failed to terminate instance %s. Backend %s not available.", + instance.name, + jpd.backend, ) - if backend is None: + else: + logger.debug("Terminating runner instance %s", jpd.hostname) + try: + await run_async( + backend.compute().terminate_instance, + jpd.instance_id, + jpd.region, + jpd.backend_data, + ) + except Exception as e: + if instance.first_termination_retry_at is None: + instance.first_termination_retry_at = get_current_datetime() + instance.last_termination_retry_at = get_current_datetime() + if _next_termination_retry_at(instance) < _get_termination_deadline(instance): + if isinstance(e, NotYetTerminated): + logger.debug("Instance %s termination in progress: %s", instance.name, e) + else: + logger.warning( + "Failed to terminate instance %s. Will retry. Error: %r", + instance.name, + e, + exc_info=not isinstance(e, BackendError), + ) + return logger.error( - "Failed to terminate instance %s. Backend %s not available.", + "Failed all attempts to terminate instance %s." + " Please terminate the instance manually to avoid unexpected charges." + " Error: %r", instance.name, - jpd.backend, + e, + exc_info=not isinstance(e, BackendError), ) - else: - logger.debug("Terminating runner instance %s", jpd.hostname) - try: - await run_async( - backend.compute().terminate_instance, - jpd.instance_id, - jpd.region, - jpd.backend_data, - ) - except Exception as e: - if instance.first_termination_retry_at is None: - instance.first_termination_retry_at = get_current_datetime() - instance.last_termination_retry_at = get_current_datetime() - if _next_termination_retry_at(instance) < _get_termination_deadline(instance): - if isinstance(e, NotYetTerminated): - logger.debug( - "Instance %s termination in progress: %s", instance.name, e - ) - else: - logger.warning( - "Failed to terminate instance %s. Will retry. Error: %r", - instance.name, - e, - exc_info=not isinstance(e, BackendError), - ) - return - logger.error( - "Failed all attempts to terminate instance %s." - " Please terminate the instance manually to avoid unexpected charges." - " Error: %r", - instance.name, - e, - exc_info=not isinstance(e, BackendError), - ) instance.deleted = True instance.deleted_at = get_current_datetime() @@ -1126,10 +1130,6 @@ def _get_instance_idle_duration(instance: InstanceModel) -> datetime.timedelta: return get_current_datetime() - last_time -def _get_retry_duration_deadline(instance: InstanceModel, retry: Retry) -> datetime.datetime: - return instance.created_at + timedelta(seconds=retry.duration) - - def _get_provisioning_deadline( instance: InstanceModel, job_provisioning_data: JobProvisioningData, diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 2814840b55..bc4183f689 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -3,16 +3,22 @@ import math import uuid from datetime import datetime, timedelta -from typing import List, Optional +from typing import List, Optional, Union from sqlalchemy import and_, func, not_, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, load_only, noload, selectinload from dstack._internal.core.backends.base.backend import Backend -from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport +from dstack._internal.core.backends.base.compute import ( + ComputeWithGroupProvisioningSupport, + ComputeWithVolumeSupport, +) +from dstack._internal.core.backends.base.models import JobConfiguration +from dstack._internal.core.backends.features import BACKENDS_WITH_GROUP_PROVISIONING_SUPPORT from dstack._internal.core.errors import BackendError, ServerClientError from dstack._internal.core.models.common import NetworkMode +from dstack._internal.core.models.compute_groups import ComputeGroupProvisioningData from dstack._internal.core.models.fleets import ( Fleet, FleetConfiguration, @@ -42,8 +48,10 @@ from dstack._internal.core.models.volumes import Volume from dstack._internal.core.services.profiles import get_termination from dstack._internal.server import settings +from dstack._internal.server.background.tasks.process_compute_groups import ComputeGroupStatus from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( + ComputeGroupModel, FleetModel, InstanceModel, JobModel, @@ -69,6 +77,7 @@ from dstack._internal.server.services.jobs import ( check_can_attach_job_volumes, find_job, + find_jobs, get_instances_ids_with_detaching_volumes, get_job_configured_volume_models, get_job_configured_volumes, @@ -132,6 +141,7 @@ async def _process_next_submitted_job(): .join(JobModel.run) .where( JobModel.status == JobStatus.SUBMITTED, + JobModel.waiting_master_job.is_not(True), JobModel.id.not_in(lockset), ) .options(load_only(JobModel.id)) @@ -190,6 +200,8 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): run_spec = run.run_spec run_profile = run_spec.merged_profile job = find_job(run.jobs, job_model.replica_num, job_model.job_num) + replica_jobs = find_jobs(run.jobs, replica_num=job_model.replica_num) + replica_job_models = _get_job_models_for_jobs(run_model.jobs, replica_jobs) multinode = job.job_spec.jobs_per_replica > 1 # Master job chooses fleet for the run. @@ -323,6 +335,10 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): return # If no instances were locked, we can proceed in the same transaction. + # TODO: Volume attachment for compute groups is not yet supported since + # currently supported compute groups (e.g. Runpod) don't need explicit volume attachment. + need_volume_attachment = True + if job_model.instance is not None: res = await session.execute( select(InstanceModel) @@ -333,7 +349,6 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): instance = res.unique().scalar_one() job_model.status = JobStatus.PROVISIONING else: - # Assigned no instance, create a new one if run_profile.creation_policy == CreationPolicy.REUSE: logger.debug("%s: reuse instance failed", fmt(job_model)) job_model.status = JobStatus.TERMINATING @@ -342,13 +357,23 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): await session.commit() return - # Create a new cloud instance - run_job_result = await _run_job_on_new_instance( + jobs_to_provision = [job] + if ( + multinode + and job.job_spec.job_num == 0 + # job_model.waiting_master_job is not set for legacy jobs. + # In this case compute group provisioning not supported + # and jobs always provision one-by-one. + and job_model.waiting_master_job is not None + ): + jobs_to_provision = replica_jobs + + run_job_result = await _run_jobs_on_new_instances( project=project, fleet_model=fleet_model, job_model=job_model, run=run, - job=job, + jobs=jobs_to_provision, project_ssh_public_key=project.ssh_public_key, project_ssh_private_key=project.ssh_private_key, master_job_provisioning_data=master_job_provisioning_data, @@ -362,72 +387,102 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): await session.commit() return - logger.info("%s: now is provisioning a new instance", fmt(job_model)) - job_provisioning_data, offer, effective_profile, _ = run_job_result - job_model.job_provisioning_data = job_provisioning_data.json() - job_model.status = JobStatus.PROVISIONING if fleet_model is None: fleet_model = await _create_fleet_model_for_job( session=session, project=project, run=run, ) - # FIXME: Fleet is not locked which may lead to duplicate instance_num. - # This is currently hard to fix without locking the fleet for entire provisioning duration. - # Processing should be done in multiple steps so that - # InstanceModel is created before provisioning. - instance_num = await _get_next_instance_num( - session=session, - fleet_model=fleet_model, - ) - instance = _create_instance_model_for_job( - project=project, - fleet_model=fleet_model, - job_model=job_model, - job_provisioning_data=job_provisioning_data, - offer=offer, - instance_num=instance_num, - profile=effective_profile, - ) - job_model.job_runtime_data = _prepare_job_runtime_data(offer, multinode).json() - # Both this task and process_fleets can add instances to fleets. - # TODO: Ensure this does not violate nodes.max when it's enforced. - instance.fleet_id = fleet_model.id - logger.info( - "The job %s created the new instance %s", - job_model.job_name, - instance.name, - extra={ - "instance_name": instance.name, - "instance_status": InstanceStatus.PROVISIONING.value, - }, - ) - session.add(instance) - session.add(fleet_model) - job_model.used_instance_id = instance.id - - volumes_ids = sorted([v.id for vs in volume_models for v in vs]) - # TODO: lock instances for attaching volumes? - # Take lock to prevent attaching volumes that are to be deleted. - # If the volume was deleted before the lock, the volume will fail to attach and the job will fail. - await session.execute( - select(VolumeModel) - .where(VolumeModel.id.in_(volumes_ids)) - .options(joinedload(VolumeModel.user).load_only(UserModel.name)) - .order_by(VolumeModel.id) # take locks in order - .with_for_update(key_share=True, of=VolumeModel) - ) - async with get_locker(get_db().dialect_name).lock_ctx(VolumeModel.__tablename__, volumes_ids): - if len(volume_models) > 0: - await _attach_volumes( + session.add(fleet_model) + + provisioning_data, offer, effective_profile, _ = run_job_result + compute_group_model = None + if isinstance(provisioning_data, ComputeGroupProvisioningData): + need_volume_attachment = False + provisioned_jobs = jobs_to_provision + jpds = provisioning_data.job_provisioning_datas + compute_group_model = ComputeGroupModel( + id=uuid.uuid4(), + project=project, + fleet=fleet_model, + status=ComputeGroupStatus.RUNNING, + provisioning_data=provisioning_data.json(), + ) + session.add(compute_group_model) + else: + provisioned_jobs = [job] + jpds = [provisioning_data] + if len(jobs_to_provision) > 1: + # Tried provisioning multiple jobs but provisioned only one. + # Allow other jobs to provision one-by-one. + for replica_job_model in replica_job_models: + replica_job_model.waiting_master_job = False + + logger.info("%s: provisioned %s new instance(s)", fmt(job_model), len(provisioned_jobs)) + provisioned_job_models = _get_job_models_for_jobs(run_model.jobs, provisioned_jobs) + instance = None # Instance for attaching volumes in case of single job provisioned + for provisioned_job_model, jpd in zip(provisioned_job_models, jpds): + provisioned_job_model.job_provisioning_data = jpd.json() + provisioned_job_model.status = JobStatus.PROVISIONING + # FIXME: Fleet is not locked which may lead to duplicate instance_num. + # This is currently hard to fix without locking the fleet for entire provisioning duration. + # Processing should be done in multiple steps so that + # InstanceModel is created before provisioning. + instance_num = await _get_next_instance_num( session=session, + fleet_model=fleet_model, + ) + instance = _create_instance_model_for_job( project=project, - job_model=job_model, - instance=instance, - volume_models=volume_models, + fleet_model=fleet_model, + compute_group_model=compute_group_model, + job_model=provisioned_job_model, + job_provisioning_data=jpd, + offer=offer, + instance_num=instance_num, + profile=effective_profile, ) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() + provisioned_job_model.job_runtime_data = _prepare_job_runtime_data( + offer, multinode + ).json() + logger.info( + "Created a new instance %s for job %s", + instance.name, + provisioned_job_model.job_name, + extra={ + "instance_name": instance.name, + "instance_status": InstanceStatus.PROVISIONING.value, + }, + ) + session.add(instance) + provisioned_job_model.used_instance_id = instance.id + provisioned_job_model.last_processed_at = common_utils.get_current_datetime() + + volumes_ids = sorted([v.id for vs in volume_models for v in vs]) + if need_volume_attachment: + # TODO: Lock instances for attaching volumes? + # Take lock to prevent attaching volumes that are to be deleted. + # If the volume was deleted before the lock, the volume will fail to attach and the job will fail. + await session.execute( + select(VolumeModel) + .where(VolumeModel.id.in_(volumes_ids)) + .options(joinedload(VolumeModel.user).load_only(UserModel.name)) + .order_by(VolumeModel.id) # take locks in order + .with_for_update(key_share=True, of=VolumeModel) + ) + async with get_locker(get_db().dialect_name).lock_ctx( + VolumeModel.__tablename__, volumes_ids + ): + if len(volume_models) > 0: + assert instance is not None + await _attach_volumes( + session=session, + project=project, + job_model=job_model, + instance=instance, + volume_models=volume_models, + ) + await session.commit() async def _select_fleet_models( @@ -553,10 +608,9 @@ async def _find_optimal_fleet_with_offers( except ValueError: fleet_backend_offers = [] else: - multinode = ( - candidate_fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER - or job.job_spec.jobs_per_replica > 1 - ) + # Handle multinode for old jobs that don't have requirements.multinode set. + # TODO: Drop multinode param. + multinode = requirements.multinode or job.job_spec.jobs_per_replica > 1 fleet_backend_offers = await get_offers_by_requirements( project=project, profile=profile, @@ -728,19 +782,33 @@ async def _assign_job_to_fleet_instance( return instance -async def _run_job_on_new_instance( +async def _run_jobs_on_new_instances( project: ProjectModel, job_model: JobModel, run: Run, - job: Job, + jobs: list[Job], project_ssh_public_key: str, project_ssh_private_key: str, master_job_provisioning_data: Optional[JobProvisioningData] = None, - volumes: Optional[List[List[Volume]]] = None, + volumes: Optional[list[list[Volume]]] = None, fleet_model: Optional[FleetModel] = None, -) -> Optional[tuple[JobProvisioningData, InstanceOfferWithAvailability, Profile, Requirements]]: +) -> Optional[ + tuple[ + Union[JobProvisioningData, ComputeGroupProvisioningData], + InstanceOfferWithAvailability, + Profile, + Requirements, + ] +]: + """ + Provisions an instance for a job or a compute group for multiple jobs and runs the jobs. + Even when multiple jobs are passes, it may still provision only one instance + and run only the master job in case there are no offers supporting cluster groups. + Other jobs should be provisioned one-by-one later. + """ if volumes is None: volumes = [] + job = jobs[0] profile = run.run_spec.merged_profile requirements = job.job_spec.requirements fleet = None @@ -758,9 +826,7 @@ async def _run_job_on_new_instance( return None # TODO: Respect fleet provisioning properties such as tags - multinode = job.job_spec.jobs_per_replica > 1 or ( - fleet is not None and fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER - ) + multinode = requirements.multinode or job.job_spec.jobs_per_replica > 1 offers = await get_offers_by_requirements( project=project, profile=profile, @@ -784,17 +850,31 @@ async def _run_job_on_new_instance( offer.price, ) offer_volumes = _get_offer_volumes(volumes, offer) + job_configurations = [JobConfiguration(job=j, volumes=offer_volumes) for j in jobs] + compute = backend.compute() try: - job_provisioning_data = await common_utils.run_async( - backend.compute().run_job, - run, - job, - offer, - project_ssh_public_key, - project_ssh_private_key, - offer_volumes, - ) - return job_provisioning_data, offer, profile, requirements + if len(jobs) > 1 and offer.backend in BACKENDS_WITH_GROUP_PROVISIONING_SUPPORT: + assert isinstance(compute, ComputeWithGroupProvisioningSupport) + cgpd = await common_utils.run_async( + compute.run_jobs, + run, + job_configurations, + offer, + project_ssh_public_key, + project_ssh_private_key, + ) + return cgpd, offer, profile, requirements + else: + jpd = await common_utils.run_async( + compute.run_job, + run, + job, + offer, + project_ssh_public_key, + project_ssh_private_key, + offer_volumes, + ) + return jpd, offer, profile, requirements except BackendError as e: logger.warning( "%s: %s launch in %s/%s failed: %s", @@ -912,6 +992,7 @@ async def _get_next_instance_num(session: AsyncSession, fleet_model: FleetModel) def _create_instance_model_for_job( project: ProjectModel, fleet_model: FleetModel, + compute_group_model: Optional[ComputeGroupModel], job_model: JobModel, job_provisioning_data: JobProvisioningData, offer: InstanceOfferWithAvailability, @@ -931,6 +1012,8 @@ def _create_instance_model_for_job( name=f"{fleet_model.name}-{instance_num}", instance_num=instance_num, project=project, + fleet=fleet_model, + compute_group=compute_group_model, created_at=common_utils.get_current_datetime(), started_at=common_utils.get_current_datetime(), status=InstanceStatus.PROVISIONING, @@ -1081,3 +1164,15 @@ async def _attach_volume( instance.volume_attachments.append(volume_attachment_model) volume_model.last_job_processed_at = common_utils.get_current_datetime() + + +def _get_job_models_for_jobs( + job_models: list[JobModel], + jobs: list[Job], +) -> list[JobModel]: + """ + Returns job models of latest submissions for a list of jobs. + Preserves jobs order. + """ + id_to_job_model_map = {jm.id: jm for jm in job_models} + return [id_to_job_model_map[j.job_submissions[-1].id] for j in jobs] diff --git a/src/dstack/_internal/server/migrations/env.py b/src/dstack/_internal/server/migrations/env.py index 0b2f73a19c..81d8ba0694 100644 --- a/src/dstack/_internal/server/migrations/env.py +++ b/src/dstack/_internal/server/migrations/env.py @@ -6,7 +6,7 @@ from sqlalchemy import Connection, MetaData, text from dstack._internal.server.db import get_db -from dstack._internal.server.models import BaseModel +from dstack._internal.server.models import BaseModel, EnumAsString config = context.config @@ -21,6 +21,14 @@ def set_target_metadata(metadata: MetaData): target_metadata = metadata +def render_item(type_, obj, autogen_context): + """Apply custom rendering for selected items.""" + if type_ == "type" and isinstance(obj, EnumAsString): + return f"sa.String(length={obj.length})" + # default rendering for other objects + return False + + def run_migrations_offline(): """Run migrations in 'offline' mode. This configures the context with just a URL @@ -35,6 +43,7 @@ def run_migrations_offline(): target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, + render_item=render_item, ) with context.begin_transaction(): context.run_migrations() @@ -71,6 +80,7 @@ def run_migrations(connection: Connection): target_metadata=target_metadata, compare_type=True, render_as_batch=True, + render_item=render_item, # Running each migration in a separate transaction. # Running all migrations in one transaction may lead to deadlocks in HA deployments # because lock ordering is not respected across all migrations. diff --git a/src/dstack/_internal/server/migrations/versions/7d1ec2b920ac_add_computegroupmodel.py b/src/dstack/_internal/server/migrations/versions/7d1ec2b920ac_add_computegroupmodel.py new file mode 100644 index 0000000000..94eccec2b4 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/7d1ec2b920ac_add_computegroupmodel.py @@ -0,0 +1,93 @@ +"""Add ComputeGroupModel + +Revision ID: 7d1ec2b920ac +Revises: ff1d94f65b08 +Create Date: 2025-10-21 16:01:23.739395 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "7d1ec2b920ac" +down_revision = "ff1d94f65b08" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "compute_groups", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("fleet_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.Column("status", sa.String(length=100), nullable=False), + sa.Column( + "last_processed_at", dstack._internal.server.models.NaiveDateTime(), nullable=False + ), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", dstack._internal.server.models.NaiveDateTime(), nullable=True), + sa.Column("provisioning_data", sa.Text(), nullable=False), + sa.Column( + "first_termination_retry_at", + dstack._internal.server.models.NaiveDateTime(), + nullable=True, + ), + sa.Column( + "last_termination_retry_at", + dstack._internal.server.models.NaiveDateTime(), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["fleet_id"], ["fleets.id"], name=op.f("fk_compute_groups_fleet_id_fleets") + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_compute_groups_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_compute_groups")), + ) + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "compute_group_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=True, + ) + ) + batch_op.create_foreign_key( + batch_op.f("fk_instances_compute_group_id_compute_groups"), + "compute_groups", + ["compute_group_id"], + ["id"], + ) + + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column(sa.Column("waiting_master_job", sa.Boolean(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("waiting_master_job") + + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_instances_compute_group_id_compute_groups"), type_="foreignkey" + ) + batch_op.drop_column("compute_group_id") + + op.drop_table("compute_groups") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index c6d97b810e..e88f83d599 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -25,6 +25,7 @@ from dstack._internal.core.errors import DstackError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreConfig, generate_dual_core_model +from dstack._internal.core.models.compute_groups import ComputeGroupStatus from dstack._internal.core.models.fleets import FleetStatus from dstack._internal.core.models.gateways import GatewayStatus from dstack._internal.core.models.health import HealthStatus @@ -448,6 +449,12 @@ class JobModel(BaseModel): # Whether the replica is registered to receive service requests. # Always `False` for non-service runs. registered: Mapped[bool] = mapped_column(Boolean, server_default=false()) + # `waiting_master_job` is `True` for non-master jobs that have to wait + # for master processing before they can be processed. + # This allows updating all replica jobs even when only master is locked, + # e.g. to provision instances for all jobs when processing master. + # If not set, all jobs should be processed only one-by-one. + waiting_master_job: Mapped[Optional[bool]] = mapped_column(Boolean) class GatewayModel(BaseModel): @@ -592,6 +599,9 @@ class InstanceModel(BaseModel): fleet_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("fleets.id")) fleet: Mapped[Optional["FleetModel"]] = relationship(back_populates="instances") + compute_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("compute_groups.id")) + compute_group: Mapped[Optional["ComputeGroupModel"]] = relationship(back_populates="instances") + status: Mapped[InstanceStatus] = mapped_column(EnumAsString(InstanceStatus, 100), index=True) unreachable: Mapped[bool] = mapped_column(Boolean) @@ -743,6 +753,35 @@ class PlacementGroupModel(BaseModel): provisioning_data: Mapped[Optional[str]] = mapped_column(Text) +class ComputeGroupModel(BaseModel): + __tablename__ = "compute_groups" + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + + project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) + project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) + + fleet_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("fleets.id")) + fleet: Mapped["FleetModel"] = relationship(foreign_keys=[fleet_id]) + + created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + status: Mapped[ComputeGroupStatus] = mapped_column(EnumAsString(ComputeGroupStatus, 100)) + last_processed_at: Mapped[datetime] = mapped_column( + NaiveDateTime, default=get_current_datetime + ) + deleted: Mapped[bool] = mapped_column(Boolean, default=False) + deleted_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + + provisioning_data: Mapped[str] = mapped_column(Text) + + first_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + last_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + + instances: Mapped[List["InstanceModel"]] = relationship(back_populates="compute_group") + + class JobMetricsPoint(BaseModel): __tablename__ = "job_metrics_points" diff --git a/src/dstack/_internal/server/services/compute_groups.py b/src/dstack/_internal/server/services/compute_groups.py new file mode 100644 index 0000000000..4d759e0d21 --- /dev/null +++ b/src/dstack/_internal/server/services/compute_groups.py @@ -0,0 +1,22 @@ +from dstack._internal.core.models.compute_groups import ComputeGroup, ComputeGroupProvisioningData +from dstack._internal.server.models import ComputeGroupModel + + +def compute_group_model_to_compute_group(compute_group_model: ComputeGroupModel) -> ComputeGroup: + provisioning_data = get_compute_group_provisioning_data(compute_group_model) + return ComputeGroup( + id=compute_group_model.id, + project_name=compute_group_model.project.name, + status=compute_group_model.status, + name=provisioning_data.compute_group_name, + created_at=compute_group_model.created_at, + provisioning_data=provisioning_data, + ) + + +def get_compute_group_provisioning_data( + compute_group_model: ComputeGroupModel, +) -> ComputeGroupProvisioningData: + return ComputeGroupProvisioningData.__response__.parse_raw( + compute_group_model.provisioning_data + ) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 277ed41b32..0e3aaf2d4b 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -650,6 +650,7 @@ def get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements: max_price=profile.max_price, spot=get_policy_map(profile.spot_policy, default=SpotPolicy.ONDEMAND), reservation=fleet_spec.configuration.reservation, + multinode=fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER, ) return requirements diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index cbb089b2c5..4d4f75e75a 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -96,6 +96,19 @@ def find_job(jobs: List[Job], replica_num: int, job_num: int) -> Job: ) +def find_jobs( + jobs: List[Job], + replica_num: Optional[int] = None, + job_num: Optional[int] = None, +) -> list[Job]: + res = jobs + if replica_num is not None: + res = [j for j in res if j.job_spec.replica_num == replica_num] + if job_num is not None: + res = [j for j in res if j.job_spec.job_num == job_num] + return res + + async def get_run_job_model( session: AsyncSession, project: ProjectModel, diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 02cdc70b3c..18f6f14e07 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -161,7 +161,7 @@ async def _get_job_spec( stop_duration=self._stop_duration(), utilization_policy=self._utilization_policy(), registry_auth=self._registry_auth(), - requirements=self._requirements(), + requirements=self._requirements(jobs_per_replica), retry=self._retry(), working_dir=self._working_dir(), volumes=self._volumes(job_num), @@ -295,13 +295,14 @@ def _utilization_policy(self) -> Optional[UtilizationPolicy]: def _registry_auth(self) -> Optional[RegistryAuth]: return self.run_spec.configuration.registry_auth - def _requirements(self) -> Requirements: + def _requirements(self, jobs_per_replica: int) -> Requirements: spot_policy = self._spot_policy() return Requirements( resources=self.run_spec.configuration.resources, max_price=self.run_spec.merged_profile.max_price, spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT), reservation=self.run_spec.merged_profile.reservation, + multinode=jobs_per_replica > 1, ) def _retry(self) -> Optional[Retry]: diff --git a/src/dstack/_internal/server/services/requirements/combine.py b/src/dstack/_internal/server/services/requirements/combine.py index a5830ddef7..e090601f92 100644 --- a/src/dstack/_internal/server/services/requirements/combine.py +++ b/src/dstack/_internal/server/services/requirements/combine.py @@ -63,6 +63,7 @@ def combine_fleet_and_run_requirements( reservation=_get_single_value_optional( fleet_requirements.reservation, run_requirements.reservation ), + multinode=fleet_requirements.multinode or run_requirements.multinode, ) except CombineError: return None diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 14ad4c1c86..ed64aa7219 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -609,6 +609,7 @@ def create_job_model_for_new_submission( job_spec_data=job.job_spec.json(), job_provisioning_data=None, probes=[], + waiting_master_job=job.job_spec.job_num != 0, ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index e6de272911..883ce14535 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -13,6 +13,7 @@ Compute, ComputeWithCreateInstanceSupport, ComputeWithGatewaySupport, + ComputeWithGroupProvisioningSupport, ComputeWithMultinodeSupport, ComputeWithPlacementGroupSupport, ComputeWithPrivateGatewaySupport, @@ -22,6 +23,10 @@ ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode +from dstack._internal.core.models.compute_groups import ( + ComputeGroupProvisioningData, + ComputeGroupStatus, +) from dstack._internal.core.models.configurations import ( AnyRunConfiguration, DevEnvironmentConfiguration, @@ -83,6 +88,7 @@ ) from dstack._internal.server.models import ( BackendModel, + ComputeGroupModel, DecryptedString, FileArchiveModel, FleetModel, @@ -353,6 +359,7 @@ async def create_job( instance_assigned: bool = False, disconnected_at: Optional[datetime] = None, registered: bool = False, + waiting_master_job: Optional[bool] = None, ) -> JobModel: if deployment_num is None: deployment_num = run.deployment_num @@ -384,6 +391,7 @@ async def create_job( disconnected_at=disconnected_at, probes=[], registered=registered, + waiting_master_job=waiting_master_job, ) session.add(job) await session.commit() @@ -455,6 +463,48 @@ def get_job_runtime_data( ) +def get_compute_group_provisioning_data( + compute_group_id: str = "test_compute_group", + compute_group_name: str = "test_compute_group", + backend: BackendType = BackendType.RUNPOD, + region: str = "US", + job_provisioning_datas: Optional[list[JobProvisioningData]] = None, + backend_data: Optional[str] = None, +) -> ComputeGroupProvisioningData: + if job_provisioning_datas is None: + job_provisioning_datas = [] + return ComputeGroupProvisioningData( + compute_group_id=compute_group_id, + compute_group_name=compute_group_name, + backend=backend, + region=region, + job_provisioning_datas=job_provisioning_datas, + backend_data=backend_data, + ) + + +async def create_compute_group( + session: AsyncSession, + project: ProjectModel, + fleet: FleetModel, + status: ComputeGroupStatus = ComputeGroupStatus.RUNNING, + provisioning_data: Optional[ComputeGroupProvisioningData] = None, + last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), +): + if provisioning_data is None: + provisioning_data = get_compute_group_provisioning_data() + compute_group = ComputeGroupModel( + project=project, + fleet=fleet, + status=status, + provisioning_data=provisioning_data.json(), + last_processed_at=last_processed_at, + ) + session.add(compute_group) + await session.commit() + return compute_group + + async def create_probe( session: AsyncSession, job: JobModel, @@ -1136,6 +1186,7 @@ async def __aexit__(self, exc_type, exc, traceback): class ComputeMockSpec( Compute, ComputeWithCreateInstanceSupport, + ComputeWithGroupProvisioningSupport, ComputeWithPrivilegedSupport, ComputeWithMultinodeSupport, ComputeWithReservationSupport, diff --git a/src/dstack/_internal/settings.py b/src/dstack/_internal/settings.py index 2c0035bad4..a97abf5e2a 100644 --- a/src/dstack/_internal/settings.py +++ b/src/dstack/_internal/settings.py @@ -34,3 +34,5 @@ class FeatureFlags: large features. This class may be empty if there are no such features in development. Feature flags are environment variables of the form DSTACK_FF_* """ + + pass diff --git a/src/tests/_internal/server/background/tasks/test_process_compute_groups.py b/src/tests/_internal/server/background/tasks/test_process_compute_groups.py new file mode 100644 index 0000000000..11ce734606 --- /dev/null +++ b/src/tests/_internal/server/background/tasks/test_process_compute_groups.py @@ -0,0 +1,83 @@ +from datetime import datetime, timezone +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.server.background.tasks.process_compute_groups import ( + ComputeGroupStatus, + process_compute_groups, +) +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_compute_group, + create_fleet, + create_project, +) + + +class TestProcessComputeGroups: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_terminates_compute_group(self, test_db, session: AsyncSession): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group( + session=session, + project=project, + fleet=fleet, + ) + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + await process_compute_groups() + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status == ComputeGroupStatus.TERMINATED + assert compute_group.deleted + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retries_compute_group_termination(self, test_db, session: AsyncSession): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc), + ) + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + compute_mock.terminate_compute_group.side_effect = BackendError() + await process_compute_groups() + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status != ComputeGroupStatus.TERMINATED + assert compute_group.first_termination_retry_at is not None + assert compute_group.last_termination_retry_at is not None + # Simulate termination deadline exceeded + compute_group.first_termination_retry_at = datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc) + compute_group.last_termination_retry_at = datetime(2023, 1, 2, 4, 0, tzinfo=timezone.utc) + compute_group.last_processed_at = datetime(2023, 1, 2, 4, 0, tzinfo=timezone.utc) + await session.commit() + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = backend_mock + backend_mock.TYPE = BackendType.RUNPOD + compute_mock.terminate_compute_group.side_effect = BackendError() + await process_compute_groups() + compute_mock.terminate_compute_group.assert_called_once() + await session.refresh(compute_group) + assert compute_group.status == ComputeGroupStatus.TERMINATED diff --git a/src/tests/_internal/server/background/tasks/test_process_fleets.py b/src/tests/_internal/server/background/tasks/test_process_fleets.py index 4370f77b57..ae7155c3ca 100644 --- a/src/tests/_internal/server/background/tasks/test_process_fleets.py +++ b/src/tests/_internal/server/background/tasks/test_process_fleets.py @@ -20,7 +20,7 @@ ) -class TestProcessEmptyFleets: +class TestProcessFleets: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_deletes_empty_autocreated_fleet(self, test_db, session: AsyncSession): diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index b4ebf9fb59..545349e585 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -31,7 +31,12 @@ _prepare_job_runtime_data, process_submitted_jobs, ) -from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel +from dstack._internal.server.models import ( + ComputeGroupModel, + InstanceModel, + JobModel, + VolumeAttachmentModel, +) from dstack._internal.server.settings import JobNetworkMode from dstack._internal.server.testing.common import ( ComputeMockSpec, @@ -43,6 +48,7 @@ create_run, create_user, create_volume, + get_compute_group_provisioning_data, get_fleet_spec, get_instance_offer_with_availability, get_job_provisioning_data, @@ -1116,6 +1122,73 @@ async def test_picks_high_priority_jobs_first(self, test_db, session: AsyncSessi await session.refresh(job2) assert job2.status == JobStatus.PROVISIONING + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_provisions_compute_group(self, test_db, session: AsyncSession): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + run_name = "test-run" + run_spec = get_run_spec( + repo_id=repo.name, + run_name=run_name, + ) + run_spec.configuration = TaskConfiguration(nodes=2, commands=["echo"]) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_name=run_name, + run_spec=run_spec, + ) + job1 = await create_job( + session=session, + run=run, + instance_assigned=True, + job_num=0, + status=JobStatus.SUBMITTED, + waiting_master_job=False, + ) + job2 = await create_job( + session=session, + run=run, + instance_assigned=False, + job_num=1, + status=JobStatus.SUBMITTED, + waiting_master_job=True, + ) + offer = get_instance_offer_with_availability( + backend=BackendType.RUNPOD, + availability=InstanceAvailability.AVAILABLE, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + compute_mock = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value = compute_mock + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.RUNPOD + compute_mock.get_offers.return_value = [offer] + jpds = [ + get_job_provisioning_data(), + get_job_provisioning_data(), + ] + compute_mock.run_jobs.return_value = get_compute_group_provisioning_data( + job_provisioning_datas=jpds + ) + await process_submitted_jobs() + m.assert_called_once() + compute_mock.get_offers.assert_called_once() + compute_mock.run_jobs.assert_called_once() + await session.refresh(job1) + await session.refresh(job2) + assert job1.status == JobStatus.PROVISIONING + assert job2.status == JobStatus.PROVISIONING + res = await session.execute(select(ComputeGroupModel)) + assert res.scalar() is not None + @pytest.mark.parametrize( ["job_network_mode", "blocks", "multinode", "network_mode", "constraints_are_set"], diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index f4e481f539..76eb4dbb51 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -251,6 +251,7 @@ def get_dev_env_run_plan_dict( "max_price": None, "spot": True, "reservation": None, + "multinode": False, }, "retry": None, "volumes": volumes, @@ -459,6 +460,7 @@ def get_dev_env_run_dict( "max_price": None, "spot": True, "reservation": None, + "multinode": False, }, "retry": None, "volumes": [], diff --git a/src/tests/_internal/server/services/test_instances.py b/src/tests/_internal/server/services/test_instances.py index 414a360c21..aa248aa485 100644 --- a/src/tests/_internal/server/services/test_instances.py +++ b/src/tests/_internal/server/services/test_instances.py @@ -58,12 +58,12 @@ async def test_returns_multinode_instances(self, test_db, session: AsyncSession) project=project, backend=BackendType.AWS, ) - runpod_instance = await create_instance( + vastai_instance = await create_instance( session=session, project=project, - backend=BackendType.RUNPOD, + backend=BackendType.VASTAI, ) - instances = [aws_instance, runpod_instance] + instances = [aws_instance, vastai_instance] res = instances_services.filter_pool_instances( pool_instances=instances, profile=Profile(name="test"), diff --git a/src/tests/_internal/server/services/test_offers.py b/src/tests/_internal/server/services/test_offers.py index 3e67bc7c3f..685369fc42 100644 --- a/src/tests/_internal/server/services/test_offers.py +++ b/src/tests/_internal/server/services/test_offers.py @@ -46,11 +46,11 @@ async def test_returns_multinode_offers(self): aws_backend_mock.TYPE = BackendType.AWS aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS) aws_backend_mock.compute.return_value.get_offers.return_value = [aws_offer] - runpod_backend_mock = Mock() - runpod_backend_mock.TYPE = BackendType.RUNPOD - runpod_offer = get_instance_offer_with_availability(backend=BackendType.RUNPOD) - runpod_backend_mock.compute.return_value.get_offers.return_value = [runpod_offer] - m.return_value = [aws_backend_mock, runpod_backend_mock] + vastai_backend_mock = Mock() + vastai_backend_mock.TYPE = BackendType.VASTAI + vastai_offer = get_instance_offer_with_availability(backend=BackendType.VASTAI) + vastai_backend_mock.compute.return_value.get_offers.return_value = [vastai_offer] + m.return_value = [aws_backend_mock, vastai_backend_mock] res = await get_offers_by_requirements( project=Mock(), profile=profile,