From 921f18360066004096d9e22b9f27d54637e713fd Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Tue, 23 Sep 2025 14:44:51 +0000 Subject: [PATCH] Kubernetes: request resources according to RequirementsSpec Other fixes and improvements: * Handle errors in `_create_jump_pod_service_if_not_exists` * Check both Service and Pod to decide if the jump pod must be (re)created * Respect `Node.status.nodeinfo.architecture` * Add `namespace` option to the backend config Part-of: https://github.com/dstackai/dstack/issues/3126 --- pyproject.toml | 3 +- .../_internal/core/backends/base/compute.py | 32 +- .../core/backends/kubernetes/compute.py | 584 ++++++++++++------ .../core/backends/kubernetes/models.py | 17 +- .../core/backends/kubernetes/utils.py | 153 ++++- .../background/tasks/process_instances.py | 4 +- .../_internal/server/utils/provisioning.py | 13 +- .../core/backends/base/test_compute.py | 5 +- .../core/backends/kubernetes/test_compute.py | 8 +- .../core/backends/kubernetes/test_utils.py | 68 ++ .../tasks/test_process_instances.py | 3 +- 11 files changed, 643 insertions(+), 247 deletions(-) create mode 100644 src/tests/_internal/core/backends/kubernetes/test_utils.py diff --git a/pyproject.toml b/pyproject.toml index 35f3ad4565..23340efb19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,8 +83,9 @@ include = [ "src/dstack/plugins", "src/dstack/_internal/server", "src/dstack/_internal/core/services", - "src/dstack/_internal/cli/commands", + "src/dstack/_internal/core/backends/kubernetes", "src/dstack/_internal/cli/services/configurators", + "src/dstack/_internal/cli/commands", ] ignore = [ "src/dstack/_internal/server/migrations/versions", diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index bba603f901..b8cbac23e1 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -5,14 +5,16 @@ import threading from abc import ABC, abstractmethod from collections.abc import Iterable +from enum import Enum from functools import lru_cache from pathlib import Path -from typing import Callable, Dict, List, Literal, Optional +from typing import Callable, Dict, List, Optional import git import requests import yaml from cachetools import TTLCache, cachedmethod +from gpuhunt import CPUArchitecture from dstack._internal import settings from dstack._internal.core.backends.base.offers import filter_offers_by_requirements @@ -49,7 +51,21 @@ DSTACK_RUNNER_BINARY_NAME = "dstack-runner" DEFAULT_PRIVATE_SUBNETS = ("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16") -GoArchType = Literal["amd64", "arm64"] + +class GoArchType(str, Enum): + """ + A subset of GOARCH values + """ + + AMD64 = "amd64" + ARM64 = "arm64" + + def to_cpu_architecture(self) -> CPUArchitecture: + if self == self.AMD64: + return CPUArchitecture.X86 + if self == self.ARM64: + return CPUArchitecture.ARM + assert False, self class Compute(ABC): @@ -688,14 +704,14 @@ def normalize_arch(arch: Optional[str] = None) -> GoArchType: If the arch is not specified, falls back to `amd64`. """ if not arch: - return "amd64" + return GoArchType.AMD64 arch_lower = arch.lower() if "32" in arch_lower or arch_lower in ["i386", "i686"]: raise ValueError(f"32-bit architectures are not supported: {arch}") if arch_lower.startswith("x86") or arch_lower.startswith("amd"): - return "amd64" + return GoArchType.AMD64 if arch_lower.startswith("arm") or arch_lower.startswith("aarch"): - return "arm64" + return GoArchType.ARM64 raise ValueError(f"Unsupported architecture: {arch}") @@ -711,8 +727,7 @@ def get_dstack_runner_download_url(arch: Optional[str] = None) -> str: "/{version}/binaries/dstack-runner-linux-{arch}" ) version = get_dstack_runner_version() - arch = normalize_arch(arch) - return url_template.format(version=version, arch=arch) + return url_template.format(version=version, arch=normalize_arch(arch).value) def get_dstack_shim_download_url(arch: Optional[str] = None) -> str: @@ -727,8 +742,7 @@ def get_dstack_shim_download_url(arch: Optional[str] = None) -> str: "/{version}/binaries/dstack-shim-linux-{arch}" ) version = get_dstack_runner_version() - arch = normalize_arch(arch) - return url_template.format(version=version, arch=arch) + return url_template.format(version=version, arch=normalize_arch(arch).value) def get_setup_cloud_instance_commands( diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 8307c7672c..f2bc714232 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -2,7 +2,7 @@ import tempfile import threading import time -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple from gpuhunt import KNOWN_NVIDIA_GPUS, AcceleratorVendor from kubernetes import client @@ -15,6 +15,7 @@ generate_unique_instance_name_for_job, get_docker_commands, get_dstack_gateway_commands, + normalize_arch, ) from dstack._internal.core.backends.base.offers import filter_offers_by_requirements from dstack._internal.core.backends.kubernetes.models import ( @@ -22,8 +23,10 @@ KubernetesNetworkingConfig, ) from dstack._internal.core.backends.kubernetes.utils import ( + call_api_method, get_api_from_config_data, get_cluster_public_ip, + get_value, ) from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ComputeError @@ -44,6 +47,7 @@ Resources, SSHConnectionParams, ) +from dstack._internal.core.models.resources import CPUSpec from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run from dstack._internal.core.models.volumes import Volume from dstack._internal.utils.common import parse_memory @@ -52,7 +56,6 @@ logger = get_logger(__name__) JUMP_POD_SSH_PORT = 22 -DEFAULT_NAMESPACE = "default" NVIDIA_GPU_NAME_TO_GPU_INFO = {gpu.name: gpu for gpu in KNOWN_NVIDIA_GPUS} NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO.keys() @@ -75,25 +78,43 @@ def __init__(self, config: KubernetesConfig): def get_offers_by_requirements( self, requirements: Requirements ) -> List[InstanceOfferWithAvailability]: - nodes = self.api.list_node() - instance_offers = [] - for node in nodes.items: + instance_offers: list[InstanceOfferWithAvailability] = [] + node_list = call_api_method( + self.api.list_node, + client.V1NodeList, + ) + nodes = get_value(node_list, ".items", list[client.V1Node], required=True) + for node in nodes: + try: + labels = get_value(node, ".metadata.labels", dict[str, str]) or {} + name = get_value(node, ".metadata.name", str, required=True) + cpus = _parse_cpu( + get_value(node, ".status.allocatable['cpu']", str, required=True) + ) + cpu_arch = normalize_arch( + get_value(node, ".status.node_info.architecture", str) + ).to_cpu_architecture() + memory_mib = _parse_memory( + get_value(node, ".status.allocatable['memory']", str, required=True) + ) + gpus, _ = _get_gpus_from_node_labels(labels) + disk_size_mib = _parse_memory( + get_value(node, ".status.allocatable['ephemeral-storage']", str, required=True) + ) + except (AttributeError, KeyError, ValueError) as e: + logger.exception("Failed to process node: %s: %s", type(e).__name__, e) + continue instance_offer = InstanceOfferWithAvailability( backend=BackendType.KUBERNETES, instance=InstanceType( - name=node.metadata.name, + name=name, resources=Resources( - cpus=node.status.capacity["cpu"], - memory_mib=int(parse_memory(node.status.capacity["memory"], as_untis="M")), - gpus=_get_gpus_from_node_labels(node.metadata.labels), + cpus=cpus, + cpu_arch=cpu_arch, + memory_mib=memory_mib, + gpus=gpus, spot=False, - disk=Disk( - size_mib=int( - parse_memory( - node.status.capacity["ephemeral-storage"], as_untis="M" - ) - ) - ), + disk=Disk(size_mib=disk_size_mib), ), ), price=0, @@ -132,6 +153,7 @@ def run_job( ) jump_pod_port, created = _create_jump_pod_service_if_not_exists( api=self.api, + namespace=self.config.namespace, project_name=run.project_name, ssh_public_keys=[project_ssh_public_key.strip(), run.run_spec.ssh_key_pub.strip()], jump_pod_port=self.networking_config.ssh_port, @@ -141,6 +163,7 @@ def run_job( target=_continue_setup_jump_pod, kwargs={ "api": self.api, + "namespace": self.config.namespace, "project_name": run.project_name, "project_ssh_private_key": project_ssh_private_key.strip(), "user_ssh_public_key": run.run_spec.ssh_key_pub.strip(), @@ -148,41 +171,114 @@ def run_job( "jump_pod_port": jump_pod_port, }, ).start() - self.api.create_namespaced_pod( - namespace=DEFAULT_NAMESPACE, - body=client.V1Pod( - metadata=client.V1ObjectMeta( - name=instance_name, - labels={"app.kubernetes.io/name": instance_name}, - ), - spec=client.V1PodSpec( - containers=[ - client.V1Container( - name=f"{instance_name}-container", - image=job.job_spec.image_name, - command=["/bin/sh"], - args=["-c", " && ".join(commands)], - ports=[ - client.V1ContainerPort( - container_port=DSTACK_RUNNER_SSH_PORT, - ) + resources_spec = job.job_spec.requirements.resources + assert isinstance(resources_spec.cpu, CPUSpec) + resources_requests: dict[str, str] = {} + resources_limits: dict[str, str] = {} + node_affinity: Optional[client.V1NodeAffinity] = None + if (cpu_min := resources_spec.cpu.count.min) is not None: + resources_requests["cpu"] = str(cpu_min) + if (gpu_spec := resources_spec.gpu) is not None: + gpu_min = gpu_spec.count.min + if gpu_min is not None and gpu_min > 0: + if not (offer_gpus := instance_offer.instance.resources.gpus): + raise ComputeError( + "GPU is requested but the offer has no GPUs:" + f" {gpu_spec=} {instance_offer=}", + ) + offer_gpu = offer_gpus[0] + matching_gpu_label_values: set[str] = set() + # We cannot generate an expected GPU label value from the Gpu model instance + # as the actual values may have additional components (socket, memory type, etc.) + # that we don't preserve in the Gpu model, e.g., "NVIDIA-H100-80GB-HBM3". + # Moreover, a single Gpu may match multiple label values. + # As a workaround, we iterate and process all node labels once again (we already + # processed them in `get_offers_by_requirements()`). + node_list = call_api_method( + self.api.list_node, + client.V1NodeList, + ) + nodes = get_value(node_list, ".items", list[client.V1Node], required=True) + for node in nodes: + labels = get_value(node, ".metadata.labels", dict[str, str]) + if not labels: + continue + gpus, gpu_label_value = _get_gpus_from_node_labels(labels) + if not gpus or gpu_label_value is None: + continue + if gpus[0] == offer_gpu: + matching_gpu_label_values.add(gpu_label_value) + if not matching_gpu_label_values: + raise ComputeError( + f"GPU is requested but no matching GPU labels found: {gpu_spec=}" + ) + logger.debug( + "Requesting %d GPU(s), node labels: %s", gpu_min, matching_gpu_label_values + ) + # TODO: support other GPU vendors + resources_requests["nvidia.com/gpu"] = str(gpu_min) + resources_limits["nvidia.com/gpu"] = str(gpu_min) + node_affinity = client.V1NodeAffinity( + required_during_scheduling_ignored_during_execution=[ + client.V1NodeSelectorTerm( + match_expressions=[ + client.V1NodeSelectorRequirement( + key="nvidia.com/gpu.product", + operator="In", + values=list(matching_gpu_label_values), + ), ], - security_context=client.V1SecurityContext( - # TODO(#1535): support non-root images properly - run_as_user=0, - run_as_group=0, - ), - # TODO: Pass cpu, memory, gpu as requests. - # Beware that node capacity != allocatable, so - # if the node has 2xCPU – then cpu=2 request will probably fail. - resources=client.V1ResourceRequirements(requests={}), - ) - ] - ), + ), + ], + ) + if (memory_min := resources_spec.memory.min) is not None: + resources_requests["memory"] = f"{float(memory_min)}Gi" + if ( + resources_spec.disk is not None + and (disk_min := resources_spec.disk.size.min) is not None + ): + resources_requests["ephemeral-storage"] = f"{float(disk_min)}Gi" + pod = client.V1Pod( + metadata=client.V1ObjectMeta( + name=instance_name, + labels={"app.kubernetes.io/name": instance_name}, ), + spec=client.V1PodSpec( + containers=[ + client.V1Container( + name=f"{instance_name}-container", + image=job.job_spec.image_name, + command=["/bin/sh"], + args=["-c", " && ".join(commands)], + ports=[ + client.V1ContainerPort( + container_port=DSTACK_RUNNER_SSH_PORT, + ) + ], + security_context=client.V1SecurityContext( + # TODO(#1535): support non-root images properly + run_as_user=0, + run_as_group=0, + ), + resources=client.V1ResourceRequirements( + requests=resources_requests, + limits=resources_limits, + ), + ) + ], + affinity=node_affinity, + ), + ) + call_api_method( + self.api.create_namespaced_pod, + client.V1Pod, + namespace=self.config.namespace, + body=pod, ) - service_response = self.api.create_namespaced_service( - namespace=DEFAULT_NAMESPACE, + service = call_api_method( + self.api.create_namespaced_service, + client.V1Service, + namespace=self.config.namespace, body=client.V1Service( metadata=client.V1ObjectMeta(name=_get_pod_service_name(instance_name)), spec=client.V1ServiceSpec( @@ -192,7 +288,7 @@ def run_job( ), ), ) - service_ip = service_response.spec.cluster_ip + service_ip = get_value(service, ".spec.cluster_ip", str, required=True) return JobProvisioningData( backend=instance_offer.backend, instance_type=instance_offer.instance, @@ -215,22 +311,22 @@ def run_job( def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): - try: - self.api.delete_namespaced_service( - name=_get_pod_service_name(instance_id), - namespace=DEFAULT_NAMESPACE, - body=client.V1DeleteOptions(), - ) - except client.ApiException as e: - if e.status != 404: - raise - try: - self.api.delete_namespaced_pod( - name=instance_id, namespace=DEFAULT_NAMESPACE, body=client.V1DeleteOptions() - ) - except client.ApiException as e: - if e.status != 404: - raise + call_api_method( + self.api.delete_namespaced_service, + client.V1Service, + expected=404, + name=_get_pod_service_name(instance_id), + namespace=self.config.namespace, + body=client.V1DeleteOptions(), + ) + call_api_method( + self.api.delete_namespaced_pod, + client.V1Pod, + expected=404, + name=instance_id, + namespace=self.config.namespace, + body=client.V1DeleteOptions(), + ) def create_gateway( self, @@ -247,67 +343,75 @@ def create_gateway( # https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html instance_name = generate_unique_gateway_instance_name(configuration) commands = _get_gateway_commands(authorized_keys=[configuration.ssh_key_pub]) - self.api.create_namespaced_pod( - namespace=DEFAULT_NAMESPACE, - body=client.V1Pod( - metadata=client.V1ObjectMeta( - name=instance_name, - labels={"app.kubernetes.io/name": instance_name}, - ), - spec=client.V1PodSpec( - containers=[ - client.V1Container( - name=f"{instance_name}-container", - image="ubuntu:22.04", - command=["/bin/sh"], - args=["-c", " && ".join(commands)], - ports=[ - client.V1ContainerPort( - container_port=22, - ), - client.V1ContainerPort( - container_port=80, - ), - client.V1ContainerPort( - container_port=443, - ), - ], - ) - ] - ), + pod = client.V1Pod( + metadata=client.V1ObjectMeta( + name=instance_name, + labels={"app.kubernetes.io/name": instance_name}, + ), + spec=client.V1PodSpec( + containers=[ + client.V1Container( + name=f"{instance_name}-container", + image="ubuntu:22.04", + command=["/bin/sh"], + args=["-c", " && ".join(commands)], + ports=[ + client.V1ContainerPort( + container_port=22, + ), + client.V1ContainerPort( + container_port=80, + ), + client.V1ContainerPort( + container_port=443, + ), + ], + ) + ] ), ) - self.api.create_namespaced_service( - namespace=DEFAULT_NAMESPACE, - body=client.V1Service( - metadata=client.V1ObjectMeta( - name=_get_pod_service_name(instance_name), - ), - spec=client.V1ServiceSpec( - type="LoadBalancer", - selector={"app.kubernetes.io/name": instance_name}, - ports=[ - client.V1ServicePort( - name="ssh", - port=22, - target_port=22, - ), - client.V1ServicePort( - name="http", - port=80, - target_port=80, - ), - client.V1ServicePort( - name="https", - port=443, - target_port=443, - ), - ], - ), + call_api_method( + self.api.create_namespaced_pod, + client.V1Pod, + namespace=self.config.namespace, + body=pod, + ) + service = client.V1Service( + metadata=client.V1ObjectMeta( + name=_get_pod_service_name(instance_name), + ), + spec=client.V1ServiceSpec( + type="LoadBalancer", + selector={"app.kubernetes.io/name": instance_name}, + ports=[ + client.V1ServicePort( + name="ssh", + port=22, + target_port=22, + ), + client.V1ServicePort( + name="http", + port=80, + target_port=80, + ), + client.V1ServicePort( + name="https", + port=443, + target_port=443, + ), + ], ), ) + call_api_method( + self.api.create_namespaced_service, + client.V1Service, + namespace=self.config.namespace, + body=service, + ) hostname = _wait_for_load_balancer_hostname( - api=self.api, service_name=_get_pod_service_name(instance_name) + api=self.api, + namespace=self.config.namespace, + service_name=_get_pod_service_name(instance_name), ) if hostname is None: self.terminate_instance(instance_name, region="-") @@ -334,15 +438,30 @@ def terminate_gateway( ) -def _get_gpus_from_node_labels(labels: Dict) -> List[Gpu]: - # We rely on https://github.com/NVIDIA/gpu-feature-discovery to detect gpus. - # Note that "nvidia.com/gpu.product" is not a short gpu name like "T4" or "A100" but a product name - # from nvidia-smi like "Tesla-T4" or "A100-SXM4-40GB". +def _parse_cpu(cpu: str) -> int: + if cpu.endswith("m"): + # "m" means millicpu (1/1000 CPU), e.g., 7900m -> 7.9 -> 7 + return int(float(cpu[:-1]) / 1000) + return int(cpu) + + +def _parse_memory(memory: str) -> int: + if memory.isdigit(): + # no suffix means that the value is in bytes + return int(memory) // 2**20 + return int(parse_memory(memory, as_untis="M")) + + +def _get_gpus_from_node_labels(labels: dict[str, str]) -> tuple[list[Gpu], Optional[str]]: + # We rely on https://github.com/NVIDIA/k8s-device-plugin/tree/main/docs/gpu-feature-discovery + # to detect gpus. Note that "nvidia.com/gpu.product" is not a short gpu name like "T4" or + # "A100" but a product name like "Tesla-T4" or "A100-SXM4-40GB". # Thus, we convert the product name to a known gpu name. + # TODO: support other GPU vendors gpu_count = labels.get("nvidia.com/gpu.count") gpu_product = labels.get("nvidia.com/gpu.product") if gpu_count is None or gpu_product is None: - return [] + return [], None gpu_count = int(gpu_count) gpu_name = None for known_gpu_name in NVIDIA_GPU_NAMES: @@ -350,20 +469,22 @@ def _get_gpus_from_node_labels(labels: Dict) -> List[Gpu]: gpu_name = known_gpu_name break if gpu_name is None: - return [] + return [], None gpu_info = NVIDIA_GPU_NAME_TO_GPU_INFO[gpu_name] gpu_memory = gpu_info.memory * 1024 # A100 may come in two variants if "40GB" in gpu_product: gpu_memory = 40 * 1024 - return [ + gpus = [ Gpu(vendor=AcceleratorVendor.NVIDIA, name=gpu_name, memory_mib=gpu_memory) for _ in range(gpu_count) ] + return gpus, gpu_product def _continue_setup_jump_pod( api: client.CoreV1Api, + namespace: str, project_name: str, project_ssh_private_key: str, user_ssh_public_key: str, @@ -372,6 +493,7 @@ def _continue_setup_jump_pod( ): _wait_for_pod_ready( api=api, + namespace=namespace, pod_name=_get_jump_pod_name(project_name), ) _add_authorized_key_to_jump_pod( @@ -384,82 +506,135 @@ def _continue_setup_jump_pod( def _create_jump_pod_service_if_not_exists( api: client.CoreV1Api, + namespace: str, project_name: str, ssh_public_keys: List[str], jump_pod_port: Optional[int], ) -> Tuple[int, bool]: created = False - try: - service = api.read_namespaced_service( + service: Optional[client.V1Service] = None + pod: Optional[client.V1Pod] = None + _namespace = call_api_method( + api.read_namespace, + client.V1Namespace, + expected=404, + name=namespace, + ) + if _namespace is None: + _namespace = client.V1Namespace( + metadata=client.V1ObjectMeta( + name=namespace, + labels={"app.kubernetes.io/name": namespace}, + ), + ) + call_api_method( + api.create_namespace, + client.V1Namespace, + body=_namespace, + ) + else: + service = call_api_method( + api.read_namespaced_service, + client.V1Service, + expected=404, name=_get_jump_pod_service_name(project_name), - namespace=DEFAULT_NAMESPACE, + namespace=namespace, ) - except client.ApiException as e: - if e.status == 404: - service = _create_jump_pod_service( - api=api, - project_name=project_name, - ssh_public_keys=ssh_public_keys, - jump_pod_port=jump_pod_port, - ) - created = True - else: - raise - return service.spec.ports[0].node_port, created + pod = call_api_method( + api.read_namespaced_pod, + client.V1Pod, + expected=404, + name=_get_jump_pod_name(project_name), + namespace=namespace, + ) + # The service may exist without the pod if the node on which the jump pod was running + # has been deleted. + if service is None or pod is None: + service = _create_jump_pod_service( + api=api, + namespace=namespace, + project_name=project_name, + ssh_public_keys=ssh_public_keys, + jump_pod_port=jump_pod_port, + ) + created = True + port = get_value(service, ".spec.ports[0].node_port", int, required=True) + return port, created def _create_jump_pod_service( api: client.CoreV1Api, + namespace: str, project_name: str, ssh_public_keys: List[str], jump_pod_port: Optional[int], ) -> client.V1Service: # TODO use restricted ssh-forwarding-only user for jump pod instead of root. - commands = _get_jump_pod_commands(authorized_keys=ssh_public_keys) pod_name = _get_jump_pod_name(project_name) - api.create_namespaced_pod( - namespace=DEFAULT_NAMESPACE, - body=client.V1Pod( - metadata=client.V1ObjectMeta( - name=pod_name, - labels={"app.kubernetes.io/name": pod_name}, - ), - spec=client.V1PodSpec( - containers=[ - client.V1Container( - name=f"{pod_name}-container", - # TODO: Choose appropriate image for jump pod - image="dstackai/base:py3.11-0.4rc4", - command=["/bin/sh"], - args=["-c", " && ".join(commands)], - ports=[ - client.V1ContainerPort( - container_port=JUMP_POD_SSH_PORT, - ) - ], - ) - ] - ), + call_api_method( + api.delete_namespaced_pod, + client.V1Pod, + expected=404, + namespace=namespace, + name=pod_name, + ) + commands = _get_jump_pod_commands(authorized_keys=ssh_public_keys) + pod = client.V1Pod( + metadata=client.V1ObjectMeta( + name=pod_name, + labels={"app.kubernetes.io/name": pod_name}, + ), + spec=client.V1PodSpec( + containers=[ + client.V1Container( + name=f"{pod_name}-container", + # TODO: Choose appropriate image for jump pod + image="dstackai/base:py3.11-0.4rc4", + command=["/bin/sh"], + args=["-c", " && ".join(commands)], + ports=[ + client.V1ContainerPort( + container_port=JUMP_POD_SSH_PORT, + ) + ], + ) + ] ), ) - service_response = api.create_namespaced_service( - namespace=DEFAULT_NAMESPACE, - body=client.V1Service( - metadata=client.V1ObjectMeta(name=_get_jump_pod_service_name(project_name)), - spec=client.V1ServiceSpec( - type="NodePort", - selector={"app.kubernetes.io/name": pod_name}, - ports=[ - client.V1ServicePort( - port=JUMP_POD_SSH_PORT, - target_port=JUMP_POD_SSH_PORT, - node_port=jump_pod_port, - ) - ], - ), + call_api_method( + api.create_namespaced_pod, + client.V1Pod, + namespace=namespace, + body=pod, + ) + service_name = _get_jump_pod_service_name(project_name) + call_api_method( + api.delete_namespaced_service, + client.V1Service, + expected=404, + namespace=namespace, + name=service_name, + ) + service = client.V1Service( + metadata=client.V1ObjectMeta(name=service_name), + spec=client.V1ServiceSpec( + type="NodePort", + selector={"app.kubernetes.io/name": pod_name}, + ports=[ + client.V1ServicePort( + port=JUMP_POD_SSH_PORT, + target_port=JUMP_POD_SSH_PORT, + node_port=jump_pod_port, + ) + ], ), ) - return service_response + return call_api_method( + api.create_namespaced_service, + client.V1Service, + namespace=namespace, + body=service, + ) def _get_jump_pod_commands(authorized_keys: List[str]) -> List[str]: @@ -484,20 +659,25 @@ def _get_jump_pod_commands(authorized_keys: List[str]) -> List[str]: def _wait_for_pod_ready( api: client.CoreV1Api, + namespace: str, pod_name: str, timeout_seconds: int = 300, ): start_time = time.time() while True: - try: - pod = api.read_namespaced_pod(name=pod_name, namespace=DEFAULT_NAMESPACE) - except client.ApiException as e: - if e.status != 404: - raise - else: - if pod.status.phase == "Running" and all( - container_status.ready for container_status in pod.status.container_statuses - ): + pod = call_api_method( + api.read_namespaced_pod, + client.V1Pod, + expected=404, + name=pod_name, + namespace=namespace, + ) + if pod is not None: + phase = get_value(pod, ".status.phase", str, required=True) + container_statuses = get_value( + pod, ".status.container_statuses", list[client.V1ContainerStatus], required=True + ) + if phase == "Running" and all(status.ready for status in container_statuses): return True elapsed_time = time.time() - start_time if elapsed_time >= timeout_seconds: @@ -508,19 +688,23 @@ def _wait_for_pod_ready( def _wait_for_load_balancer_hostname( api: client.CoreV1Api, + namespace: str, service_name: str, timeout_seconds: int = 120, ) -> Optional[str]: start_time = time.time() while True: - try: - service = api.read_namespaced_service(name=service_name, namespace=DEFAULT_NAMESPACE) - except client.ApiException as e: - if e.status != 404: - raise - else: - if service.status.load_balancer.ingress is not None: - return service.status.load_balancer.ingress[0].hostname + service = call_api_method( + api.read_namespaced_service, + client.V1Service, + expected=404, + name=service_name, + namespace=namespace, + ) + if service is not None: + hostname = get_value(service, ".status.load_balancer.ingress[0].hostname", str) + if hostname is not None: + return hostname elapsed_time = time.time() - start_time if elapsed_time >= timeout_seconds: logger.warning("Timeout waiting for load balancer %s to get ip", service_name) diff --git a/src/dstack/_internal/core/backends/kubernetes/models.py b/src/dstack/_internal/core/backends/kubernetes/models.py index ed8af7de7a..a4f05cab13 100644 --- a/src/dstack/_internal/core/backends/kubernetes/models.py +++ b/src/dstack/_internal/core/backends/kubernetes/models.py @@ -5,6 +5,8 @@ from dstack._internal.core.backends.base.models import fill_data from dstack._internal.core.models.common import CoreModel +DEFAULT_NAMESPACE = "default" + class KubernetesNetworkingConfig(CoreModel): ssh_host: Annotated[ @@ -25,13 +27,12 @@ class KubernetesBackendConfig(CoreModel): networking: Annotated[ Optional[KubernetesNetworkingConfig], Field(description="The networking configuration") ] = None + namespace: Annotated[ + str, Field(description="The namespace for resources managed by `dstack`") + ] = DEFAULT_NAMESPACE -class KubernetesBackendConfigWithCreds(CoreModel): - type: Annotated[Literal["kubernetes"], Field(description="The type of backend")] = "kubernetes" - networking: Annotated[ - Optional[KubernetesNetworkingConfig], Field(description="The networking configuration") - ] = None +class KubernetesBackendConfigWithCreds(KubernetesBackendConfig): kubeconfig: Annotated[KubeconfigConfig, Field(description="The kubeconfig configuration")] @@ -53,11 +54,7 @@ def fill_data(cls, values): return fill_data(values) -class KubernetesBackendFileConfigWithCreds(CoreModel): - type: Annotated[Literal["kubernetes"], Field(description="The type of backend")] = "kubernetes" - networking: Annotated[ - Optional[KubernetesNetworkingConfig], Field(description="The networking configuration") - ] = None +class KubernetesBackendFileConfigWithCreds(KubernetesBackendConfig): kubeconfig: Annotated[KubeconfigFileConfig, Field(description="The kubeconfig configuration")] diff --git a/src/dstack/_internal/core/backends/kubernetes/utils.py b/src/dstack/_internal/core/backends/kubernetes/utils.py index b4a19fd448..d50489c8fa 100644 --- a/src/dstack/_internal/core/backends/kubernetes/utils.py +++ b/src/dstack/_internal/core/backends/kubernetes/utils.py @@ -1,20 +1,157 @@ -from typing import Dict, List, Optional +import ast +from typing import Any, Callable, List, Literal, Optional, TypeVar, Union, get_origin, overload -import kubernetes import yaml +from kubernetes import client as kubernetes_client +from kubernetes import config as kubernetes_config +from typing_extensions import ParamSpec +T = TypeVar("T") +P = ParamSpec("P") -def get_api_from_config_data(kubeconfig_data: str) -> kubernetes.client.CoreV1Api: + +def get_api_from_config_data(kubeconfig_data: str) -> kubernetes_client.CoreV1Api: config_dict = yaml.load(kubeconfig_data, yaml.FullLoader) return get_api_from_config_dict(config_dict) -def get_api_from_config_dict(kubeconfig: Dict) -> kubernetes.client.CoreV1Api: - api_client = kubernetes.config.new_client_from_config_dict(config_dict=kubeconfig) - return kubernetes.client.CoreV1Api(api_client=api_client) +def get_api_from_config_dict(kubeconfig: dict) -> kubernetes_client.CoreV1Api: + api_client = kubernetes_config.new_client_from_config_dict(config_dict=kubeconfig) + return kubernetes_client.CoreV1Api(api_client=api_client) + + +@overload +def call_api_method( + method: Callable[P, Any], + type_: type[T], + expected: None = None, + *args: P.args, + **kwargs: P.kwargs, +) -> T: ... + + +@overload +def call_api_method( + method: Callable[P, Any], + type_: type[T], + expected: Union[int, tuple[int, ...], list[int]], + *args: P.args, + **kwargs: P.kwargs, +) -> Optional[T]: ... + + +def call_api_method( + method: Callable[P, Any], + type_: type[T], + expected: Optional[Union[int, tuple[int, ...], list[int]]] = None, + *args: P.args, + **kwargs: P.kwargs, +) -> Optional[T]: + """ + Returns the result of the API method call, optionally ignoring specified HTTP status codes. + + Args: + method: the `CoreV1Api` bound method. + type_: The expected type of the return value, used for runtime type checking and + as a type hint for a static type checker (as kubernetes package is not type-annotated). + NB: For composite types, only "origin" type is checked, e.g., list, not list[Node] + expected: Expected error statuses, e.g., 404. + args: positional arguments of the method. + kwargs: keyword arguments of the method. + Returns: + The return value or `None` in case of the expected error. + """ + if isinstance(expected, int): + expected = (expected,) + result: T + try: + result = method(*args, **kwargs) + except kubernetes_client.ApiException as e: + if expected is None or e.status not in expected: + raise + return None + if not isinstance(result, get_origin(type_) or type_): + raise TypeError( + f"{method.__name__} returned {type(result).__name__}, expected {type_.__name__}" + ) + return result + + +@overload +def get_value( + obj: object, path: str, type_: type[T], *, required: Literal[False] = False +) -> Optional[T]: ... + + +@overload +def get_value(obj: object, path: str, type_: type[T], *, required: Literal[True]) -> T: ... + + +def get_value(obj: object, path: str, type_: type[T], *, required: bool = False) -> Optional[T]: + """ + Returns the value at a given path. + Supports object attributes, sequence indices, and mapping keys. + + Args: + obj: The object to traverse. + path: The path to the value, regular Python syntax. The leading dot is optional, all the + following are correct: `.attr`, `attr`, `.[0]`, `[0]`, `.['key']`, `['key']`. + type_: The expected type of the value, used for runtime type checking and as a type hint + for a static type checker (as kubernetes package is not type-annotated). + NB: For composite types, only "origin" type is checked, e.g., list, not list[Node] + required: If `True`, the value must exist and must not be `None`. If `False` (safe + navigation mode), the may not exist and may be `None`. + + Returns: + The requested value or `None` in case of failed traverse when required=False. + """ + _path = path.removeprefix(".") + if _path.startswith("["): + src = f"obj{_path}" + else: + src = f"obj.{_path}" + module = ast.parse(src) + assert len(module.body) == 1, ast.dump(module, indent=4) + root_expr = module.body[0] + assert isinstance(root_expr, ast.Expr), ast.dump(module, indent=4) + varname: Optional[str] = None + expr = root_expr.value + while True: + if isinstance(expr, ast.Name): + varname = expr.id + break + if __debug__: + if isinstance(expr, ast.Subscript): + if isinstance(expr.slice, ast.UnaryOp): + # .items[-1] + assert isinstance(expr.slice.op, ast.USub), ast.dump(expr, indent=4) + assert isinstance(expr.slice.operand, ast.Constant), ast.dump(expr, indent=4) + assert isinstance(expr.slice.operand.value, int), ast.dump(expr, indent=4) + else: + # .items[0], .labels["name"] + assert isinstance(expr.slice, ast.Constant), ast.dump(expr, indent=4) + else: + assert isinstance(expr, ast.Attribute), ast.dump(expr, indent=4) + else: + assert isinstance(expr, (ast.Attribute, ast.Subscript)) + expr = expr.value + assert varname is not None, ast.dump(module) + try: + value = eval(src, {"__builtins__": {}}, {"obj": obj}) + except (AttributeError, KeyError, IndexError, TypeError) as e: + if required: + raise type(e)(f"Failed to traverse {path}: {e}") from e + return None + if value is None: + if required: + raise TypeError(f"Required {path} is None") + return value + if not isinstance(value, get_origin(type_) or type_): + raise TypeError(f"{path} value is {type(value).__name__}, expected {type_.__name__}") + return value -def get_cluster_public_ip(api_client: kubernetes.client.CoreV1Api) -> Optional[str]: +def get_cluster_public_ip(api_client: kubernetes_client.CoreV1Api) -> Optional[str]: """ Returns public IP of any cluster node. """ @@ -24,7 +161,7 @@ def get_cluster_public_ip(api_client: kubernetes.client.CoreV1Api) -> Optional[s return public_ips[0] -def get_cluster_public_ips(api_client: kubernetes.client.CoreV1Api) -> List[str]: +def get_cluster_public_ips(api_client: kubernetes_client.CoreV1Api) -> List[str]: """ Returns public IPs of all cluster nodes. """ diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 5c4e78a85c..b44c9271b4 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -307,7 +307,7 @@ async def _add_remote(instance: InstanceModel) -> None: ) deploy_timeout = 20 * 60 # 20 minutes result = await asyncio.wait_for(future, timeout=deploy_timeout) - health, host_info, cpu_arch = result + health, host_info, arch = result except (asyncio.TimeoutError, TimeoutError) as e: raise ProvisioningError(f"Deploy timeout: {e}") from e except Exception as e: @@ -327,7 +327,7 @@ async def _add_remote(instance: InstanceModel) -> None: instance.status = InstanceStatus.PENDING return - instance_type = host_info_to_instance_type(host_info, cpu_arch) + instance_type = host_info_to_instance_type(host_info, arch) instance_network = None internal_ip = None try: diff --git a/src/dstack/_internal/server/utils/provisioning.py b/src/dstack/_internal/server/utils/provisioning.py index b77efe7db4..632dce777a 100644 --- a/src/dstack/_internal/server/utils/provisioning.py +++ b/src/dstack/_internal/server/utils/provisioning.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, List, Optional import paramiko -from gpuhunt import AcceleratorVendor, CPUArchitecture, correct_gpu_memory_gib +from gpuhunt import AcceleratorVendor, correct_gpu_memory_gib from dstack._internal.core.backends.base.compute import GoArchType, normalize_arch from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT @@ -248,14 +248,7 @@ def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[str]: return out -def host_info_to_instance_type(host_info: Dict[str, Any], cpu_arch: GoArchType) -> InstanceType: - _cpu_arch: CPUArchitecture - if cpu_arch == "amd64": - _cpu_arch = CPUArchitecture.X86 - elif cpu_arch == "arm64": - _cpu_arch = CPUArchitecture.ARM - else: - raise ValueError(f"Unexpected cpu_arch: {cpu_arch}") +def host_info_to_instance_type(host_info: Dict[str, Any], arch: GoArchType) -> InstanceType: gpu_count = host_info.get("gpu_count", 0) if gpu_count > 0: gpu_vendor = AcceleratorVendor.cast(host_info.get("gpu_vendor", "nvidia")) @@ -280,7 +273,7 @@ def host_info_to_instance_type(host_info: Dict[str, Any], cpu_arch: GoArchType) instance_type = InstanceType( name="instance", resources=Resources( - cpu_arch=_cpu_arch, + cpu_arch=arch.to_cpu_architecture(), cpus=host_info["cpus"], memory_mib=host_info["memory"] / 1024 / 1024, spot=False, diff --git a/src/tests/_internal/core/backends/base/test_compute.py b/src/tests/_internal/core/backends/base/test_compute.py index 8b50893c53..848aea822c 100644 --- a/src/tests/_internal/core/backends/base/test_compute.py +++ b/src/tests/_internal/core/backends/base/test_compute.py @@ -4,6 +4,7 @@ import pytest from dstack._internal.core.backends.base.compute import ( + GoArchType, generate_unique_backend_name, generate_unique_gateway_instance_name, generate_unique_instance_name, @@ -63,11 +64,11 @@ def test_validates_project_name(self): class TestNormalizeArch: @pytest.mark.parametrize("arch", [None, "", "X86", "x86_64", "AMD64"]) def test_amd64(self, arch: Optional[str]): - assert normalize_arch(arch) == "amd64" + assert normalize_arch(arch) is GoArchType.AMD64 @pytest.mark.parametrize("arch", ["arm", "ARM64", "AArch64"]) def test_arm64(self, arch: str): - assert normalize_arch(arch) == "arm64" + assert normalize_arch(arch) is GoArchType.ARM64 @pytest.mark.parametrize("arch", ["IA32", "i686", "ARM32", "aarch32"]) def test_32bit_not_supported(self, arch: str): diff --git a/src/tests/_internal/core/backends/kubernetes/test_compute.py b/src/tests/_internal/core/backends/kubernetes/test_compute.py index 3e6d1d8bc8..5736e35a01 100644 --- a/src/tests/_internal/core/backends/kubernetes/test_compute.py +++ b/src/tests/_internal/core/backends/kubernetes/test_compute.py @@ -4,10 +4,10 @@ class TestGetGPUsFromNodeLabels: def test_returns_no_gpus_if_no_labels(self): - assert _get_gpus_from_node_labels({}) == [] + assert _get_gpus_from_node_labels({}) == ([], None) def test_returns_no_gpus_if_missing_labels(self): - assert _get_gpus_from_node_labels({"nvidia.com/gpu.count": 1}) == [] + assert _get_gpus_from_node_labels({"nvidia.com/gpu.count": 1}) == ([], None) def test_returns_correct_memory_for_different_A100(self): assert _get_gpus_from_node_labels( @@ -15,10 +15,10 @@ def test_returns_correct_memory_for_different_A100(self): "nvidia.com/gpu.count": 1, "nvidia.com/gpu.product": "A100-SXM4-40GB", } - ) == [Gpu(name="A100", memory_mib=40 * 1024)] + ) == ([Gpu(name="A100", memory_mib=40 * 1024)], "A100-SXM4-40GB") assert _get_gpus_from_node_labels( { "nvidia.com/gpu.count": 1, "nvidia.com/gpu.product": "A100-SXM4-80GB", } - ) == [Gpu(name="A100", memory_mib=80 * 1024)] + ) == ([Gpu(name="A100", memory_mib=80 * 1024)], "A100-SXM4-80GB") diff --git a/src/tests/_internal/core/backends/kubernetes/test_utils.py b/src/tests/_internal/core/backends/kubernetes/test_utils.py new file mode 100644 index 0000000000..4404d0492b --- /dev/null +++ b/src/tests/_internal/core/backends/kubernetes/test_utils.py @@ -0,0 +1,68 @@ +from argparse import Namespace + +import pytest + +from dstack._internal.core.backends.kubernetes.utils import get_value + + +class TestGetValue: + def test_attribute_with_dot(self): + assert get_value(Namespace(field=False), ".field", bool) is False + + def test_attribute_without_dot(self): + assert get_value(Namespace(field=False), "field", bool) is False + + def test_index_with_dot(self): + assert get_value([False, True], ".[1]", bool) is True + + def test_index_without_dot(self): + assert get_value([False, True], "[1]", bool) is True + + def test_key_with_dot(self): + assert get_value({"field": True}, ".['field']", bool) is True + + def test_key_without_dot(self): + assert get_value({"field": True}, "['field']", bool) is True + + def test_nested(self): + obj = Namespace(sensors=[{"speed": Namespace(values=[127, 112, 98])}]) + assert get_value(obj, ".sensors[0]['speed'].values[-1]", int) == 98 + + def test_optional_is_missing(self): + obj = Namespace(sensors=[{"speed": Namespace(values=[127, 112, 98])}]) + assert get_value(obj, ".sensors[0]['altitude'].values[-1]", int) is None + + @pytest.mark.parametrize( + ["obj", "path", "exctype"], + [ + pytest.param(Namespace(), ".field", AttributeError, id="attribute"), + pytest.param([], ".[0]", IndexError, id="index"), + pytest.param({}, ".['test']", KeyError, id="key"), + pytest.param(Namespace(), ".['test']", TypeError, id="not-subscriptable"), + ], + ) + def test_required_is_missing(self, obj: object, path: str, exctype: type[Exception]): + with pytest.raises(exctype, match="Failed to traverse"): + get_value(obj, path, int, required=True) + + def test_required_is_null(self): + obj = Namespace(version=None) + with pytest.raises(TypeError, match="Required version is None"): + get_value(obj, "version", int, required=True) + + def test_unexpected_type(self): + obj = Namespace(version="1") + with pytest.raises(TypeError, match="version value is str, expected int"): + get_value(obj, "version", int, required=True) + + @pytest.mark.parametrize( + "path", + [ + pytest.param(".[var]", id="variable"), + pytest.param(".[1 + 2]", id="expression"), + pytest.param("print('test')", id="function-call"), + ], + ) + def test_assertions(self, path: str): + with pytest.raises(AssertionError): + get_value(None, path, str) diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index c1983fbed6..690cb71d95 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -11,6 +11,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.backends.base.compute import GoArchType from dstack._internal.core.errors import ( BackendError, NoCapacityError, @@ -1057,7 +1058,7 @@ def host_info(self) -> dict: @pytest.fixture def deploy_instance_mock(self, monkeypatch: pytest.MonkeyPatch, host_info: dict): - mock = Mock(return_value=(InstanceCheck(reachable=True), host_info, "amd64")) + mock = Mock(return_value=(InstanceCheck(reachable=True), host_info, GoArchType.AMD64)) monkeypatch.setattr( "dstack._internal.server.background.tasks.process_instances._deploy_instance", mock )