Skip to content

Skip command buffer re-tracing via in-place HIP graph node patching#785

Open
phambinhfin wants to merge 9 commits intomainfrom
phambinh/skip-retrace-kernel-node-patching
Open

Skip command buffer re-tracing via in-place HIP graph node patching#785
phambinhfin wants to merge 9 commits intomainfrom
phambinh/skip-retrace-kernel-node-patching

Conversation

@phambinhfin
Copy link
Copy Markdown

@phambinhfin phambinhfin commented Apr 7, 2026

Problem

When XLA's BFC (Best-Fit with Coalescing) memory allocator reassigns device buffers between steps, command buffers must be updated with the new addresses. The current path re-records the entire command buffer from scratch:

alloc changed → Record() → Finalize() → Instantiate() → Submit()

This full re-record walks every command in the sequence, calls HIP APIs to recreate each graph node, finalizes the graph, and re-instantiates the executable. For large models with hundreds of graph nodes, this overhead is significant and happens on every training step because the BFC allocator frequently moves buffers.

Additionally, collectives (AllReduce, ReduceScatter, etc.) suffer a severe regression when captured into HIP graphs. CollectiveCmd::RecordTracedCommand was calling TraceCommandBufferFactory::Create on every Record invocation — performing a full hipStreamBeginCapture → RCCL → hipStreamEndCapture cycle each time. For non-power-of-2 element counts, this stream capture takes ~250ms due to RCCL's internal polling protocol, making COLLECTIVES in command buffers 60-150x slower than direct execution.

Idea

1. Fast kernel node address patching

Instead of re-recording the entire graph, directly patch the device addresses inside the existing executable graph (hipGraphExec_t):

  1. Switch kernel nodes from kernelParams-style to extra-style argument passing. Instead of hipKernelNodeParams.kernelParams (array of pointers to each argument), we pack all arguments into a single contiguous buffer and use HIP_LAUNCH_PARAM_BUFFER_POINTER via the extra field. This gives us an owned, mutable buffer containing all the device pointers for each kernel.

  2. Patch addresses in-place. When allocations change, scan the owned arg buffers for old device addresses and replace them with new ones. Then call hipGraphExecKernelNodeSetParams to push the updated args to the executable graph — no re-recording, no re-instantiation.

  3. Cache node handles and params at creation time. Each CreateKernelNode call stores the hipGraphNode_t handle, the hipKernelNodeParams, and the owned arg buffer. This eliminates the need to call hipGraphGetNodes + hipGraphNodeGetType + hipGraphKernelNodeGetParams during the update — we iterate our cached list directly.

Why extra-style?

