diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index effecbd4..cc6f9a80 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -384,36 +384,48 @@ def validate(self): self.logger.warning( f"Running DLIO with {self.read_threads} threads for I/O but core available {cores_available} " f"are insufficient and can lead to lower performance.") - # Memory budget guard: spawned worker processes must not exhaust system RAM. - # Each worker loads Python + framework + reader libraries (~512 MB RSS minimum). - # The hard cap is 32 GB so these benchmarks run on any compliant system. - # This check runs on all ranks so every rank refuses before workers are spawned. + # Memory budget guard: spawned worker processes must not exhaust this + # node's RAM. Worker counts are scoped to *this* node — not the world + # communicator — because RAM is a per-node resource. Prior code multiplied + # by self.comm_size and compared against one node's RAM, which produced + # false positives at multi-node scale (issue mlcommons/storage#448: + # the suggested max_threads collapsed to 0 at 100+ nodes even though the + # per-node load was well under capacity). if self.read_threads > 0 and self.data_loader in [ DataLoaderType.PYTORCH, DataLoaderType.DALI ]: + import platform + import socket import psutil - total_workers = self.read_threads * self.comm_size - # 512 MB per spawned worker is the minimum observed RSS (framework imports only). + ranks_per_node = DLIOMPI.get_instance().ranks_per_node() + local_workers = self.read_threads * ranks_per_node + # 512 MB per spawned worker is the minimum observed RSS (framework + # imports only). Compare against psutil.virtual_memory().available + # with a 90% safety margin so already-used RAM is respected. per_worker_mb = 512 - # Use actual installed RAM so large machines aren't blocked (#372). - # Spawning more workers than can fit in RAM is still an error. - BUDGET_MB = psutil.virtual_memory().total // (1024 * 1024) - estimated_mb = per_worker_mb * total_workers + vm = psutil.virtual_memory() + available_mb = vm.available // (1024 * 1024) + BUDGET_MB = int(available_mb * 0.9) + estimated_mb = per_worker_mb * local_workers + hostname = socket.gethostname() if estimated_mb > BUDGET_MB: - max_threads = BUDGET_MB // per_worker_mb // max(1, self.comm_size) + max_threads = BUDGET_MB // per_worker_mb // max(1, ranks_per_node) raise Exception( - f"Memory budget exceeded: reader.read_threads={self.read_threads} " - f"x comm_size={self.comm_size} = {total_workers} worker processes, " - f"estimated ~{estimated_mb // 1024} GB (host RAM: {BUDGET_MB // 1024} GB). " - f"Reduce reader.read_threads to at most {max_threads} for this run." + f"Per-node memory budget exceeded on host {hostname}: " + f"reader.read_threads={self.read_threads} x local_ranks=" + f"{ranks_per_node} = {local_workers} worker processes, " + f"estimated ~{estimated_mb // 1024} GB " + f"(available RAM on this node: {available_mb // 1024} GB; " + f"total: {vm.total // (1024**3)} GB). " + f"Reduce reader.read_threads to at most {max_threads}." ) - # Also warn if estimated usage exceeds 50% of available RAM on this machine - available_mb = psutil.virtual_memory().available // (1024 * 1024) + # Also warn if estimated usage exceeds 50% of available RAM on this node if estimated_mb > available_mb * 0.5: self.logger.warning( - f"reader.read_threads={self.read_threads} x comm_size={self.comm_size} " - f"= {total_workers} workers, estimated ~{estimated_mb // 1024} GB — " - f"exceeds 50% of available RAM ({available_mb // 1024} GB). " + f"On host {hostname}: reader.read_threads={self.read_threads} " + f"x local_ranks={ranks_per_node} = {local_workers} workers, " + f"estimated ~{estimated_mb // 1024} GB — exceeds 50% of " + f"available RAM on this node ({available_mb // 1024} GB). " f"Consider reducing read_threads to avoid OOM." ) if self.num_layers > 0 and self.num_layers < self.pipeline_parallelism: