Cherry-pick -adaptation: [ROCm] Fix profiler leaking stale hipErrorInvalidDevice#795
Cherry-pick -adaptation: [ROCm] Fix profiler leaking stale hipErrorInvalidDevice#795magaonka-amd wants to merge 3 commits intoROCm:rocm-jaxlib-v0.8.2from
Conversation
Adapted from upstream openxla/xla@01a20b15d6 (PR openxla#40199) for the rocm-jaxlib-v0.8.2 profiler code structure (rocm_profiler_sdk.h/cc instead of rocm_tracer.h/cc). - Replace hipGetDeviceProperties in GetDeviceCapabilities with rocprofiler agent data already available from RocmTracer. This eliminates HIP runtime calls that fail for non-visible devices when ROCR_VISIBLE_DEVICES restricts GPU visibility (e.g. pytest-xdist workers). - Since ROCm 7, hipGetLastError() is sticky: it retains errors even after subsequent successful HIP API calls. The stale hipErrorInvalidDevice from the profiler leaked into unrelated GPU operations, causing flaky test failures in JAX FFI handlers. - Agent clock rates are in MHz (vs KHz in hipDeviceProp_t); memory is filtered to VRAM-only banks. - Add unit test verifying agent data matches hipGetDeviceProperties.
fe708ab to
4f2532c
Compare
|
@magaonka-amd I think 0.8.2 still allows building with roctracer support. Not sure if this builds if that is turned on. @gulsumgudukbay Do we still need building with roctracer? I've noticed 0.9.0 does not seem to have that support. |
The preceding commit added RocmTracer::GpuAgents() and a matching test, both of which only exist on the rocprofiler-sdk (v3) tracer backend. The roctracer (v1) backend (--define=xla_rocm_profiler=v1) does not expose these APIs, breaking //pjrt/tools:build_gpu_plugin_wheel. Guard the v3-only call in device_tracer_rocm.cc and the v3-only tests in rocm_tracer_test.cc with the XLA_GPU_ROCM_TRACER_BACKEND macro. SetGpuAgents() is a no-op on the base collector, so the v1 path preserves pre-cherrypick behavior.
|
Dragan thanks for pointing out I added guard for v3 and v1 because rocm_trace_collector_->SetGpuAgents only works for v3. |
0.9.1 uses v1 (see rocm-jax/jax_rocm_plugin/build/rocm/tools/build_wheels.py line 177 ("--bazel_options=--define=xla_rocm_profiler=v1"). We did not update to v3 as it was crashing. |
Adapted from upstream openxla/xla@01a20b15d6 (PR openxla#40199) for the rocm-jaxlib-v0.8.2 profiler code structure (rocm_profiler_sdk.h/cc instead of rocm_tracer.h/cc).