Skip to content

Enable AllGather Triton Backend#799

Draft
mfrancepillois wants to merge 10 commits intoci_maxime_allreduce_triton_rocm_elementwise_rocmfrom
ci_rocm_maxime_allgather_triton
Draft

Enable AllGather Triton Backend#799
mfrancepillois wants to merge 10 commits intoci_maxime_allreduce_triton_rocm_elementwise_rocmfrom
ci_rocm_maxime_allgather_triton

Conversation

@mfrancepillois
Copy link
Copy Markdown

This PR enables AllGather triton backend:

  • Enhances the common collective-emitter to handle the AllGather op (that returns a tuple)
  • Adds 2 tritons kernel implementations: one default implementation and one using swizzled for loading data
  • Adds e2e tests.
    (This support needed the triton-xla atomics operations to be implemented. That's why it is based on top of the branch ci_maxime_allreduce_triton_rocm_elementwise_rocm)

@mfrancepillois mfrancepillois added the claude-review Request a Claude AI code review for this PR label Apr 13, 2026
@claude
Copy link
Copy Markdown

claude bot commented Apr 13, 2026

Review Summary

This PR extends the collective emitter infrastructure (originally built for AllReduce) to support AllGather via the Triton backend. It adds two kernel implementations (default and swizzled), tuple unpacking for AllGatherStart's (input, output) shape, a new KernelArguments::Create overload, and comprehensive e2e tests. The feature is gated behind xla_gpu_unsupported_use_all_gather_triton_backend (default: false).

Key issues found:

  • Bug — dtype check bypassed for AllGatherStart: all_gather.shape().element_type() returns TUPLE for AllGatherStart, so the F8/S4 unsupported-type guard never fires. Should use the operand's element type instead.
  • Bug — inconsistent kMaxBlocksPerGrid constants: Launch dimensions use 32, but signal buffer shape uses 24. Should be a single shared constant.
  • Correctness — GetTupleElement identity indexing map: The new GTE case in ComputeOutputToInputIndexing maps output shape to the tuple operand, which is semantically incorrect for general GTE operations. Needs a guard or scoping.
  • Dead code with latent bugs: EmitAllGatherSwizzled hardcodes gather dim=0 and has a potential division-by-zero. Both will break if the swizzled path is enabled.
  • Logging noise: ~24 new LOG(INFO) calls on common code paths (including non-collective Triton fusions). Should be VLOG(n).

Details in inline comments.

Automated review by Claude

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 13, 2026
@mfrancepillois mfrancepillois force-pushed the ci_rocm_maxime_allgather_triton branch from 3a41014 to 5b6054a Compare April 13, 2026 15:31
@mfrancepillois mfrancepillois force-pushed the ci_rocm_maxime_allgather_triton branch from 5b6054a to 0db520d Compare April 13, 2026 15:39
@i-chaochen
Copy link
Copy Markdown
Collaborator

wondering is this branch is based on upstream or xla-0.9.1?

@mfrancepillois
Copy link
Copy Markdown
Author

wondering is this branch is based on upstream or xla-0.9.1?

This branch is based on ci_maxime_allreduce_triton_rocm_elementwise_rocm because we need triton_xla atomic operations to be implemented. But ci_maxime_allreduce_triton_rocm_elementwise_rocm is based on rocm-jaxlib-v0.9.1.

// group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
// pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
// pid_n = (tile_id % num_pid_in_group) // group_size_m
mlir::LogicalResult EmitAllGatherSwizzled(int64_t group_size_m) {
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the swizzled kernel is not called but I'm keeping it until the performance evaluation is complete.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants