Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions mlpstorage_py/cluster_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,11 +1128,19 @@ def main():
json.dump(error_output, f, indent=2)
sys.exit(1)

# Collect local info
local_info = collect_local_info()
local_info['mpi_rank'] = rank
# Collect local info — wrap in try/except so every rank always reaches
# comm.gather(); an early exit from any rank would deadlock all others.
try:
local_info = collect_local_info()
local_info['mpi_rank'] = rank
except Exception as e:
local_info = {
'hostname': socket.gethostname(),
'mpi_rank': rank,
'_collection_error': str(e),
}

# Gather all info to rank 0
# Gather all info to rank 0 — every rank must reach this call
all_info = comm.gather(local_info, root=0)

if rank == 0:
Expand Down
44 changes: 28 additions & 16 deletions mlpstorage_py/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,12 @@ def execute(self,
)

for stream in readable:
line = stream.readline()
if not line: # EOF
# read1() returns whatever bytes are in the pipe buffer without
# blocking for '\n', preventing a hang on \r-terminated output.
raw = stream.buffer.read1(65536)
if not raw: # EOF
continue
line = raw.decode('utf-8', errors='replace')

if stream.fileno() == stdout_fd:
stdout_buffer.write(line)
Expand All @@ -402,20 +405,29 @@ def execute(self,
sys.stderr.write(line)
sys.stderr.flush()

# Read any remaining output
stdout_remainder = self.process.stdout.read()
if stdout_remainder:
stdout_buffer.write(stdout_remainder)
if print_stdout:
sys.stdout.write(stdout_remainder)
sys.stdout.flush()

stderr_remainder = self.process.stderr.read()
if stderr_remainder:
stderr_buffer.write(stderr_remainder)
if print_stderr:
sys.stderr.write(stderr_remainder)
sys.stderr.flush()
# Drain any remaining output. TextIOWrapper.read() blocks until
# EOF, which never arrives if orphaned grandchild processes
# (e.g. PyTorch DataLoader workers forked inside a DLIO MPI rank
# *after* MPI_Init) still hold the pipe write-end open.
# select() + read1() with a short timeout avoids that hang:
# when the write-end is fully closed select() returns immediately
# (EOF is readable), so the normal path has no added latency.
for stream, buf, print_flag, sys_out in [
(self.process.stdout, stdout_buffer, print_stdout, sys.stdout),
(self.process.stderr, stderr_buffer, print_stderr, sys.stderr),
]:
while True:
ready, _, _ = select.select([stream], [], [], 0.5)
if not ready:
break
chunk = stream.buffer.read1(65536)
if not chunk:
break
text = chunk.decode('utf-8', errors='replace')
buf.write(text)
if print_flag:
sys_out.write(text)
sys_out.flush()

# Get the return code
return_code = self.process.poll()
Expand Down
144 changes: 144 additions & 0 deletions tests/unit/test_cluster_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,3 +1210,147 @@ def test_handles_unreachable_host_gracefully(self):
if bad_host_samples:
# If we got samples, they should have errors
assert any('errors' in s for s in bad_host_samples)


class TestMPICollectorScriptMain:
"""Tests for the main() function embedded in MPI_COLLECTOR_SCRIPT.

Verifies that every rank always calls comm.gather() even when
collect_local_info() raises, preventing a deadlock on surviving ranks.
"""

@staticmethod
def _load_script_ns():
"""Exec MPI_COLLECTOR_SCRIPT into a fresh namespace and return it."""
from mlpstorage_py.cluster_collector import MPI_COLLECTOR_SCRIPT
ns = {'__name__': 'mlps_collector'}
exec(MPI_COLLECTOR_SCRIPT, ns)
return ns

@staticmethod
def _mock_mpi(mock_comm):
"""Return a sys.modules patch dict wiring mock_comm as MPI.COMM_WORLD."""
mock_mpi = MagicMock()
mock_mpi.COMM_WORLD = mock_comm
mock_mpi4py = MagicMock()
mock_mpi4py.MPI = mock_mpi
return {'mpi4py': mock_mpi4py}

def test_gather_called_on_successful_collection(self, tmp_path):
"""Normal path: gather is called with local info dict including mpi_rank."""
output_file = str(tmp_path / 'out.json')
mock_comm = MagicMock()
mock_comm.Get_rank.return_value = 1
mock_comm.Get_size.return_value = 2
mock_comm.gather.return_value = None

ns = self._load_script_ns()
ns['collect_local_info'] = MagicMock(return_value={'hostname': 'node1'})

with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \
patch('sys.argv', ['script', output_file]):
ns['main']()

mock_comm.gather.assert_called_once()
gathered_info = mock_comm.gather.call_args[0][0]
assert gathered_info['hostname'] == 'node1'
assert gathered_info['mpi_rank'] == 1
assert '_collection_error' not in gathered_info

def test_gather_still_called_when_collection_raises(self, tmp_path):
"""Error path: gather is called even when collect_local_info() raises."""
output_file = str(tmp_path / 'out.json')
mock_comm = MagicMock()
mock_comm.Get_rank.return_value = 1
mock_comm.Get_size.return_value = 2
mock_comm.gather.return_value = None

ns = self._load_script_ns()
ns['collect_local_info'] = MagicMock(side_effect=RuntimeError('disk read failed'))

with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \
patch('sys.argv', ['script', output_file]):
ns['main']()

mock_comm.gather.assert_called_once()

def test_sentinel_has_collection_error_key(self, tmp_path):
"""Error sentinel must contain _collection_error so callers can detect failures."""
output_file = str(tmp_path / 'out.json')
mock_comm = MagicMock()
mock_comm.Get_rank.return_value = 1
mock_comm.Get_size.return_value = 2
mock_comm.gather.return_value = None

ns = self._load_script_ns()
ns['collect_local_info'] = MagicMock(side_effect=RuntimeError('disk read failed'))

with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \
patch('sys.argv', ['script', output_file]):
ns['main']()

gathered_info = mock_comm.gather.call_args[0][0]
assert '_collection_error' in gathered_info
assert 'disk read failed' in gathered_info['_collection_error']

def test_sentinel_has_hostname_and_rank(self, tmp_path):
"""Error sentinel must carry hostname and mpi_rank so rank 0 can identify the source."""
output_file = str(tmp_path / 'out.json')
mock_comm = MagicMock()
mock_comm.Get_rank.return_value = 2
mock_comm.Get_size.return_value = 4
mock_comm.gather.return_value = None

ns = self._load_script_ns()
ns['collect_local_info'] = MagicMock(side_effect=OSError('permission denied'))

with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \
patch('sys.argv', ['script', output_file]):
ns['main']()

gathered_info = mock_comm.gather.call_args[0][0]
assert gathered_info['mpi_rank'] == 2
assert 'hostname' in gathered_info

def test_rank_zero_writes_output_file(self, tmp_path):
"""Rank 0 writes the JSON output file when collection succeeds."""
output_file = str(tmp_path / 'out.json')
mock_comm = MagicMock()
mock_comm.Get_rank.return_value = 0
mock_comm.Get_size.return_value = 1
mock_comm.gather.return_value = [{'hostname': 'node0', 'mpi_rank': 0}]

ns = self._load_script_ns()
ns['collect_local_info'] = MagicMock(return_value={'hostname': 'node0'})

with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \
patch('sys.argv', ['script', output_file]):
ns['main']()

with open(output_file) as f:
data = json.load(f)
assert 'node0' in data

def test_rank_zero_writes_output_when_another_rank_sent_sentinel(self, tmp_path):
"""Rank 0 writes JSON even when another rank's payload is an error sentinel."""
output_file = str(tmp_path / 'out.json')
mock_comm = MagicMock()
mock_comm.Get_rank.return_value = 0
mock_comm.Get_size.return_value = 2
mock_comm.gather.return_value = [
{'hostname': 'node0', 'mpi_rank': 0},
{'hostname': 'node1', 'mpi_rank': 1, '_collection_error': 'disk read failed'},
]

ns = self._load_script_ns()
ns['collect_local_info'] = MagicMock(return_value={'hostname': 'node0'})

with patch.dict('sys.modules', self._mock_mpi(mock_comm)), \
patch('sys.argv', ['script', output_file]):
ns['main']()

with open(output_file) as f:
data = json.load(f)
assert 'node0' in data
assert 'node1' in data
assert '_collection_error' in data['node1']
Loading