The kernelParams approach has a known HIP bug (ROCm/clr#138): hipGraphKernelNodeGetParams returns dangling pointers for kernelParams-style nodes, making it impossible to read or patch argument values from the graph. By switching to extra-style with HIP_LAUNCH_PARAM_BUFFER_POINTER, we own the argument buffer and can safely read/modify it.

Update flow comparison

Before (full re-record, ~200us per command buffer):

for each command:
    CreateKernelNode() or UpdateKernelNode()  ← HIP API call per node
Finalize()
Instantiate() / hipGraphExecUpdate()

After (in-place patch, ~165us and improving with optimization):

for each owned kernel node:
    scan arg buffer for changed addresses     ← pure memcpy, no HIP API
    if modified: hipGraphExecKernelNodeSetParams()  ← only for changed nodes

2. Collective trace cache (fixes COLLECTIVES regression)

CollectiveCmd::RecordTracedCommand was calling TraceCommandBufferFactory::Create on every record — doing a full hipStreamBeginCapture → RCCL ncclAllReduce → hipStreamEndCapture each time. Unlike TracedCommandBufferCmd (used by GEMM, CublasLt, etc.) which caches traced graphs via TracedCommandBuffer, collectives had no cache at all.

Fix: Use the same TracedCommandBuffer cache for CollectiveCmd. The cache stores traced command buffers keyed by buffer addresses. On cache hit, the previously-traced RCCL graph is reused directly. On miss, it traces once and caches.

This is the same caching mechanism that VMM/VA-remapping achieves implicitly (addresses never change → graph never re-traced), but works without requiring VMM support.

What this PR contains

Core changes

  • rocm_command_buffer.cc / CreateKernelNode: Packs kernel arguments into an owned contiguous buffer using extra-style (HIP_LAUNCH_PARAM_BUFFER_POINTER) instead of kernelParams. Caches the node handle and params in OwnedKernelNode.

  • rocm_command_buffer.cc / UpdateKernelNode: Also uses extra-style for consistency with CreateKernelNode.

  • rocm_command_buffer.cc / UpdateNodeAddresses: New fast-update path. Iterates cached kernel_nodes_, patches owned arg buffers, calls hipGraphExecKernelNodeSetParams per modified node. Also handles memcpy/memset nodes via graph walk.

  • rocm_command_buffer.h: Adds OwnedKernelNode struct (node handle + cached params + owned arg buffer), has_kernelparams_nodes_ safety flag, SupportsNodeAddressUpdate(), UpdateNodeAddresses().

  • command_buffer_thunk.cc: Adds the fast-update decision logic in ExecuteOnStream. When XLA_GPU_GRAPH_FAST_UPDATE=1, attempts UpdateNodeAddresses before falling back to full Record. Manages prev_allocs buffer (mmap-backed) for tracking old addresses.

  • command_buffer_cmd.cc / CollectiveCmd::RecordTracedCommand: Replaced uncached TraceCommandBufferFactory::Create with TracedCommandBuffer cache (same pattern as TracedCommandBufferCmd). Eliminates the 250ms per-record RCCL re-capture.

Supporting changes

  • HIP graph node flattening: FlattenChildGraphNodes, UpdateFlattenedChildNodes, BuildPatchTable, PatchFlattenedNodes — flatten child graphs into the parent for better patching coverage.
  • UpdateKernelNodes: In-place kernel node patching for traced command buffers.
  • DumpGraphKernelNodes: Debug utility for inspecting graph node state.
  • Profiling: XLA_PROFILE_CMD_BUFFER=1 env var enables detailed timing logs.

Profiling results

Hardware: 8x AMD Instinct MI308X (gfx942)
Workload: 8-layer MLP, 4096 input dim, 41.9M params/replica

A. Single-device (jit) — kernel node patching

Baseline (FAST_UPDATE=0) Fast Update (FAST_UPDATE=1)
Median step time 19.55 ms 19.47 ms
Re-traces per step ~1 (full Record) 0 (patch only)
Update overhead 220-300 us 200-260 us

B. Multi-device (pmap 8 GPU) — without COLLECTIVES

Baseline (FAST_UPDATE=0) Fast Update (FAST_UPDATE=1)
Median step time 8.10 ms 7.95 ms

C. Multi-device (pmap 8 GPU) — with COLLECTIVES (the main fix)

Before this PR After this PR
FAST_UPDATE=0 551.70 ms 8.66 ms (63x faster)
FAST_UPDATE=1 547.54 ms 8.22 ms (67x faster)

D. Standalone allreduce (non-power-of-2 elements, graph-captured)

Elements Before After Speedup
1,000 254 ms 1.63 ms 156x
1,048,575 (2^20 - 1) 250 ms 1.62 ms 154x
1,048,577 (2^20 + 1) 253 ms 1.62 ms 156x
5,000,000 250 ms 3.06 ms 82x

Root cause analysis

The 250ms regression was caused by hipStreamBeginCapture → RCCL → hipStreamEndCapture being repeated on every graph record. For non-power-of-2 allreduce sizes, RCCL's kernel protocol involves internal polling that takes ~250ms during stream capture. With the trace cache, this capture happens only once; subsequent records reuse the cached graph.

Power-of-2 sizes were fast even without caching because RCCL uses a simpler single-pass MSCCL algorithm for those sizes (the captured graph already completes quickly). With the cache, all sizes are uniformly fast.

Correctness

All configurations produce identical final loss values:

Without COLLECTIVES:  loss = 0.834596
With COLLECTIVES:     loss = 0.834596 (both FAST_UPDATE=0 and FAST_UPDATE=1)

How to enable

# Fast kernel node patching (optional, small additional speedup)
export XLA_GPU_GRAPH_FAST_UPDATE=1

# Enable COLLECTIVES in command buffers (now safe and fast)
export XLA_FLAGS="--xla_gpu_graph_min_graph_size=1 --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUBLASLT,CUSTOM_CALL,COLLECTIVES"

# (optional) profiling diagnostics
export XLA_PROFILE_CMD_BUFFER=1

The collective trace cache is always active (no flag needed). The fast-update path (XLA_GPU_GRAPH_FAST_UPDATE=1) automatically falls back to full Record when:

  • First execution (no previous allocations to diff against)
  • SupportsNodeAddressUpdate() returns false
  • UpdateNodeAddresses fails for any reason

Remove two overly-conservative checks from rocm_command_buffer.cc
(ported from phambinh/full_vmm_solution branch):

1. Trace(): Remove rejection of empty traced HIP graphs, which caused
   segfaults when custom calls don't launch GPU ops.
2. PrepareFinalization(): Remove empty node insertion into empty graphs,
   which could crash with conditional nodes.

Baseline for command buffer performance profiling on main branch.
Flatten traced child graph nodes into the parent HIP graph as
individual kernel/memcpy/memset nodes instead of embedding them as
child graph nodes.  This enables per-node parameter updates via
hipGraphExecKernelNodeSetParams (~1 us/node) instead of full child
graph re-tracing (~60-70 us per sub-graph), providing up to 14.6x
faster update throughput.

Key changes:
- Add FlattenChildGraphNodes/UpdateFlattenedChildNodes virtual
  methods to GpuCommandBuffer with ROCm implementation
- Add GpuFlattenedCommand type to track flattened node handles
- Modify RecordTracedCommand to use flattening path when
  xla_gpu_graph_enable_node_flattening flag is set
- Add HIP graph introspection wrappers (hipGraphGetEdges,
  hipGraphNodeGetDependencies, hipGraphKernelNodeGetParams, etc.)
- Move AppendCommand to protected for subclass access
- Add xla_gpu_graph_enable_node_flattening proto flag (field 502)

Activate via: --xla_gpu_graph_enable_node_flattening=true
- Add GpuFlattenedCommand handling in ToGraphNodeDependencies to prevent
  crash when flattened commands are used as dependencies
- Add graceful fallback to child graph path when kernel nodes use HIP's
  opaque `extra` arg-packing (hipGraphKernelNodeGetParams returns
  kernelParams=null, extra=non-null), which can't be re-packed for a
  new graph node
- Register xla_gpu_graph_enable_node_flattening flag in
  debug_options_flags.cc
- Add XLA_GPU_GRAPH_ENABLE_NODE_FLATTENING env var as alternative
  activation mechanism when jaxlib doesn't have the proto flag

Known limitation: ROCm/HIP stores all kernel args internally via the
opaque `extra` mechanism. hipGraphKernelNodeGetParams does not populate
`kernelParams`, so individual kernel node flattening cannot re-pack
args for a new parent graph. All traced commands currently fall back to
the child graph path on ROCm.
…al format

hipGraphKernelNodeGetParams on ROCm 7.2 returns kernel arguments via
the internal `extra` pointer (not as HIP_LAUNCH_PARAM_* arrays and
not as kernelParams) for kernels captured via hipModuleLaunchKernel.
This internal format cannot be used with hipGraphAddKernelNode to
create equivalent nodes in a different graph context.

Changes:
- Detect kernels with opaque extra args (kernelParams=null,
  extra=non-null) and return InternalError to trigger graceful
  fallback to the standard child-graph path
- Improve fallback in RecordTracedCommand to catch any error status
  (not just kUnimplemented) and fall back to CreateChildCommand
- Add VLOG diagnostics for kernel node parameters during flattening

C++ benchmarks confirm the HIP graph node update APIs work correctly
when kernelParams is properly populated (14.6x faster than re-tracing).
The limitation is specifically in how hipGraphKernelNodeGetParams
returns arguments for module-loaded kernels on ROCm 7.2.
Status: Work in progress - functional and correct, performance optimization ongoing

== Problem ==
XLA's TracedCommandBuffer re-traces HIP graphs (via hipStreamBeginCapture/
hipStreamEndCapture) every time buffer addresses change due to BFC allocator
reassignment. This causes severe performance degradation especially with
command buffers enabled for GEMM/CublasLt operations.

