diff --git a/mlpstorage_py/cluster_collector.py b/mlpstorage_py/cluster_collector.py index a4a1ea4e..bc1a12b9 100644 --- a/mlpstorage_py/cluster_collector.py +++ b/mlpstorage_py/cluster_collector.py @@ -1128,11 +1128,19 @@ def main(): json.dump(error_output, f, indent=2) sys.exit(1) - # Collect local info - local_info = collect_local_info() - local_info['mpi_rank'] = rank + # Collect local info — wrap in try/except so every rank always reaches + # comm.gather(); an early exit from any rank would deadlock all others. + try: + local_info = collect_local_info() + local_info['mpi_rank'] = rank + except Exception as e: + local_info = { + 'hostname': socket.gethostname(), + 'mpi_rank': rank, + '_collection_error': str(e), + } - # Gather all info to rank 0 + # Gather all info to rank 0 — every rank must reach this call all_info = comm.gather(local_info, root=0) if rank == 0: diff --git a/mlpstorage_py/utils.py b/mlpstorage_py/utils.py index 7c4c80b8..065f893a 100755 --- a/mlpstorage_py/utils.py +++ b/mlpstorage_py/utils.py @@ -387,9 +387,12 @@ def execute(self, ) for stream in readable: - line = stream.readline() - if not line: # EOF + # read1() returns whatever bytes are in the pipe buffer without + # blocking for '\n', preventing a hang on \r-terminated output. + raw = stream.buffer.read1(65536) + if not raw: # EOF continue + line = raw.decode('utf-8', errors='replace') if stream.fileno() == stdout_fd: stdout_buffer.write(line) @@ -402,20 +405,29 @@ def execute(self, sys.stderr.write(line) sys.stderr.flush() - # Read any remaining output - stdout_remainder = self.process.stdout.read() - if stdout_remainder: - stdout_buffer.write(stdout_remainder) - if print_stdout: - sys.stdout.write(stdout_remainder) - sys.stdout.flush() - - stderr_remainder = self.process.stderr.read() - if stderr_remainder: - stderr_buffer.write(stderr_remainder) - if print_stderr: - sys.stderr.write(stderr_remainder) - sys.stderr.flush() + # Drain any remaining output. TextIOWrapper.read() blocks until + # EOF, which never arrives if orphaned grandchild processes + # (e.g. PyTorch DataLoader workers forked inside a DLIO MPI rank + # *after* MPI_Init) still hold the pipe write-end open. + # select() + read1() with a short timeout avoids that hang: + # when the write-end is fully closed select() returns immediately + # (EOF is readable), so the normal path has no added latency. + for stream, buf, print_flag, sys_out in [ + (self.process.stdout, stdout_buffer, print_stdout, sys.stdout), + (self.process.stderr, stderr_buffer, print_stderr, sys.stderr), + ]: + while True: + ready, _, _ = select.select([stream], [], [], 0.5) + if not ready: + break + chunk = stream.buffer.read1(65536) + if not chunk: + break + text = chunk.decode('utf-8', errors='replace') + buf.write(text) + if print_flag: + sys_out.write(text) + sys_out.flush() # Get the return code return_code = self.process.poll() diff --git a/tests/unit/test_cluster_collector.py b/tests/unit/test_cluster_collector.py index b8042e0c..6feca9ed 100755 --- a/tests/unit/test_cluster_collector.py +++ b/tests/unit/test_cluster_collector.py @@ -1210,3 +1210,147 @@ def test_handles_unreachable_host_gracefully(self): if bad_host_samples: # If we got samples, they should have errors assert any('errors' in s for s in bad_host_samples) + + +class TestMPICollectorScriptMain: + """Tests for the main() function embedded in MPI_COLLECTOR_SCRIPT. + + Verifies that every rank always calls comm.gather() even when + collect_local_info() raises, preventing a deadlock on surviving ranks. + """ + + @staticmethod + def _load_script_ns(): + """Exec MPI_COLLECTOR_SCRIPT into a fresh namespace and return it.""" + from mlpstorage_py.cluster_collector import MPI_COLLECTOR_SCRIPT + ns = {'__name__': 'mlps_collector'} + exec(MPI_COLLECTOR_SCRIPT, ns) + return ns + + @staticmethod + def _mock_mpi(mock_comm): + """Return a sys.modules patch dict wiring mock_comm as MPI.COMM_WORLD.""" + mock_mpi = MagicMock() + mock_mpi.COMM_WORLD = mock_comm + mock_mpi4py = MagicMock() + mock_mpi4py.MPI = mock_mpi + return {'mpi4py': mock_mpi4py} + + def test_gather_called_on_successful_collection(self, tmp_path): + """Normal path: gather is called with local info dict including mpi_rank.""" + output_file = str(tmp_path / 'out.json') + mock_comm = MagicMock() + mock_comm.Get_rank.return_value = 1 + mock_comm.Get_size.return_value = 2 + mock_comm.gather.return_value = None + + ns = self._load_script_ns() + ns['collect_local_info'] = MagicMock(return_value={'hostname': 'node1'}) + + with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \ + patch('sys.argv', ['script', output_file]): + ns['main']() + + mock_comm.gather.assert_called_once() + gathered_info = mock_comm.gather.call_args[0][0] + assert gathered_info['hostname'] == 'node1' + assert gathered_info['mpi_rank'] == 1 + assert '_collection_error' not in gathered_info + + def test_gather_still_called_when_collection_raises(self, tmp_path): + """Error path: gather is called even when collect_local_info() raises.""" + output_file = str(tmp_path / 'out.json') + mock_comm = MagicMock() + mock_comm.Get_rank.return_value = 1 + mock_comm.Get_size.return_value = 2 + mock_comm.gather.return_value = None + + ns = self._load_script_ns() + ns['collect_local_info'] = MagicMock(side_effect=RuntimeError('disk read failed')) + + with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \ + patch('sys.argv', ['script', output_file]): + ns['main']() + + mock_comm.gather.assert_called_once() + + def test_sentinel_has_collection_error_key(self, tmp_path): + """Error sentinel must contain _collection_error so callers can detect failures.""" + output_file = str(tmp_path / 'out.json') + mock_comm = MagicMock() + mock_comm.Get_rank.return_value = 1 + mock_comm.Get_size.return_value = 2 + mock_comm.gather.return_value = None + + ns = self._load_script_ns() + ns['collect_local_info'] = MagicMock(side_effect=RuntimeError('disk read failed')) + + with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \ + patch('sys.argv', ['script', output_file]): + ns['main']() + + gathered_info = mock_comm.gather.call_args[0][0] + assert '_collection_error' in gathered_info + assert 'disk read failed' in gathered_info['_collection_error'] + + def test_sentinel_has_hostname_and_rank(self, tmp_path): + """Error sentinel must carry hostname and mpi_rank so rank 0 can identify the source.""" + output_file = str(tmp_path / 'out.json') + mock_comm = MagicMock() + mock_comm.Get_rank.return_value = 2 + mock_comm.Get_size.return_value = 4 + mock_comm.gather.return_value = None + + ns = self._load_script_ns() + ns['collect_local_info'] = MagicMock(side_effect=OSError('permission denied')) + + with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \ + patch('sys.argv', ['script', output_file]): + ns['main']() + + gathered_info = mock_comm.gather.call_args[0][0] + assert gathered_info['mpi_rank'] == 2 + assert 'hostname' in gathered_info + + def test_rank_zero_writes_output_file(self, tmp_path): + """Rank 0 writes the JSON output file when collection succeeds.""" + output_file = str(tmp_path / 'out.json') + mock_comm = MagicMock() + mock_comm.Get_rank.return_value = 0 + mock_comm.Get_size.return_value = 1 + mock_comm.gather.return_value = [{'hostname': 'node0', 'mpi_rank': 0}] + + ns = self._load_script_ns() + ns['collect_local_info'] = MagicMock(return_value={'hostname': 'node0'}) + + with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \ + patch('sys.argv', ['script', output_file]): + ns['main']() + + with open(output_file) as f: + data = json.load(f) + assert 'node0' in data + + def test_rank_zero_writes_output_when_another_rank_sent_sentinel(self, tmp_path): + """Rank 0 writes JSON even when another rank's payload is an error sentinel.""" + output_file = str(tmp_path / 'out.json') + mock_comm = MagicMock() + mock_comm.Get_rank.return_value = 0 + mock_comm.Get_size.return_value = 2 + mock_comm.gather.return_value = [ + {'hostname': 'node0', 'mpi_rank': 0}, + {'hostname': 'node1', 'mpi_rank': 1, '_collection_error': 'disk read failed'}, + ] + + ns = self._load_script_ns() + ns['collect_local_info'] = MagicMock(return_value={'hostname': 'node0'}) + + with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \ + patch('sys.argv', ['script', output_file]): + ns['main']() + + with open(output_file) as f: + data = json.load(f) + assert 'node0' in data + assert 'node1' in data + assert '_collection_error' in data['node1']