diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 502309b..a3d5671 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -14,12 +14,13 @@ """Profiling Utilities.""" import asyncio +from collections.abc import Mapping import dataclasses import json import logging import os import threading -from typing import Any, Mapping +from typing import Any import urllib.parse import fastapi @@ -35,14 +36,17 @@ class _ProfileState: executable: plugin_executable.PluginExecutable | None = None + profile_request: Mapping[str, Any] | None = None lock: threading.Lock def __init__(self) -> None: self.executable = None + self.profile_request = None self.lock = threading.Lock() def reset(self) -> None: self.executable = None + self.profile_request = None _first_profile_start = True @@ -153,6 +157,7 @@ def _start_pathways_trace_from_profile_request( _profile_state.executable = plugin_executable.PluginExecutable( json.dumps({"profileRequest": profile_request}) ) + _profile_state.profile_request = profile_request try: _, result_future = _profile_state.executable.call() result_future.result() @@ -233,7 +238,19 @@ def stop_trace() -> None: if _profile_state.executable is None: raise ValueError("stop_trace called before a trace is being taken!") try: - _, result_future = _profile_state.executable.call() + if ( + _profile_state.profile_request is not None + and "xprofTraceOptions" in _profile_state.profile_request + ): + out_avals = [jax.core.ShapedArray((1,), jnp.object_)] + out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])] + else: + out_avals = () + out_shardings = () + + _, result_future = _profile_state.executable.call( + out_avals=out_avals, out_shardings=out_shardings + ) result_future.result() finally: _profile_state.reset() diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index 7d524f8..66e6d57 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -14,12 +14,12 @@ import json import logging -import unittest from unittest import mock from absl.testing import absltest from absl.testing import parameterized import jax +from jax import numpy as jnp from pathwaysutils import profiling import requests @@ -213,10 +213,13 @@ def test_lock_released_on_stop_failure(self): """Tests that the lock is released if stop_trace fails.""" profiling.start_trace("gs://test_bucket/test_dir3") self.assertFalse(profiling._profile_state.lock.locked()) - mock_result = ( - self.mock_plugin_executable_cls.return_value.call.return_value[1] + mock_result_fail = mock.MagicMock() + mock_result_fail.result.side_effect = RuntimeError("stop failed") + self.mock_plugin_executable_cls.return_value.call.return_value = ( + mock.MagicMock(), + mock_result_fail, ) - mock_result.result.side_effect = RuntimeError("stop failed") + self.mock_plugin_executable_cls.return_value.call.side_effect = None with self.assertRaisesRegex(RuntimeError, "stop failed"): profiling.stop_trace() self.assertFalse(profiling._profile_state.lock.locked()) @@ -277,6 +280,44 @@ def test_stop_trace_success(self): with self.subTest("executable_is_none"): self.assertIsNone(profiling._profile_state.executable) + @absltest.skipIf( + jax.version.__version_info__ < (0, 9, 2), + "ProfileOptions requires JAX 0.9.2 or newer", + ) + def test_stop_trace_with_xprof_options_passes_out_avals(self): + options = jax.profiler.ProfileOptions() + options.duration_ms = 2000 + + with mock.patch.object( + profiling, "_profile_state", autospec=True + ) as mock_profile_state: + request = profiling._create_profile_request( + "gs://test_bucket/test_dir", options + ) + mock_profile_state.profile_request = request + mock_profile_state.executable = ( + self.mock_plugin_executable_cls.return_value + ) + mock_profile_state.lock = mock.MagicMock() + mock_profile_state.lock.locked.return_value = True + mock_profile_state.lock.__enter__.return_value = None + mock_profile_state.lock.__exit__.return_value = None + + profiling.stop_trace() + + with self.subTest("plugin_executable_called"): + self.mock_plugin_executable_cls.return_value.call.assert_called_once() + _, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args + self.assertIn("out_avals", kwargs) + self.assertIn("out_shardings", kwargs) + + with self.subTest("out_avals_properties"): + _, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args + self.assertLen(kwargs["out_avals"], 1) + (out_aval,) = kwargs["out_avals"] + self.assertEqual(out_aval.shape, (1,)) + self.assertEqual(out_aval.dtype, jnp.object_) + def test_stop_trace_before_start_error(self): with self.assertRaisesRegex( ValueError, "stop_trace called before a trace is being taken!" @@ -406,7 +447,7 @@ def test_create_profile_request_default_options(self, profiler_options): }, ) - @unittest.skipIf( + @absltest.skipIf( jax.version.__version_info__ < (0, 9, 2), "ProfileOptions requires JAX 0.9.2 or newer", ) @@ -444,41 +485,45 @@ def test_create_profile_request_with_options(self): }, ) - @unittest.skipIf( + @absltest.skipIf( jax.version.__version_info__ < (0, 9, 2), "ProfileOptions requires JAX 0.9.2 or newer", ) @parameterized.parameters( ({"traceLocation": "gs://test_bucket/test_dir"},), - ({ - "traceLocation": "gs://test_bucket/test_dir", - "blockUntilStart": True, - "maxDurationSecs": 10.0, - "devices": {"deviceIds": [1, 2]}, - "includeResourceManagers": True, - "maxNumHosts": 5, - "xprofTraceOptions": { + ( + { + "traceLocation": "gs://test_bucket/test_dir", "blockUntilStart": True, - "traceDirectory": "gs://test_bucket/test_dir", + "maxDurationSecs": 10.0, + "devices": {"deviceIds": [1, 2]}, + "includeResourceManagers": True, + "maxNumHosts": 5, + "xprofTraceOptions": { + "blockUntilStart": True, + "traceDirectory": "gs://test_bucket/test_dir", + }, }, - },), - ({ - "traceLocation": "gs://bucket/dir", - "xprofTraceOptions": { - "hostTraceLevel": 0, - "traceOptions": { - "traceMode": "TRACE_COMPUTE", - "numSparseCoresToTrace": 1, - "numSparseCoreTilesToTrace": 2, - "numChipsToProfilePerTask": 3, - "powerTraceLevel": 4, - "enableFwThrottleEvent": True, - "enableFwPowerLevelEvent": True, - "enableFwThermalEvent": True, + ), + ( + { + "traceLocation": "gs://bucket/dir", + "xprofTraceOptions": { + "hostTraceLevel": 0, + "traceOptions": { + "traceMode": "TRACE_COMPUTE", + "numSparseCoresToTrace": 1, + "numSparseCoreTilesToTrace": 2, + "numChipsToProfilePerTask": 3, + "powerTraceLevel": 4, + "enableFwThrottleEvent": True, + "enableFwPowerLevelEvent": True, + "enableFwThermalEvent": True, + }, + "traceDirectory": "gs://bucket/dir", }, - "traceDirectory": "gs://bucket/dir", }, - },), + ), ) def test_start_pathways_trace_from_profile_request(self, profile_request): @@ -496,10 +541,9 @@ def test_original_stop_trace_called_on_stop_failure(self): """Tests that original_stop_trace is called if pathways stop_trace fails.""" profiling.start_trace("gs://test_bucket/test_dir") self.assertFalse(profiling._profile_state.lock.locked()) - mock_result = ( - self.mock_plugin_executable_cls.return_value.call.return_value[1] + self.mock_plugin_executable_cls.return_value.call.side_effect = ( + RuntimeError("stop failed") ) - mock_result.result.side_effect = RuntimeError("stop failed") with self.assertRaisesRegex(RuntimeError, "stop failed"): profiling.stop_trace() self.mock_original_stop_trace.assert_called_once()