diff --git a/mdfactory/performance/__init__.py b/mdfactory/performance/__init__.py new file mode 100644 index 0000000..bbce0ad --- /dev/null +++ b/mdfactory/performance/__init__.py @@ -0,0 +1,9 @@ +# ABOUTME: HPC performance optimization package for mdfactory. +# ABOUTME: Cluster autodiscovery, CPU affinity, benchmarking, and GPU MPS management. +"""HPC performance optimization utilities. + +Modules +------- +cluster + SLURM cluster autodiscovery — query partitions, node types, accounts, and QOS. +""" diff --git a/mdfactory/performance/cluster.py b/mdfactory/performance/cluster.py new file mode 100644 index 0000000..66d91a0 --- /dev/null +++ b/mdfactory/performance/cluster.py @@ -0,0 +1,632 @@ +# ABOUTME: SLURM cluster autodiscovery — query partitions, node types, accounts, QOS. +# ABOUTME: Parses sinfo/sacctmgr output into structured dataclasses for resource-aware scheduling. +"""SLURM cluster autodiscovery. + +Query the local SLURM scheduler and return a structured representation of +available resources (partitions, node types, accounts, QOS policies, GPU types). + +Functions +--------- +discover_cluster + Main entry point — returns ``ClusterInfo`` or ``None`` if SLURM is unavailable. +select_partition + Heuristic partition selection given resource requirements. + +Examples +-------- +>>> from mdfactory.performance.cluster import discover_cluster, select_partition +>>> cluster = discover_cluster() +>>> if cluster is not None: +... gpu_part = select_partition(cluster, needs_gpu=True) +""" + +from __future__ import annotations + +import functools +import os +import shutil +import subprocess +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class NodeType: + """Hardware specification of a node type within a partition. + + Parameters + ---------- + cpus : int + Number of CPU cores per node. + memory_mb : int + Memory in megabytes per node. + gpus : int + Number of GPUs per node (0 if CPU-only). + gpu_type : str or None + GPU model identifier (e.g., ``"a100"``, ``"h100"``), or None. + features : tuple of str + SLURM feature/constraint tags on this node type (immutable). + """ + + cpus: int + memory_mb: int + gpus: int = 0 + gpu_type: str | None = None + features: tuple[str, ...] = field(default_factory=tuple) + + +@dataclass(frozen=True) +class Partition: + """A SLURM partition with its node types and limits. + + Parameters + ---------- + name : str + Partition name (e.g., ``"gpu"``, ``"cpu"``). + state : str + Partition-level state: ``"up"`` if any node is schedulable, otherwise + the last observed unhealthy state (e.g., ``"down"``, ``"drained"``). + max_time : str + Maximum walltime (SLURM format, e.g., ``"3-00:00:00"``). + default_time : str + Default walltime assigned when user does not specify one. + Populated from sinfo ``%L``; equals ``max_time`` on legacy output. + node_types : list of NodeType + Distinct hardware configurations available in this partition. + total_nodes : int + Total number of nodes in the partition. + is_default : bool + Whether this is the cluster's default partition. + """ + + name: str + state: str + max_time: str + default_time: str + node_types: list[NodeType] = field(default_factory=list) + total_nodes: int = 0 + is_default: bool = False + + +@dataclass(frozen=True) +class ClusterInfo: + """Structured representation of a SLURM cluster's resources. + + Parameters + ---------- + partitions : list of Partition + All discovered partitions. + accounts : list of str + SLURM accounts available to the current user. + qos_policies : list of str + Available QOS policy names. + default_account : str or None + The user's default account, if determinable. + """ + + partitions: list[Partition] = field(default_factory=list) + accounts: list[str] = field(default_factory=list) + qos_policies: list[str] = field(default_factory=list) + default_account: str | None = None + + +def _run_command(cmd: list[str], *, timeout: int = 30) -> str | None: + """Run a shell command and return stdout, or None on failure. + + Parameters + ---------- + cmd : list of str + Command and arguments. + timeout : int + Timeout in seconds. + + Returns + ------- + str or None + Stripped stdout on success, None on any failure. + """ + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + check=False, + ) + if result.returncode != 0: + return None + return result.stdout.strip() + except (FileNotFoundError, subprocess.TimeoutExpired, OSError): + return None + + +def _parse_gres(gres_str: str) -> tuple[int, str | None]: + """Parse SLURM GRES string to extract GPU count and type. + + Parameters + ---------- + gres_str : str + GRES field from sinfo (e.g., ``"gpu:a100:4"``, ``"gpu:2"``, ``"(null)"``). + + Returns + ------- + tuple of (int, str or None) + (gpu_count, gpu_type). Returns (0, None) when no GPUs. + + Examples + -------- + >>> _parse_gres("gpu:a100:4") + (4, 'a100') + >>> _parse_gres("gpu:2") + (2, None) + >>> _parse_gres("(null)") + (0, None) + """ + if not gres_str or gres_str == "(null)": + return 0, None + + # Handle multiple GRES entries separated by commas + for raw_entry in gres_str.split(","): + entry = raw_entry.strip() + if not entry.startswith("gpu"): + continue + parts = entry.split(":") + if len(parts) == 3: + # gpu:type:count + _, gpu_type, count_str = parts + return int(count_str), gpu_type + elif len(parts) == 2: + # gpu:count (no type specified) + _, count_str = parts + # Check if second part is a number or a type + try: + count = int(count_str) + return count, None + except ValueError: + # gpu:type with implicit count of 1 + return 1, count_str + + return 0, None + + +def _parse_time_limit(time_str: str) -> str: + """Normalize SLURM time limit strings. + + Parameters + ---------- + time_str : str + Time limit from sinfo (e.g., ``"3-00:00:00"``, ``"infinite"``, ``"2:00:00"``). + + Returns + ------- + str + Cleaned time string. + """ + if not time_str or time_str == "n/a": + return "unknown" + return time_str.strip() + + +def _parse_memory_mb(mem_str: str) -> int: + """Parse memory string from sinfo to megabytes. + + Parameters + ---------- + mem_str : str + Memory field from sinfo (numeric, in MB by default). + + Returns + ------- + int + Memory in MB. Returns 0 on parse failure. + """ + try: + # sinfo %m gives memory in MB as an integer + cleaned = mem_str.strip().rstrip("+") + return int(cleaned) + except (ValueError, AttributeError): + return 0 + + +def _parse_features(features_str: str) -> tuple[str, ...]: + """Parse SLURM features/constraints string. + + Parameters + ---------- + features_str : str + Features field from sinfo (comma-separated or ``"(null)"``). + + Returns + ------- + tuple of str + Feature strings (immutable). + """ + if not features_str or features_str == "(null)": + return () + return tuple(f.strip() for f in features_str.split(",") if f.strip()) + + +def _parse_sinfo(output: str) -> list[Partition]: + """Parse sinfo output into Partition objects. + + Expects output from: + sinfo -N --noheader -o "%P %n %c %m %G %f %l %L %T" + + Fields: Partition, NodeName, CPUs, Memory(MB), GRES, Features, + MaxTimeLimit, DefaultTimeLimit, State + + Also supports the legacy 8-field format (without %L) for backward + compatibility — in that case ``default_time`` equals ``max_time``. + + Parameters + ---------- + output : str + Raw sinfo output. + + Returns + ------- + list of Partition + Parsed partition list with deduplicated node types. + """ + # Collect data per partition + partition_data: dict[str, dict] = {} + + for raw_line in output.splitlines(): + line = raw_line.strip() + if not line: + continue + + parts = line.split() + if len(parts) >= 9: + # 9-field format: %P %n %c %m %G %f %l %L %T + partition_name = parts[0] + cpus_str = parts[2] + mem_str = parts[3] + gres_str = parts[4] + features_str = parts[5] + max_time = parts[6] + default_time = parts[7] + state = parts[8] + elif len(parts) == 8: + # Legacy 8-field format (no %L): default_time = max_time + partition_name = parts[0] + cpus_str = parts[2] + mem_str = parts[3] + gres_str = parts[4] + features_str = parts[5] + max_time = parts[6] + default_time = parts[6] + state = parts[7] + else: + continue + + # Handle default partition marker (trailing asterisk) + is_default = partition_name.endswith("*") + if is_default: + partition_name = partition_name.rstrip("*") + + # Parse node specs + try: + cpus = int(cpus_str) + except ValueError: + continue + + memory_mb = _parse_memory_mb(mem_str) + gpus, gpu_type = _parse_gres(gres_str) + features = _parse_features(features_str) + + if partition_name not in partition_data: + partition_data[partition_name] = { + "max_time": _parse_time_limit(max_time), + "default_time": _parse_time_limit(default_time), + "node_types": set(), + "node_count": 0, + "is_default": is_default, + "has_healthy_node": False, + "last_unhealthy_state": "down", + } + + # Use a hashable representation for deduplication (features is already a tuple) + node_key = (cpus, memory_mb, gpus, gpu_type, features) + partition_data[partition_name]["node_types"].add(node_key) + partition_data[partition_name]["node_count"] += 1 + + # Track if any line marks this as default + if is_default: + partition_data[partition_name]["is_default"] = True + + # Track node health: partition is "up" if ANY node is schedulable + if state.lower() in ("idle", "mixed", "allocated", "completing", "planned"): + partition_data[partition_name]["has_healthy_node"] = True + else: + partition_data[partition_name]["last_unhealthy_state"] = state + + # Build Partition objects + partitions = [] + for name, data in partition_data.items(): + # Partition state: "up" if any node is healthy, otherwise report + # the last observed unhealthy state + if data["has_healthy_node"]: + partition_state = "up" + else: + partition_state = data["last_unhealthy_state"] + + node_types = [ + NodeType( + cpus=cpus, + memory_mb=mem, + gpus=gpus, + gpu_type=gtype, + features=feats, + ) + for cpus, mem, gpus, gtype, feats in data["node_types"] + ] + # Sort node types by CPU count for deterministic ordering + node_types.sort(key=lambda n: (n.cpus, n.memory_mb, n.gpus)) + + partitions.append( + Partition( + name=name, + state=partition_state, + max_time=data["max_time"], + default_time=data["default_time"], + node_types=node_types, + total_nodes=data["node_count"], + is_default=data["is_default"], + ) + ) + + # Sort partitions: default first, then alphabetical + partitions.sort(key=lambda p: (not p.is_default, p.name)) + return partitions + + +def _parse_accounts(output: str) -> list[str]: + """Parse sacctmgr account output. + + Parameters + ---------- + output : str + Raw sacctmgr output (one account per line, parsable2 format). + + Returns + ------- + list of str + Unique account names, sorted. + """ + accounts = set() + for line in output.splitlines(): + account = line.strip() + if account: + accounts.add(account) + return sorted(accounts) + + +def _parse_qos(output: str) -> list[str]: + """Parse sacctmgr QOS output. + + Parameters + ---------- + output : str + Raw sacctmgr output (pipe-separated: Name|MaxWall|MaxTRES). + + Returns + ------- + list of str + QOS policy names, sorted. + """ + qos_names = set() + for raw_line in output.splitlines(): + line = raw_line.strip() + if not line: + continue + # Format: Name|MaxWall|MaxTRES + parts = line.split("|") + if parts and parts[0].strip(): + qos_names.add(parts[0].strip()) + return sorted(qos_names) + + +def _discover_partitions() -> list[Partition] | None: + """Query sinfo for partition and node information. + + Returns + ------- + list of Partition or None + Parsed partitions, or None if sinfo is unavailable. + """ + output = _run_command(["sinfo", "-N", "--noheader", "-o", "%P %n %c %m %G %f %l %L %T"]) + if output is None: + return None + return _parse_sinfo(output) + + +def _discover_accounts() -> list[str] | None: + """Query sacctmgr for the current user's accounts. + + Returns + ------- + list of str or None + Account names, or None if sacctmgr is unavailable. + """ + user = os.environ.get("USER", os.environ.get("LOGNAME", "")) + if not user: + return None + output = _run_command( + [ + "sacctmgr", + "show", + "assoc", + f"user={user}", + "format=Account", + "--noheader", + "--parsable2", + ] + ) + if output is None: + return None + return _parse_accounts(output) + + +def _discover_qos() -> list[str] | None: + """Query sacctmgr for available QOS policies. + + Returns + ------- + list of str or None + QOS names, or None if sacctmgr is unavailable. + """ + output = _run_command( + [ + "sacctmgr", + "show", + "qos", + "format=Name,MaxWall,MaxTRES", + "--noheader", + "--parsable2", + ] + ) + if output is None: + return None + return _parse_qos(output) + + +def _discover_default_account() -> str | None: + """Query sacctmgr for the current user's default SLURM account. + + Returns + ------- + str or None + The user's default account, or None if unavailable. + """ + user = os.environ.get("USER", os.environ.get("LOGNAME", "")) + if not user: + return None + output = _run_command( + [ + "sacctmgr", + "show", + "user", + user, + "format=DefaultAccount", + "--noheader", + "--parsable2", + ] + ) + if output is None: + return None + account = output.strip().splitlines()[0].strip() if output.strip() else None + return account if account else None + + +@functools.lru_cache(maxsize=1) +def discover_cluster() -> ClusterInfo | None: + """Query SLURM and return structured cluster information. + + Calls ``sinfo`` and ``sacctmgr`` to discover partitions, node types, + accounts, and QOS policies. Returns None gracefully when SLURM commands + are not available (e.g., running on a laptop). + + Results are cached for the session (cluster topology doesn't change + mid-session). Call ``discover_cluster.cache_clear()`` to force re-query. + + Returns + ------- + ClusterInfo or None + Structured cluster information, or None if SLURM is unavailable. + + Examples + -------- + >>> cluster = discover_cluster() + >>> if cluster is not None: + ... for p in cluster.partitions: + ... print(f"{p.name}: {p.total_nodes} nodes") + """ + # sinfo is the minimum requirement — if it's not available, we're not + # on a SLURM cluster + if shutil.which("sinfo") is None: + return None + + partitions = _discover_partitions() + if partitions is None: + return None + + # Accounts and QOS are best-effort (sacctmgr may be restricted) + accounts = _discover_accounts() or [] + qos_policies = _discover_qos() or [] + + # Query the real SLURM default account; fall back to first available + default_account = _discover_default_account() + if default_account is None and accounts: + default_account = accounts[0] + + return ClusterInfo( + partitions=partitions, + accounts=accounts, + qos_policies=qos_policies, + default_account=default_account, + ) + + +def select_partition( + cluster: ClusterInfo, + *, + needs_gpu: bool = False, + min_cpus: int = 1, + min_mem_gb: int = 1, +) -> Partition | None: + """Heuristic partition selection given resource requirements. + + Selects the best-matching partition from the cluster based on hardware + needs. Prefers partitions that are ``"up"`` and have nodes meeting the + specified requirements. + + Parameters + ---------- + cluster : ClusterInfo + Cluster information from ``discover_cluster()``. + needs_gpu : bool + If True, only consider partitions with GPU-equipped nodes. + min_cpus : int + Minimum CPUs per node required. + min_mem_gb : int + Minimum memory per node in GB. + + Returns + ------- + Partition or None + Best matching partition, or None if no partition meets requirements. + + Examples + -------- + >>> cluster = discover_cluster() + >>> gpu_partition = select_partition(cluster, needs_gpu=True, min_cpus=8) + """ + min_mem_mb = min_mem_gb * 1024 + candidates: list[Partition] = [] + + for partition in cluster.partitions: + # Skip partitions with no schedulable nodes + if partition.state.lower() != "up": + continue + + # Check if any node type meets requirements + has_qualifying_node = False + for node in partition.node_types: + if node.cpus < min_cpus: + continue + if node.memory_mb < min_mem_mb: + continue + if needs_gpu and node.gpus == 0: + continue + has_qualifying_node = True + break + + if has_qualifying_node: + candidates.append(partition) + + if not candidates: + return None + + # Prefer: default partition > most nodes > alphabetical + candidates.sort(key=lambda p: (not p.is_default, -p.total_nodes, p.name)) + return candidates[0] diff --git a/mdfactory/tests/test_cluster.py b/mdfactory/tests/test_cluster.py new file mode 100644 index 0000000..9dec77d --- /dev/null +++ b/mdfactory/tests/test_cluster.py @@ -0,0 +1,558 @@ +# ABOUTME: Unit tests for mdfactory.performance.cluster (SLURM autodiscovery). +# ABOUTME: Uses mocked sinfo/sacctmgr output — no SLURM required to run. +"""Tests for SLURM cluster autodiscovery.""" + +from __future__ import annotations + +import subprocess +from unittest.mock import patch + +import pytest + +from mdfactory.performance.cluster import ( + ClusterInfo, + NodeType, + Partition, + _parse_accounts, + _parse_gres, + _parse_qos, + _parse_sinfo, + _run_command, + discover_cluster, + select_partition, +) + +# --------------------------------------------------------------------------- +# Fixtures: realistic sinfo / sacctmgr output +# --------------------------------------------------------------------------- + +# 9-field format: Partition Node CPUs Mem GRES Features MaxTime DefTime State +SINFO_OUTPUT_MIXED = """\ +cpu* node001 128 512000 (null) epyc9555,avx512 3-00:00:00 1-00:00:00 idle +cpu* node002 128 512000 (null) epyc9555,avx512 3-00:00:00 1-00:00:00 mixed +cpu* node003 128 512000 (null) epyc9555,avx512 3-00:00:00 1-00:00:00 allocated +gpu node010 64 256000 gpu:a100:4 a100,nvlink 1-00:00:00 4:00:00 idle +gpu node011 64 256000 gpu:a100:4 a100,nvlink 1-00:00:00 4:00:00 idle +gpu node012 96 512000 gpu:h100:8 h100,nvlink 1-00:00:00 4:00:00 mixed +bigmem node020 256 2048000 (null) bigmem,epyc 7-00:00:00 1-00:00:00 idle +""" + +SINFO_OUTPUT_SINGLE_PARTITION = """\ +compute node001 64 128000 (null) (null) 2-00:00:00 1-00:00:00 idle +compute node002 64 128000 (null) (null) 2-00:00:00 1-00:00:00 idle +""" + +SINFO_OUTPUT_GPU_ONLY = """\ +gpu-short node001 32 64000 gpu:v100:2 v100 4:00:00 2:00:00 idle +gpu-short node002 32 64000 gpu:v100:2 v100 4:00:00 2:00:00 idle +gpu-long node003 64 128000 gpu:a100:4 a100 2-00:00:00 4:00:00 idle +""" + +SINFO_OUTPUT_NO_TYPE_GPU = """\ +gpu node001 64 256000 gpu:4 (null) 1-00:00:00 4:00:00 idle +""" + +# First node drained, rest healthy → partition should be "up" +SINFO_OUTPUT_MIXED_HEALTH = """\ +cpu node001 64 128000 (null) (null) 2-00:00:00 1-00:00:00 drained +cpu node002 64 128000 (null) (null) 2-00:00:00 1-00:00:00 idle +cpu node003 64 128000 (null) (null) 2-00:00:00 1-00:00:00 idle +""" + +# All nodes unhealthy → partition should report unhealthy state +SINFO_OUTPUT_ALL_DOWN = """\ +cpu node001 64 128000 (null) (null) 2-00:00:00 1-00:00:00 down +cpu node002 64 128000 (null) (null) 2-00:00:00 1-00:00:00 drained +""" + +# Legacy 8-field format (no %L default time) — backward compatibility +SINFO_OUTPUT_LEGACY_8FIELD = """\ +compute node001 64 128000 (null) (null) 2-00:00:00 idle +compute node002 64 128000 (null) (null) 2-00:00:00 idle +""" + +SACCTMGR_ACCOUNTS = """\ +myproject +shared-account +default-account +""" + +SACCTMGR_QOS = """\ +normal|| +high|1-00:00:00|cpu=128,mem=512G +gpu|12:00:00|cpu=64,gres/gpu=4 +""" + + +# --------------------------------------------------------------------------- +# Tests: GRES parsing +# --------------------------------------------------------------------------- + + +class TestParseGres: + """Test GPU GRES string parsing.""" + + def test_gpu_with_type_and_count(self): + assert _parse_gres("gpu:a100:4") == (4, "a100") + + def test_gpu_with_count_only(self): + assert _parse_gres("gpu:2") == (2, None) + + def test_gpu_with_type_only(self): + assert _parse_gres("gpu:h100") == (1, "h100") + + def test_null_gres(self): + assert _parse_gres("(null)") == (0, None) + + def test_empty_string(self): + assert _parse_gres("") == (0, None) + + def test_multi_gres_with_gpu(self): + assert _parse_gres("mps:shared,gpu:a100:4") == (4, "a100") + + def test_non_gpu_gres(self): + assert _parse_gres("mps:shared") == (0, None) + + +# --------------------------------------------------------------------------- +# Tests: sinfo parsing +# --------------------------------------------------------------------------- + + +class TestParseSinfo: + """Test sinfo output parsing into Partition objects.""" + + def test_mixed_cluster(self): + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + + # Should find 3 partitions + assert len(partitions) == 3 + names = [p.name for p in partitions] + assert "cpu" in names + assert "gpu" in names + assert "bigmem" in names + + def test_default_partition_marker(self): + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + + cpu_part = next(p for p in partitions if p.name == "cpu") + assert cpu_part.is_default is True + + gpu_part = next(p for p in partitions if p.name == "gpu") + assert gpu_part.is_default is False + + def test_default_partition_sorted_first(self): + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + assert partitions[0].name == "cpu" + assert partitions[0].is_default is True + + def test_cpu_partition_node_types(self): + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + cpu_part = next(p for p in partitions if p.name == "cpu") + + # All 3 nodes have same spec → 1 unique node type + assert len(cpu_part.node_types) == 1 + nt = cpu_part.node_types[0] + assert nt.cpus == 128 + assert nt.memory_mb == 512000 + assert nt.gpus == 0 + assert nt.gpu_type is None + assert "epyc9555" in nt.features + + def test_gpu_partition_multiple_node_types(self): + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + gpu_part = next(p for p in partitions if p.name == "gpu") + + # 2 distinct node types: a100 (64 core) and h100 (96 core) + assert len(gpu_part.node_types) == 2 + + a100_node = next(n for n in gpu_part.node_types if n.gpu_type == "a100") + assert a100_node.cpus == 64 + assert a100_node.gpus == 4 + + h100_node = next(n for n in gpu_part.node_types if n.gpu_type == "h100") + assert h100_node.cpus == 96 + assert h100_node.gpus == 8 + + def test_total_node_count(self): + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + + cpu_part = next(p for p in partitions if p.name == "cpu") + assert cpu_part.total_nodes == 3 + + gpu_part = next(p for p in partitions if p.name == "gpu") + assert gpu_part.total_nodes == 3 + + bigmem_part = next(p for p in partitions if p.name == "bigmem") + assert bigmem_part.total_nodes == 1 + + def test_time_limit_parsed(self): + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + cpu_part = next(p for p in partitions if p.name == "cpu") + assert cpu_part.max_time == "3-00:00:00" + + def test_single_partition(self): + partitions = _parse_sinfo(SINFO_OUTPUT_SINGLE_PARTITION) + assert len(partitions) == 1 + assert partitions[0].name == "compute" + assert partitions[0].total_nodes == 2 + + def test_gpu_without_type(self): + partitions = _parse_sinfo(SINFO_OUTPUT_NO_TYPE_GPU) + assert len(partitions) == 1 + nt = partitions[0].node_types[0] + assert nt.gpus == 4 + assert nt.gpu_type is None + + def test_default_time_parsed_separately(self): + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + cpu_part = next(p for p in partitions if p.name == "cpu") + assert cpu_part.max_time == "3-00:00:00" + assert cpu_part.default_time == "1-00:00:00" + + gpu_part = next(p for p in partitions if p.name == "gpu") + assert gpu_part.max_time == "1-00:00:00" + assert gpu_part.default_time == "4:00:00" + + def test_legacy_8field_format(self): + """Parser handles legacy 8-field sinfo output (no %L).""" + partitions = _parse_sinfo(SINFO_OUTPUT_LEGACY_8FIELD) + assert len(partitions) == 1 + assert partitions[0].name == "compute" + # default_time falls back to max_time + assert partitions[0].default_time == partitions[0].max_time + + def test_partition_state_up_when_any_node_healthy(self): + """Partition with mixed healthy/unhealthy nodes should be 'up'.""" + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED_HEALTH) + assert len(partitions) == 1 + assert partitions[0].state == "up" + + def test_partition_state_unhealthy_when_all_nodes_down(self): + """Partition with no healthy nodes reports unhealthy state.""" + partitions = _parse_sinfo(SINFO_OUTPUT_ALL_DOWN) + assert len(partitions) == 1 + assert partitions[0].state != "up" + assert partitions[0].state in ("down", "drained") + + def test_empty_output(self): + partitions = _parse_sinfo("") + assert partitions == [] + + def test_malformed_lines_skipped(self): + output = "this is not valid sinfo output\n" + SINFO_OUTPUT_SINGLE_PARTITION + partitions = _parse_sinfo(output) + # Should still parse valid lines + assert len(partitions) == 1 + + +# --------------------------------------------------------------------------- +# Tests: account / QOS parsing +# --------------------------------------------------------------------------- + + +class TestParseAccounts: + """Test sacctmgr account output parsing.""" + + def test_multiple_accounts(self): + accounts = _parse_accounts(SACCTMGR_ACCOUNTS) + assert accounts == ["default-account", "myproject", "shared-account"] + + def test_empty_output(self): + assert _parse_accounts("") == [] + + def test_whitespace_handling(self): + assert _parse_accounts(" acc1 \n acc2 \n") == ["acc1", "acc2"] + + def test_deduplication(self): + assert _parse_accounts("acc1\nacc1\nacc2") == ["acc1", "acc2"] + + +class TestParseQos: + """Test sacctmgr QOS output parsing.""" + + def test_multiple_qos(self): + qos = _parse_qos(SACCTMGR_QOS) + assert qos == ["gpu", "high", "normal"] + + def test_empty_output(self): + assert _parse_qos("") == [] + + def test_pipe_separated_format(self): + qos = _parse_qos("standard|2-00:00:00|cpu=256\n") + assert qos == ["standard"] + + +# --------------------------------------------------------------------------- +# Tests: discover_cluster integration (mocked subprocess) +# --------------------------------------------------------------------------- + + +class TestDiscoverCluster: + """Test discover_cluster with mocked subprocess calls.""" + + def setup_method(self): + """Clear LRU cache between tests.""" + discover_cluster.cache_clear() + + def test_returns_none_without_sinfo(self): + with patch("mdfactory.performance.cluster.shutil.which", return_value=None): + result = discover_cluster() + assert result is None + + def test_returns_cluster_info_with_sinfo(self): + with ( + patch("mdfactory.performance.cluster.shutil.which", return_value="/usr/bin/sinfo"), + patch( + "mdfactory.performance.cluster._run_command", + side_effect=[ + SINFO_OUTPUT_MIXED, # sinfo call + SACCTMGR_ACCOUNTS, # sacctmgr accounts + SACCTMGR_QOS, # sacctmgr qos + "myproject", # sacctmgr default account + ], + ), + ): + result = discover_cluster() + + assert result is not None + assert isinstance(result, ClusterInfo) + assert len(result.partitions) == 3 + assert len(result.accounts) == 3 + assert len(result.qos_policies) == 3 + assert result.default_account == "myproject" + + def test_default_account_falls_back_to_first(self): + """When default account query fails, fall back to first account.""" + with ( + patch("mdfactory.performance.cluster.shutil.which", return_value="/usr/bin/sinfo"), + patch( + "mdfactory.performance.cluster._run_command", + side_effect=[ + SINFO_OUTPUT_MIXED, # sinfo call + SACCTMGR_ACCOUNTS, # sacctmgr accounts + SACCTMGR_QOS, # sacctmgr qos + None, # sacctmgr default account fails + ], + ), + ): + result = discover_cluster() + + assert result is not None + # Falls back to first alphabetical account + assert result.default_account == "default-account" + + def test_graceful_without_sacctmgr(self): + """When sacctmgr fails, still return partitions.""" + with ( + patch("mdfactory.performance.cluster.shutil.which", return_value="/usr/bin/sinfo"), + patch( + "mdfactory.performance.cluster._run_command", + side_effect=[ + SINFO_OUTPUT_MIXED, # sinfo succeeds + None, # sacctmgr accounts fails + None, # sacctmgr qos fails + None, # sacctmgr default account fails + ], + ), + ): + result = discover_cluster() + + assert result is not None + assert len(result.partitions) == 3 + assert result.accounts == [] + assert result.qos_policies == [] + assert result.default_account is None + + def test_caching(self): + """Second call returns cached result without re-querying.""" + with ( + patch("mdfactory.performance.cluster.shutil.which", return_value="/usr/bin/sinfo"), + patch( + "mdfactory.performance.cluster._run_command", + side_effect=[ + SINFO_OUTPUT_SINGLE_PARTITION, + None, + None, + None, + ], + ) as mock_cmd, + ): + result1 = discover_cluster() + result2 = discover_cluster() + + assert result1 is result2 + # _run_command called 4 times for 1 discover (sinfo, accounts, qos, default_account) + assert mock_cmd.call_count == 4 + + def test_returns_none_when_sinfo_fails(self): + """If sinfo exists but returns error, return None.""" + with ( + patch("mdfactory.performance.cluster.shutil.which", return_value="/usr/bin/sinfo"), + patch( + "mdfactory.performance.cluster._run_command", + return_value=None, # sinfo call fails + ), + ): + result = discover_cluster() + assert result is None + + +# --------------------------------------------------------------------------- +# Tests: select_partition +# --------------------------------------------------------------------------- + + +class TestSelectPartition: + """Test heuristic partition selection.""" + + @pytest.fixture() + def cluster(self) -> ClusterInfo: + """Build a ClusterInfo from the mixed sinfo output.""" + partitions = _parse_sinfo(SINFO_OUTPUT_MIXED) + return ClusterInfo( + partitions=partitions, + accounts=["myproject"], + qos_policies=["normal"], + default_account="myproject", + ) + + def test_select_default_cpu_partition(self, cluster: ClusterInfo): + result = select_partition(cluster) + assert result is not None + assert result.name == "cpu" + + def test_select_gpu_partition(self, cluster: ClusterInfo): + result = select_partition(cluster, needs_gpu=True) + assert result is not None + assert result.name == "gpu" + + def test_select_with_high_cpu_requirement(self, cluster: ClusterInfo): + # Need 200+ CPUs → only bigmem (256) qualifies + result = select_partition(cluster, min_cpus=200) + assert result is not None + assert result.name == "bigmem" + + def test_select_with_high_memory_requirement(self, cluster: ClusterInfo): + # Need 1TB+ → only bigmem qualifies + result = select_partition(cluster, min_mem_gb=1500) + assert result is not None + assert result.name == "bigmem" + + def test_returns_none_when_impossible(self, cluster: ClusterInfo): + # Need 1000 CPUs — nobody has that + result = select_partition(cluster, min_cpus=1000) + assert result is None + + def test_returns_none_gpu_when_no_gpu_partition(self): + partitions = _parse_sinfo(SINFO_OUTPUT_SINGLE_PARTITION) + cluster = ClusterInfo(partitions=partitions) + result = select_partition(cluster, needs_gpu=True) + assert result is None + + def test_prefers_default_partition(self, cluster: ClusterInfo): + # Both cpu and bigmem meet min_cpus=1 — prefer cpu (default) + result = select_partition(cluster, min_cpus=1) + assert result is not None + assert result.name == "cpu" + + def test_gpu_with_min_cpus(self, cluster: ClusterInfo): + # Need GPU + 90 CPUs → only h100 node qualifies (96 cpus) + result = select_partition(cluster, needs_gpu=True, min_cpus=90) + assert result is not None + assert result.name == "gpu" + + def test_skips_down_partitions(self): + """Partitions where all nodes are down are not selectable.""" + partitions = _parse_sinfo(SINFO_OUTPUT_ALL_DOWN) + cluster = ClusterInfo(partitions=partitions) + result = select_partition(cluster) + assert result is None + + +# --------------------------------------------------------------------------- +# Tests: dataclass properties +# --------------------------------------------------------------------------- + + +class TestDataclasses: + """Test dataclass construction and immutability.""" + + def test_node_type_frozen(self): + nt = NodeType(cpus=64, memory_mb=256000, gpus=4, gpu_type="a100") + with pytest.raises(AttributeError): + nt.cpus = 128 # type: ignore[misc] + + def test_partition_frozen(self): + p = Partition(name="test", state="up", max_time="1-00:00:00", default_time="1:00:00") + with pytest.raises(AttributeError): + p.name = "other" # type: ignore[misc] + + def test_cluster_info_frozen(self): + ci = ClusterInfo() + with pytest.raises(AttributeError): + ci.default_account = "hack" # type: ignore[misc] + + def test_node_type_defaults(self): + nt = NodeType(cpus=32, memory_mb=64000) + assert nt.gpus == 0 + assert nt.gpu_type is None + assert nt.features == () + + def test_node_type_features_immutable(self): + nt = NodeType(cpus=64, memory_mb=256000, features=("a100", "nvlink")) + assert nt.features == ("a100", "nvlink") + with pytest.raises(TypeError): + nt.features[0] = "other" # type: ignore[index] + + def test_cluster_info_defaults(self): + ci = ClusterInfo() + assert ci.partitions == [] + assert ci.accounts == [] + assert ci.qos_policies == [] + assert ci.default_account is None + + +# --------------------------------------------------------------------------- +# Tests: _run_command edge cases +# --------------------------------------------------------------------------- + + +class TestRunCommand: + """Test _run_command timeout and error handling.""" + + def test_returns_none_on_timeout(self): + with patch( + "mdfactory.performance.cluster.subprocess.run", + side_effect=subprocess.TimeoutExpired(cmd=["sinfo"], timeout=30), + ): + result = _run_command(["sinfo", "--version"]) + assert result is None + + def test_returns_none_on_file_not_found(self): + with patch( + "mdfactory.performance.cluster.subprocess.run", + side_effect=FileNotFoundError("No such file"), + ): + result = _run_command(["nonexistent_binary"]) + assert result is None + + def test_returns_none_on_nonzero_exit(self): + with patch( + "mdfactory.performance.cluster.subprocess.run", + return_value=subprocess.CompletedProcess( + args=["sinfo"], returncode=1, stdout="", stderr="error" + ), + ): + result = _run_command(["sinfo", "--bad-flag"]) + assert result is None + + def test_returns_stdout_on_success(self): + with patch( + "mdfactory.performance.cluster.subprocess.run", + return_value=subprocess.CompletedProcess( + args=["echo"], returncode=0, stdout="hello\n", stderr="" + ), + ): + result = _run_command(["echo", "hello"]) + assert result == "hello"