== Approach ==
Instead of re-tracing, patch kernel nodes directly in the cached child graph
using hipGraphKernelNodeSetParams. This avoids the expensive stream capture
entirely when only buffer addresses change.

== How it works ==
1. TracedCommandBuffer now tracks slice-level addresses (not just allocation
   indices) per cache entry
2. On cache miss, instead of re-tracing via TraceCommandBufferFactory::Create,
   we call UpdateKernelNodes() on the existing cached graph
3. UpdateKernelNodes scans each kernel node's arguments:
   - For 'extra' packed buffers (rocBLAS/hipBLASLt): decode HIP_LAUNCH_PARAM
     buffer, scan at pointer-aligned offsets for old addresses, replace with new
   - For kernelParams arrays: scan pointer values directly
4. hipGraphKernelNodeSetParams commits changes to the graph definition
5. The existing UpdateChildCommand path propagates to the executable graph

== Enable ==
  export XLA_GPU_GRAPH_SKIP_RETRACE=1

== Results (Llama FSDP 8-layer, 8x MI308X) ==

Without command buffers (CB off):         41.6 ms/step
With CB (no collectives), baseline:       97.4 ms/step (after warmup)
With CB (no collectives), skip-retrace:   97.2 ms/step (after warmup)
  -> During warmup: 896ms baseline vs 864ms skip-retrace (3.5% improvement)
  -> 728 traces reduced to only initial creates, 1276 patches avoided retraces

With CB + collectives, baseline:          53,862 ms/step
With CB + collectives, skip-retrace:      52,810 ms/step (~2% improvement)
  -> Collectives mode now runs correctly (previously problematic)
  -> Remaining cost is from collective command re-tracing (not GEMMs)

Correctness: verified element-wise gradient match across all configurations.
Loss values stable and matching baseline throughout.

== Files changed ==
- command_buffer_cmd.cc/h: TracedCommandBuffer tracks buffer_slices_ and
  recorded_slice_addrs; GetOrTraceCommandBuffer attempts UpdateKernelNodes
  before falling back to retrace
- gpu_command_buffer.h: Added UpdateKernelNodes virtual method and
  DumpGraphKernelNodes for debugging
- rocm_command_buffer.cc/h: Implemented UpdateKernelNodes (scans extra/
  kernelParams, patches addresses, calls hipGraphKernelNodeSetParams)
- rocm_driver_wrapper.h: Added hipGraphKernelNodeSetParams wrapper

== Next steps ==
- Extend UpdateKernelNodes approach to collective command types to reduce
  the 53s/step cost with collectives enabled
- Investigate VA remapping as alternative to avoid address changes entirely
- Profile to identify which specific retrace operations dominate with collectives
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 7, 2026
#include "xla/status_macros.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Include order — gpu_command_buffer.h is inserted between device_address.h and device_address_handle.h, breaking alphabetical order within the xla/stream_executor/ group. It should come after device_address_handle.h (or be grouped with the gpu/ headers at line 106).

Suggested change
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/device_address_handle.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"

Comment on lines +283 to +291
static const bool skip_retrace = [] {
const char* env = std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE");
bool val = env != nullptr && std::string(env) == "1";
if (val) {
LOG(INFO) << "XLA_GPU_GRAPH_SKIP_RETRACE enabled: kernel node "
"params will be patched in-place to avoid re-tracing";
}
return val;
}();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skip-retrace feature is gated via raw std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE"), while the flattening feature in the same PR uses a proper DebugOptions proto flag (xla_gpu_graph_enable_node_flattening). Using two different configuration mechanisms within the same feature set is inconsistent and makes the feature harder to discover/manage. Consider using DebugOptions for both, keeping them aligned with XLA's standard configuration pattern.

Comment on lines +338 to +352
}
} else if (kp.kernelParams != nullptr) {
for (int a = 0; a < 64 && kp.kernelParams[a] != nullptr; ++a) {
uintptr_t val;
memcpy(&val, kp.kernelParams[a], sizeof(uintptr_t));
auto it = old_addr_map.find(val);
if (it != old_addr_map.end()) {
uintptr_t new_val = reinterpret_cast<uintptr_t>(
new_addresses[it->second].opaque());
memcpy(kp.kernelParams[a], &new_val, sizeof(uintptr_t));
modified = true;
}
}
if (modified) {
TF_RETURN_IF_ERROR(ToStatus(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness risk: The kernelParams patching reads sizeof(uintptr_t) bytes from each kernelParams[a], which works for pointer arguments but will misinterpret scalar arguments (int, float, etc.) as addresses. If a scalar value happens to have the same bit pattern as an old buffer address, it will be silently overwritten with a new address, causing data corruption.

The same risk exists in the extra/packed-buffer code path above (scanning at pointer-aligned offsets and matching raw values), though it's somewhat mitigated there by scanning a known packed buffer.

Additionally, the hardcoded limit of 64 here is inconsistent with the limit of 16 in DumpGraphKernelNodes at line 399 — a kernel with 17–63 params would be correctly patched but incompletely dumped, making debugging misleading. Consider using a named constant for both.

Comment on lines +841 to +843
auto* flat_cmd = const_cast<GpuFlattenedCommand*>(
dynamic_cast<const GpuFlattenedCommand*>(command));
if (!flat_cmd) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const_cast on a const Command* parameter is a code smell. Both BuildPatchTable and PatchFlattenedNodes take const Command* but mutate the object through const_cast. If mutation is intended, the virtual method signatures in gpu_command_buffer.h should take non-const Command*, rather than casting away const (which is UB if the object was originally declared const).

Comment on lines +514 to +516
auto* gpu_cmd = static_cast<const GpuCommand*>(d);
dep_handles.push_back(gpu_cmd->handle);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe cast: static_cast<const GpuCommand*>(d) is incorrect when d could be a GpuFlattenedCommand, GpuChildCommand, GpuCaseCommand, or GpuWhileCommand — none of which inherit from GpuCommand. This is undefined behavior. The existing ToGraphNodeDependencies helper (gpu_command_buffer.cc:108-128) handles this correctly with dynamic_cast chains. Consider reusing that helper or at minimum using dynamic_cast with a null check here.

Comment on lines +670 to +672
continue;
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silently skipping unsupported node types with continue can lead to incorrect dependency resolution. The flattened command's node_handles/node_infos will have fewer entries than the child graph, and if a downstream node depends on the skipped node, its dependencies won't be resolved in child_to_parent, causing it to fall back to dep_handles (external deps) instead of the correct internal predecessor. This could produce incorrect execution ordering or race conditions.

Consider either returning an error for unsupported types, or inserting an empty node to preserve the dependency chain.

Comment on lines +148 to +162
// Deep-copied extra arg buffers for kernel nodes that use the
// HIP_LAUNCH_PARAM packed-buffer launch convention (e.g. rocBLAS).
// Kept alive for the lifetime of the parent graph.
std::vector<std::unique_ptr<uint8_t[]>> extra_arg_buffers;
std::vector<std::unique_ptr<size_t>> extra_arg_sizes;

// Patch table for skip-retrace updates: records where each known
// buffer address appears inside the packed args of flattened
// extra-style kernel nodes.
std::vector<ArgPatchEntry> patch_table;

// Per-node deep-copied arg buffer and its size (only for extra-style
// kernel nodes that have patch entries).
struct NodeArgBuffer {
std::unique_ptr<uint8_t[]> data;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_arg_buffers (line 155) appears to be unused dead code — it is declared but never written to anywhere in this PR. Only node_arg_buffers (line 164) is populated in FlattenChildGraphNodes. Consider removing extra_arg_buffers if it's leftover from an earlier design iteration.

Comment on lines +555 to +557
// allows per-node address updates via hipGraphExecKernelNodeSetParams
// instead of full child graph re-tracing on buffer address changes.
optional bool xla_gpu_graph_enable_node_flattening = 502;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag is registered in debug_options_flags.cc but is never actually read anywhere in the codebase — the flattening methods (FlattenChildGraphNodes, UpdateFlattenedChildNodes, BuildPatchTable, PatchFlattenedNodes) are defined but never called from any code path. This makes the entire flattening infrastructure dead code in the current state of the PR. Is this intentional for a WIP, with the caller integration coming in a follow-up?

Comment on lines -403 to -414
size_t num_root_nodes = 0;
TF_RETURN_IF_ERROR(
ToStatus(wrap::hipGraphGetRootNodes(graph_, nullptr, &num_root_nodes),
"Failed to get HIP graph root node count"));

if (num_root_nodes == 0) {
return absl::InternalError(
"Traced HIP graph is empty. Traced function (custom call) did not "
"launch any HIP operations on the captured HIP stream. Instantiating "
"empty child nodes leads to crashes.");
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes the safety check that guarded against empty traced graphs. The original code explicitly warned: "Instantiating empty child nodes leads to crashes." The corresponding PrepareFinalization empty-node insertion is also removed below. Could you explain why these removals are safe? If the underlying HIP issue has been fixed, it would be good to note that in a comment.

@claude
Copy link
Copy Markdown

claude bot commented Apr 8, 2026

Review Summary

This WIP PR adds a skip-retrace optimization for HIP graph command buffers that patches kernel node arguments in-place when buffer addresses change, plus a node-flattening mechanism to extract child graph nodes into the parent graph for individual updates. ~700 lines of new ROCm-specific code.

Key concerns (see inline comments):

  • Correctness risk in kernel argument patching: Scanning kernel params by raw pointer-sized value can false-match scalar arguments (int/float) that happen to share a bit pattern with a buffer address, causing silent data corruption.
  • Unsafe static_cast in FlattenChildGraphNodes: Dependencies are static_cast to GpuCommand* but may be other command types — this is undefined behavior.
  • Removed empty-graph safety checks: The original code explicitly guarded against empty traced graphs to prevent crashes; these guards are removed without explanation.
  • const_cast usage in BuildPatchTable/PatchFlattenedNodes: Methods take const Command* but mutate through const_cast — signatures should be non-const if mutation is intended.
  • Dead code: The entire flattening infrastructure (FlattenChildGraphNodes, UpdateFlattenedChildNodes, BuildPatchTable, PatchFlattenedNodes) is defined but never called. The extra_arg_buffers field is declared but never written to.
  • Inconsistent configuration: Skip-retrace uses std::getenv while flattening uses DebugOptions — these should use the same mechanism.
  • No tests for ~700 lines of new pointer-manipulation code with correctness risks.

🤖 Generated with Claude Code

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 8, 2026
Previously, UpdateKernelNodes only patched hipGraphNodeTypeKernel nodes.
This extends the patching to also handle:
- hipGraphNodeTypeMemcpy: patch srcPtr.ptr and dstPtr.ptr in hipMemcpy3DParms
- hipGraphNodeTypeMemset: patch dst pointer in hipMemsetParams

This is needed for traced sub-graphs that contain memory transfer operations
alongside kernel launches (e.g., rocBLAS workspace copies, DNN scratch
buffer fills). Without this, the skip-retrace mechanism would return a
graph with stale memcpy/memset addresses, causing incorrect results or
crashes.

Added HIP API wrappers:
- hipGraphMemcpyNodeSetParams
- hipGraphMemsetNodeSetParams

Verified correct loss values on Llama FSDP benchmark (8xMI308X):
- 7 benchmark steps completed, all loss values 10.7748-10.7753
- Step times: 657-754ms (consistent with baseline)
… patching

Switch ROCm CreateKernelNode and UpdateKernelNode from kernelParams-style
to extra-style (HIP_LAUNCH_PARAM_BUFFER_POINTER) argument passing. This
avoids the HIP bug (ROCm/clr#138) where hipGraphKernelNodeGetParams
returns dangling kernelParams pointers, and enables the UpdateNodeAddresses
fast-update path to bypass full command buffer re-recording when only
buffer addresses change (BFC allocator recapture).

Key changes:
- CreateKernelNode packs arguments into an owned contiguous buffer and
  uses extra-style, caching the node handle and params for fast patching.
- UpdateKernelNode also uses extra-style for consistency.
- UpdateNodeAddresses patches owned arg buffers directly and calls
  hipGraphExecKernelNodeSetParams per modified node, avoiding the
  expensive hipGraphExecUpdate full-graph sync.
- has_kernelparams_nodes_ flag now stays false for XLA-recorded kernels,
  allowing SupportsNodeAddressUpdate() to return true.
- Added profiling instrumentation behind XLA_PROFILE_CMD_BUFFER env var.
Resolved ToString() -> ToString(0) signature change from upstream.
Kept profiling instrumentation behind XLA_PROFILE_CMD_BUFFER env var.
@phambinhfin phambinhfin changed the title WIP: Skip-retrace for HIP graph command buffers - in-place kernel node patching Skip command buffer re-tracing via in-place HIP graph node patching Apr 9, 2026
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 9, 2026
#include "xla/backends/gpu/runtime/command_buffer_thunk.h"

#include <algorithm>
#include <sys/mman.h>
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<sys/mman.h> is a POSIX system header and should be placed after the C++ standard library headers (or in a separate group), per Google C++ style. It also breaks alphabetical order within the <c*> group. Additionally, this is a Linux-only header -- using mmap makes this code non-portable to Windows (relevant if XLA ever supports non-ROCm/non-CUDA targets). Consider using aligned_alloc or a platform-abstracted allocator instead.

Comment on lines +417 to +434
if (!cmd_buffer->prev_allocs) {
size_t alloc_bytes =
((n * sizeof(se::DeviceAddressBase)) + 4095) & ~4095uL;
void* p = mmap(nullptr, alloc_bytes, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (p != MAP_FAILED) {
cmd_buffer->prev_allocs = static_cast<se::DeviceAddressBase*>(p);
cmd_buffer->prev_allocs_capacity =
alloc_bytes / sizeof(se::DeviceAddressBase);
}
}
if (cmd_buffer->prev_allocs) {
size_t copy_n = std::min(n, cmd_buffer->prev_allocs_capacity);
memcpy(cmd_buffer->prev_allocs, cmd_buffer->recorded_allocs.data(),
copy_n * sizeof(se::DeviceAddressBase));
cmd_buffer->prev_allocs_size = copy_n;
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: prev_allocs is allocated via mmap but never freed. ExecutorCommandBuffer has no destructor, so this memory is leaked when the object is destroyed. Either add a destructor that calls munmap(prev_allocs, prev_allocs_capacity * sizeof(se::DeviceAddressBase)), or use a simpler allocation strategy (e.g., std::vector or std::unique_ptr<se::DeviceAddressBase[]>) that cleans up automatically.

The comment says mmap is used "to avoid system allocator (heap) interference with the HIP graph runtime," but this rationale should be explained more concretely -- what specific interference occurs? If it is a real concern, the leak still needs to be fixed.

bool did_fast_update = false;

if (can_fast_update && cmd_buffer->prev_allocs_size > 0) {
auto* gpu_cmd_buf = static_cast<se::gpu::GpuCommandBuffer*>(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe downcast: static_cast<se::gpu::GpuCommandBuffer*> is used instead of dynamic_cast. If command_buffer is not actually a GpuCommandBuffer (e.g., on a different platform or future backend), this is undefined behavior. The null check on line 323 (gpu_cmd_buf &&) only guards against a null command_buffer.get(), not against an incorrect type. The same issue appears at line 412. Use dynamic_cast with a null check for safety.

Comment on lines 253 to +300
@@ -277,49 +284,109 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) {
return absl::OkStatus();
}

static const bool fast_update = [] {
const char* env = std::getenv("XLA_GPU_GRAPH_FAST_UPDATE");
return env != nullptr && std::string(env) == "1";
}();

uint64_t t_alloc_start = 0, t_alloc_end = 0;
uint64_t t_record_start = 0, t_record_end = 0;
uint64_t t_submit_start = 0, t_submit_end = 0;

if (profile_steps) t_alloc_start = tsl::Env::Default()->NowMicros();

auto updated_allocs = cmd_buffer->UpdateBufferAllocations(commands_, params);

// Determine whether to (re-)record the command buffer and whether this is a
// first-time initialization recording (VA remapping path).
if (profile_steps) t_alloc_end = tsl::Env::Default()->NowMicros();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Three new std::getenv-based feature flags are introduced in this file (XLA_PROFILE_CMD_BUFFER, XLA_GPU_GRAPH_FAST_UPDATE), adding to the XLA_GPU_GRAPH_SKIP_RETRACE in command_buffer_cmd.cc. These should use XLA's DebugOptions proto mechanism for consistency with the rest of the codebase. Environment variables are harder to discover, undocumented, and cannot be set per-compilation (only per-process). The static const bool + std::getenv pattern also means the value is read once at first invocation and cannot be changed, unlike DebugOptions which can be set per-HLO module.

Comment on lines +440 to +451
LOG(WARNING) << "CmdBufProfile dev=" << dev
<< " alloc_check=" << (t_alloc_end - t_alloc_start) << "us"
<< " record=" << (t_record_end - t_record_start) << "us"
<< " submit=" << (t_submit_end - t_submit_start) << "us"
<< " total=" << (t_submit_end - t_alloc_start) << "us"
<< " updated=" << needs_update
<< " fast_update=" << fast_update
<< " can_fast=" << can_fast_update
<< " prev_sz=" << cmd_buffer->prev_allocs_size
<< " num_cmds=" << commands_.size()
<< " num_allocs_changed=" << num_allocs_changed;
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Profiling output uses LOG(WARNING) for what is purely informational/diagnostic data. This will pollute warning logs in production even when intentionally enabled. Consider using LOG(INFO) or VLOG(1) instead. The same pattern appears in command_buffer_cmd.cc (lines 306, 331, 347) and earlier in this file (line 324, 344, 358, 378).

Comment on lines +397 to +404
"Failed to set memset node params after patching"));
}
}
}

return absl::OkStatus();
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PatchGraphNodes helper function (used by UpdateNodeAddresses) lacks a recursion depth limit. It recurses into child graphs via hipGraphNodeTypeGraph nodes. If the graph structure were cyclic (which shouldn't happen but could due to a HIP bug), or very deeply nested, this could cause a stack overflow. Consider adding a maximum depth guard (e.g., if (depth > 32) return InternalError(...)).

Comment on lines +536 to +560

bool RocmCommandBuffer::SupportsNodeAddressUpdate() const {
return exec_ != nullptr && graph_ != nullptr && !has_kernelparams_nodes_;
}

absl::StatusOr<bool> RocmCommandBuffer::UpdateNodeAddresses(
absl::Span<const DeviceAddressBase> old_addresses,
absl::Span<const DeviceAddressBase> new_addresses) {
if (exec_ == nullptr || graph_ == nullptr) {
return false;
}

if (has_kernelparams_nodes_) return false;

size_t common_size = std::min(old_addresses.size(), new_addresses.size());
absl::flat_hash_map<uintptr_t, size_t> old_addr_map;
for (size_t i = 0; i < common_size; ++i) {
auto old_val = reinterpret_cast<uintptr_t>(old_addresses[i].opaque());
auto new_val = reinterpret_cast<uintptr_t>(new_addresses[i].opaque());
if (old_val != 0 && old_val != new_val) {
old_addr_map[old_val] = i;
}
}

if (old_addr_map.empty()) return true;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In UpdateNodeAddresses, memcpy/memset nodes are patched via the graph definition (hipGraphMemcpyNodeSetParams) and then pushed to the executable graph (hipGraphExecMemcpyNodeSetParams1D). However, the definition-graph update uses hipGraphMemcpyNodeSetParams with the full 3D params, while the exec update uses hipGraphExecMemcpyNodeSetParams1D which assumes 1D. If the original memcpy was actually a 2D/3D copy, the exec-level update will be incorrect -- it flattens height/depth into a single size, changing the memory layout.

Additionally, hipMemcpyDeviceToDevice is hardcoded on line 548, but the original copy kind could be host-to-device, device-to-host, etc.

Comment on lines +1298 to +1310
// Pack all argument values into a contiguous buffer and use extra-style
// launch (HIP_LAUNCH_PARAM_BUFFER_POINTER) instead of kernelParams.
// This avoids the HIP bug (ROCm/clr#138) where hipGraphKernelNodeGetParams
// returns dangling kernelParams pointers, and enables UpdateNodeAddresses.
auto arg_addrs = args.argument_addresses();
size_t num_args = arg_addrs.size();

// Each argument is a pointer-sized value; pack them contiguously.
size_t buf_size = num_args * sizeof(void*);
auto packed_buf = std::make_unique<uint8_t[]>(buf_size);
for (size_t i = 0; i < num_args; ++i) {
memcpy(packed_buf.get() + i * sizeof(void*), arg_addrs[i], sizeof(void*));
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CreateKernelNode change packs all kernel arguments as pointer-sized values (sizeof(void*) each), but not all kernel arguments are necessarily pointer-sized. For example, a kernel taking an int32_t argument would have sizeof(int32_t) == 4 bytes, but this code packs it as 8 bytes (on 64-bit), adding 4 bytes of uninitialized padding. The comment on line 1305 acknowledges the assumption ("Each argument is a pointer-sized value") but this is not guaranteed by the KernelArgs API.

The same assumption is made in UpdateKernelNode below. If XLA's kernel argument packing does indeed guarantee pointer-sized values, this should be documented/asserted at the KernelArgs level rather than assumed here.

#include "xla/backends/gpu/runtime/command_buffer_thunk.h"

#include <algorithm>
#include <sys/mman.h>
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (portability): <sys/mman.h> is a POSIX-specific header placed between C++ standard library headers, violating Google C++ include ordering (C system headers should be grouped separately). More importantly, this is platform-generic code shared across backends — this include and the mmap call at line 420 will fail to compile on Windows. Either guard with #ifdef __linux__ / platform ifdefs, or use a portable allocator (std::aligned_alloc, operator new).

Comment on lines +417 to +434
if (!cmd_buffer->prev_allocs) {
size_t alloc_bytes =
((n * sizeof(se::DeviceAddressBase)) + 4095) & ~4095uL;
void* p = mmap(nullptr, alloc_bytes, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (p != MAP_FAILED) {
cmd_buffer->prev_allocs = static_cast<se::DeviceAddressBase*>(p);
cmd_buffer->prev_allocs_capacity =
alloc_bytes / sizeof(se::DeviceAddressBase);
}
}
if (cmd_buffer->prev_allocs) {
size_t copy_n = std::min(n, cmd_buffer->prev_allocs_capacity);
memcpy(cmd_buffer->prev_allocs, cmd_buffer->recorded_allocs.data(),
copy_n * sizeof(se::DeviceAddressBase));
cmd_buffer->prev_allocs_size = copy_n;
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug (memory leak): prev_allocs is allocated via mmap(MAP_PRIVATE | MAP_ANONYMOUS) but is never freed with munmap. When the ExecutorCommandBuffer is destroyed, this memory leaks. The comment says this avoids "system allocator (heap) interference with the HIP graph runtime" but provides no justification for why std::vector or std::unique_ptr<T[]> would interfere. The mmap approach also skips the RAII pattern expected by XLA conventions. If mmap is truly required, the ExecutorCommandBuffer destructor should call munmap.

bool did_fast_update = false;

if (can_fast_update && cmd_buffer->prev_allocs_size > 0) {
auto* gpu_cmd_buf = static_cast<se::gpu::GpuCommandBuffer*>(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (safety): static_cast<se::gpu::GpuCommandBuffer*> is undefined behavior if the actual runtime type is not GpuCommandBuffer. In command_buffer_cmd.cc (same PR), you use dynamic_cast for the same downcast (line 311). This should be dynamic_cast with a null check for consistency and safety — especially since the supports check on the next line already handles the null case.

Comment on lines 250 to 256
absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) {
static const bool profile_steps = [] {
const char* env = std::getenv("XLA_PROFILE_CMD_BUFFER");
return env != nullptr && std::string(env) == "1";
}();

// We might end up with empty command sequence if all of the captured fusions
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (consistency): This PR introduces three std::getenv-based feature flags (XLA_PROFILE_CMD_BUFFER, XLA_GPU_GRAPH_FAST_UPDATE here, plus XLA_GPU_GRAPH_SKIP_RETRACE in command_buffer_cmd.cc), while the flattening feature uses DebugOptions (xla_gpu_graph_enable_node_flattening). These should all use DebugOptions for consistency — env vars are harder to discover, not documented in the proto, and bypass XLA's standard configuration surface.

Comment on lines +440 to +451
LOG(WARNING) << "CmdBufProfile dev=" << dev
<< " alloc_check=" << (t_alloc_end - t_alloc_start) << "us"
<< " record=" << (t_record_end - t_record_start) << "us"
<< " submit=" << (t_submit_end - t_submit_start) << "us"
<< " total=" << (t_submit_end - t_alloc_start) << "us"
<< " updated=" << needs_update
<< " fast_update=" << fast_update
<< " can_fast=" << can_fast_update
<< " prev_sz=" << cmd_buffer->prev_allocs_size
<< " num_cmds=" << commands_.size()
<< " num_allocs_changed=" << num_allocs_changed;
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: All profiling messages use LOG(WARNING), which will pollute warning-level logs in production when XLA_PROFILE_CMD_BUFFER=1 is set. This can trigger warning-level log monitoring/alerting. Profiling output should use VLOG(1) or LOG(INFO) instead.

Comment on lines +397 to +404
"Failed to set memset node params after patching"));
}
}
}

return absl::OkStatus();
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (robustness): PatchGraphNodes recurses into child graphs via hipGraphChildGraphNodeGetGraph but has no depth limit. A graph with deeply nested children (or a cycle due to HIP API bug) would cause a stack overflow. Consider adding a max-depth guard (e.g., if (depth > 32) return InternalError(...)).

Comment on lines +536 to +560

bool RocmCommandBuffer::SupportsNodeAddressUpdate() const {
return exec_ != nullptr && graph_ != nullptr && !has_kernelparams_nodes_;
}

absl::StatusOr<bool> RocmCommandBuffer::UpdateNodeAddresses(
absl::Span<const DeviceAddressBase> old_addresses,
absl::Span<const DeviceAddressBase> new_addresses) {
if (exec_ == nullptr || graph_ == nullptr) {
return false;
}

if (has_kernelparams_nodes_) return false;

size_t common_size = std::min(old_addresses.size(), new_addresses.size());
absl::flat_hash_map<uintptr_t, size_t> old_addr_map;
for (size_t i = 0; i < common_size; ++i) {
auto old_val = reinterpret_cast<uintptr_t>(old_addresses[i].opaque());
auto new_val = reinterpret_cast<uintptr_t>(new_addresses[i].opaque());
if (old_val != 0 && old_val != new_val) {
old_addr_map[old_val] = i;
}
}

if (old_addr_map.empty()) return true;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug (correctness): The memcpy size computation width * height * depth doesn't account for element size or pitch, and the semantics of hipMemcpy3DParms.extent.width differ between 1D and 2D/3D copies. Additionally, hipMemcpyDeviceToDevice is always hardcoded via hipGraphExecMemcpyNodeSetParams1D, which ignores the original copy direction (the source memcpy could be D2H or H2D). This means host-to-device or device-to-host memcpy nodes will silently produce incorrect behavior after patching.

Comment on lines +1298 to +1310
// Pack all argument values into a contiguous buffer and use extra-style
// launch (HIP_LAUNCH_PARAM_BUFFER_POINTER) instead of kernelParams.
// This avoids the HIP bug (ROCm/clr#138) where hipGraphKernelNodeGetParams
// returns dangling kernelParams pointers, and enables UpdateNodeAddresses.
auto arg_addrs = args.argument_addresses();
size_t num_args = arg_addrs.size();

// Each argument is a pointer-sized value; pack them contiguously.
size_t buf_size = num_args * sizeof(void*);
auto packed_buf = std::make_unique<uint8_t[]>(buf_size);
for (size_t i = 0; i < num_args; ++i) {
memcpy(packed_buf.get() + i * sizeof(void*), arg_addrs[i], sizeof(void*));
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (correctness assumption): The packing logic assumes every kernel argument is exactly sizeof(void*) bytes, copying sizeof(void*) from each argument_addresses() element. However, KernelArgsPackedArrayBase::argument_addresses() returns pointers to argument data of potentially different sizes (scalars, structs). If any argument is smaller than sizeof(void*), this reads beyond its storage; if larger, it truncates. While XLA-generated GPU kernels likely only pass device pointers, this assumption is undocumented and could break with library kernels. Consider adding a DCHECK or static assertion that all argument sizes equal sizeof(void*).

@claude
Copy link
Copy Markdown

claude bot commented Apr 9, 2026

Re-review Summary

Reviewed the updated diff. The 9 previous inline findings remain unaddressed — no changes needed on those threads.

8 new findings posted inline, primarily in command_buffer_thunk.cc (5) and rocm_command_buffer.cc (3):

  • 1 bug: mmap-allocated prev_allocs is never freed (munmap) — memory leak on every ExecutorCommandBuffer destruction
  • 1 correctness bug: UpdateNodeAddresses hardcodes hipMemcpyDeviceToDevice and computes copy size incorrectly for 2D/3D memcpy nodes
  • 2 correctness concerns: kernel arg packing assumes all args are pointer-sized; recursive PatchGraphNodes has no depth limit
  • 1 portability issue: <sys/mman.h> / mmap in platform-generic code
  • 1 safety nit: static_cast downcast should be dynamic_cast (matches pattern used elsewhere in same PR)
  • 1 consistency suggestion: mixed env vars vs DebugOptions for feature flags
  • 1 logging nit: profiling uses LOG(WARNING) instead of VLOG/LOG(INFO)

Previous findings (empty-graph safety check removal, const_cast UB, dead flattening code, kernelParams limit inconsistency, etc.) are still open.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 9, 2026
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 9, 2026
@claude
Copy link
Copy Markdown

claude bot commented Apr 9, 2026

Re-review Summary

Reviewed latest changes including commit 711e998 ("Add trace cache for CollectiveCmd to fix COLLECTIVES graph capture regression"). The CollectiveCmd trace cache correctly applies the TracedCommandBuffer pattern and introduces no new issues.

All 25 previously posted inline findings remain unaddressed — no code changes or replies have been made to resolve them. The key concerns are unchanged:

  • Correctness: Kernel argument packing assumes pointer-sized args (UB for scalar args); false-positive address patching via raw pointer scanning; memcpy node patching loses stride/direction info
  • Safety: static_cast downcasts where dynamic_cast is needed; empty-graph safety checks removed without replacement; const_cast to mutate through const params
  • Portability/Quality: <sys/mman.h> / mmap in platform-generic code with no munmap (memory leak); std::getenv feature flags instead of DebugOptions; LOG(WARNING) for profiling output
  • Dead code: ~600 lines of unreachable flattening infrastructure and unused proto flag
  • No tests for any of the new pointer-manipulation or graph-patching code

No new inline comments posted — all issues were already covered in previous review rounds.

🤖 Generated with Claude Code

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 9, 2026
When COLLECTIVES are graph-captured, RCCL operations produce child graph
nodes (hipGraphNodeTypeGraph). Previously UpdateNodeAddresses only
patched kernel/memcpy/memset nodes in the parent graph, skipping child
graphs entirely.

This extends the fast-update path to:
1. Detect child graph nodes in the parent graph walk
2. Use PatchGraphNodes to recursively patch kernel/memcpy/memset nodes
   inside the child graph, including kernelParams-style nodes from
   stream capture (RCCL collectives)
3. Push the patched child graph to the exec graph via
   hipGraphExecChildGraphNodeSetParams

Also extends PatchGraphNodes to handle kernelParams-style kernel nodes
(not just extra-style), since stream-captured RCCL kernels use
kernelParams where HIP owns the parameter storage.
@phambinhfin phambinhfin force-pushed the phambinh/skip-retrace-kernel-node-patching branch from 79728ba to 0182f04 Compare April 10, 2026 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant