Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,7 @@ cc_library(
srcs = ["collective_emitter.cc"],
hdrs = ["collective_emitter.h"],
deps = [
":lowering_util",
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
Expand All @@ -926,8 +927,10 @@ cc_library(
"//xla/codegen/xtile/ir:xtile",
"//xla/core/collectives:reduction_kind",
"//xla/hlo/ir:hlo",
"//xla/mlir/utils:type_util",
"//xla/service:collective_ops_utils",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:gpu_constants",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
"//xla/stream_executor:device_description",
Expand Down
777 changes: 774 additions & 3 deletions xla/backends/gpu/codegen/triton/collective_emitter.cc

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions xla/backends/gpu/codegen/triton/collective_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,9 @@ absl::StatusOr<std::vector<Shape>> GetCollectiveUnmanagedKernelArguments(
mlir::LogicalResult RewriteAllReduce(mlir::stablehlo::AllReduceOp op,
mlir::PatternRewriter& rewriter);

// Rewrites stablehlo all-gather op to a triton implementation.
mlir::LogicalResult RewriteAllGather(mlir::stablehlo::AllGatherOp op,
mlir::PatternRewriter& rewriter);
Comment thread
mfrancepillois marked this conversation as resolved.

} // namespace xla::gpu
#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_COLLECTIVE_EMITTER_H_
69 changes: 47 additions & 22 deletions xla/backends/gpu/codegen/triton/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,25 +125,36 @@ absl::StatusOr<TritonFusion::EmitResult> TritonFusion::Emit(
absl::Span<const Shape> unmanaged_arguments) const {
std::string suggested_kernel_name = std::string(fusion.name());
llvm::IRBuilder builder(*ir_emitter_context.llvm_context());
VLOG(3) << fusion.ToString();
TF_ASSIGN_OR_RETURN(
auto kernel_arguments,
emitters::KernelArguments::Create(
ir_emitter_context.buffer_assignment(), GetDefaultBufferAlignment(),
instr_override != nullptr ? instr_override : &fusion,
unmanaged_arguments));

// Special handling for AllGather with tuple unpacking
absl::StatusOr<emitters::KernelArguments> kernel_arguments_or;
if (instr_override != nullptr &&
instr_override->opcode() == HloOpcode::kAllGatherStart &&
instr_override->shape().IsTuple() && !fusion.shape().IsTuple()) {
// Use the new overload: fusion for shapes, AllGather for buffers at index
// {1}
kernel_arguments_or = emitters::KernelArguments::Create(
ir_emitter_context.buffer_assignment(), GetDefaultBufferAlignment(),
&fusion, // shape_instruction
instr_override, // buffer_instruction
ShapeIndex{1}, // output_index (element 1 of AllGather tuple)
unmanaged_arguments);
} else {
// Regular path for AllReduce and other operations
kernel_arguments_or = emitters::KernelArguments::Create(
ir_emitter_context.buffer_assignment(), GetDefaultBufferAlignment(),
instr_override != nullptr ? instr_override : &fusion,
unmanaged_arguments);
}
TF_ASSIGN_OR_RETURN(auto kernel_arguments, std::move(kernel_arguments_or));

const HloComputation* hlo_computation =
fusion.fused_instructions_computation();
VLOG(3) << "hlo_computation: " << hlo_computation->ToString();

std::unique_ptr<llvm::Module> local_module;
auto generate = [&]() -> absl::StatusOr<KernelReuseCache::Entry> {
VLOG(3) << "Generating: " << suggested_kernel_name;

const std::string sanitized_kernel_name =
ir_emitter_context.GetSanitizedUniqueName(suggested_kernel_name);

TF_ASSIGN_OR_RETURN(
TritonWrapperResult triton_wrapper_result,
GenerateTritonKernelAndWrapper(
Expand All @@ -152,11 +163,9 @@ absl::StatusOr<TritonFusion::EmitResult> TritonFusion::Emit(
ir_emitter_context.data_layout(), ir_emitter_context.llvm_context(),
ir_emitter_context.mlir_context()));
local_module = std::move(triton_wrapper_result.llvm_module);

auto backend_config =
fusion.backend_config<GpuBackendConfig>()->fusion_backend_config();
absl::string_view fusion_kind = backend_config.kind();

LaunchDimensions launch_dimensions;

// TODO(bchetioui,pifon): this list should be consolidated; why do we need
Expand All @@ -166,11 +175,9 @@ absl::StatusOr<TritonFusion::EmitResult> TritonFusion::Emit(
kTritonNestedGemmFusionKind,
kTritonCollectiveFusionKind,
};

if (!absl::c_linear_search(kSupportedFusionKinds, fusion_kind)) {
return Internal("Unsupported fusion kind: %s", fusion_kind);
}

std::optional<LaunchConfig> launch_config;
// Currently GetLaunchConfig will compute the same value as the extracted
// one. They are different only when warp specialization is enabled.
Expand All @@ -197,11 +204,9 @@ absl::StatusOr<TritonFusion::EmitResult> TritonFusion::Emit(

AnnotateAttrsIfUnset(kernel_arguments, *kernel);
PopulateNvvmAnnotations(local_module.get(), kernel, triton_wrapper_result);

TF_RETURN_IF_ERROR(AnnotateKernelLaunchDimensions(
ir_emitter_context.gpu_device_info(), launch_dimensions, kernel,
local_module.get()));

return {{kernel->getName().str(), launch_dimensions,
/*cluster_dim=*/std::nullopt, triton_wrapper_result.shmem_bytes,
/*binary=*/"", triton_wrapper_result.tma_metadata}};
Expand Down Expand Up @@ -242,25 +247,45 @@ std::optional<TritonFusion::LaunchConfig> TritonFusion::GetLaunchConfig(

// We expect all roots to have the same number of blocks. Otherwise we
// cannot codegen it.
LOG(INFO) << "GetLaunchConfig: fusion_root_count="
Comment thread
mfrancepillois marked this conversation as resolved.
<< analysis_.fusion_root_count();

if (analysis_.fusion_root_count() == 0) {
LOG(ERROR) << "GetLaunchConfig: No fusion roots found!";
return std::nullopt;
}

int64_t num_blocks =
GetNumberOfBlocks(analysis_.fusion_root(0).shape().dimensions(),
block_level_parameters.output_tile_sizes[0]);
for (int64_t i = 1; i < analysis_.fusion_root_count(); ++i) {
CHECK_EQ(GetNumberOfBlocks(analysis_.fusion_root(i).shape().dimensions(),
block_level_parameters.output_tile_sizes[i]),
num_blocks);
}
if (i >= block_level_parameters.output_tile_sizes.size()) {
LOG(ERROR)
<< "GetLaunchConfig: output_tile_sizes index out of bounds! i=" << i
<< ", size=" << block_level_parameters.output_tile_sizes.size();
return std::nullopt;
}

int64_t blocks_for_root =
GetNumberOfBlocks(analysis_.fusion_root(i).shape().dimensions(),
block_level_parameters.output_tile_sizes[i]);

CHECK_EQ(blocks_for_root, num_blocks);
}
LaunchConfig launch_config;
// TODO(b/451901200): We eventually also want to be able to predict this
// value without compiling so the cost model can rely on it. Currently, we
// need the override for auto warp specialization.
LOG(INFO) << "GetLaunchConfig: thread_dims_override.has_value()="
<< thread_dims_override.has_value();

if (thread_dims_override) {
launch_config.launch_dimensions = LaunchDimensions{
se::BlockDim(num_blocks), thread_dims_override.value()};
} else {
int64_t warp_size = WarpSize(analysis_.device_info());
int64_t estimated_threads_per_block =
block_level_parameters.num_warps * WarpSize(analysis_.device_info());
block_level_parameters.num_warps * warp_size;
launch_config.launch_dimensions =
LaunchDimensions{static_cast<uint64_t>(num_blocks),
static_cast<uint64_t>(estimated_threads_per_block)};
Expand Down
51 changes: 51 additions & 0 deletions xla/backends/gpu/codegen/triton/support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,51 @@ CodegenDecision IsTritonSupportedAllReduce(
return CodegenDecision::Allow();
}

CodegenDecision IsTritonSupportedAllGather(
const HloAllGatherInstruction& all_gather,
const se::GpuComputeCapability& gpu_version) {
// Check if the flag is enabled
bool flag_enabled = all_gather.GetModule()
->config()
.debug_options()
.xla_gpu_unsupported_use_all_gather_triton_backend();

VLOG(1) << "IsTritonSupportedAllGather called for: " << all_gather.name()
<< ", flag_enabled=" << flag_enabled;

if (!flag_enabled) {
VLOG(1) << "AllGather Triton backend is DISABLED for: "
<< all_gather.name();
return CodegenDecision::Forbid(
"Triton backend for all-gather is disabled. Enable with "
"--xla_gpu_unsupported_use_all_gather_triton_backend=true");
}

VLOG(1) << "AllGather Triton backend is ENABLED for: " << all_gather.name();

PrimitiveType element_type = all_gather.operand(0)->shape().element_type();
if (element_type == F8E4M3FN || element_type == F8E5M2 ||
element_type == S4) {
VLOG(1) << "AllGather rejected due to unsupported data type: "
<< all_gather.shape().element_type();
return CodegenDecision::Forbid(
"S4, F8E4M3FN and F8E5M2 are not supported for all-gathers.");
}

// TODO(allgather-triton): Add additional validation checks similar to
// AllReduce
// - Check replica groups
// - Check gather dimension constraints
// - Check size thresholds

VLOG(1) << "AllGather Triton backend ALLOWED for: " << all_gather.name()
<< " (but fusion wrapping may not happen - check thunk_emitter.cc)";

// Allow codegen to proceed - the actual "not implemented" error will be
// generated in collective_emitter.cc::RewriteAllGather()
return CodegenDecision::Allow();
}

bool IsInTritonNestedGemmFusion(const HloInstruction& hlo) {
if (!hlo.parent()->IsFusionComputation()) {
return false;
Expand Down Expand Up @@ -792,6 +837,12 @@ CodegenDecision IsTritonSupportedInstructionImpl(
case HloOpcode::kAllReduceDone:
return IsTritonSupportedAllReduce(
*Cast<HloAllReduceInstruction>(instr.operand(0)), gpu_version);
case HloOpcode::kAllGatherStart:
return IsTritonSupportedAllGather(*Cast<HloAllGatherInstruction>(&instr),
gpu_version);
case HloOpcode::kAllGatherDone:
return IsTritonSupportedAllGather(
*Cast<HloAllGatherInstruction>(instr.operand(0)), gpu_version);
default:
// Not all instructions have a special handling.
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,18 @@ class LowerAllReduce : public mlir::OpRewritePattern<stablehlo::AllReduceOp> {
}
};

class LowerAllGather : public mlir::OpRewritePattern<stablehlo::AllGatherOp> {
public:
using OpRewritePattern::OpRewritePattern;

private:
mlir::LogicalResult matchAndRewrite(
stablehlo::AllGatherOp op,
mlir::PatternRewriter& rewriter) const override {
return ::xla::gpu::RewriteAllGather(op, rewriter);
}
};

class StableHLOLowerToTritonPass
: public impl::StableHLOLowerToTritonPassBase<StableHLOLowerToTritonPass> {
public:
Expand All @@ -828,7 +840,8 @@ class StableHLOLowerToTritonPass
mlir::MLIRContext* mlir_context = &getContext();
mlir::RewritePatternSet patterns(mlir_context);
patterns.add<LowerTranspose, LowerIotaToMakeRange, LowerBroadcastInDim,
LowerReduce, LowerReshape, LowerAllReduce>(mlir_context);
LowerReduce, LowerReshape, LowerAllReduce, LowerAllGather>(
mlir_context);
patterns.add<LowerDotGeneral>(mlir_context, warp_specialization_allowed_);

if (mlir::failed(
Expand Down
1 change: 1 addition & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,7 @@ cc_library(
srcs = ["all_gather_thunk.cc"],
hdrs = ["all_gather_thunk.h"],
deps = [
":collective_kernel_thunk",
":collective_thunk",
":thunk",
"//xla:future",
Expand Down
63 changes: 60 additions & 3 deletions xla/backends/gpu/runtime/all_gather_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/stream_executor/stream.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/casts.h"
#include "xla/tsl/platform/status_macros.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -91,6 +91,21 @@ AllGatherStartThunk::AllGatherStartThunk(ThunkInfo thunk_info,
CHECK_EQ(config_.config.operand_element_type.size(), buffers_.size());
}

AllGatherStartThunk::AllGatherStartThunk(
ThunkInfo thunk_info, const HloAllGatherInstruction* inst,
std::vector<Buffer> buffers,
std::unique_ptr<CollectiveKernelThunk> collective_kernel_thunk,
bool p2p_memcpy_enabled)
: CollectiveThunk(Thunk::kAllGatherStart, thunk_info,
IsGPUSyncCollective(*inst), false),
config_(GetAllGatherConfig(inst)),
buffers_(std::move(buffers)),
collective_kernel_thunk_(std::move(collective_kernel_thunk)) {
CHECK_EQ(config_.config.operand_element_type.size(), buffers_.size());
VLOG(3) << "AllGatherStartThunk created with CollectiveKernelThunk for: "
<< inst->name();
}

/*static*/ absl::Status AllGatherStartThunk::CheckImplementable(
const HloAllGatherInstruction* inst, int64_t replica_count,
int64_t partition_count) {
Expand Down Expand Up @@ -153,14 +168,56 @@ absl::StatusOr<ThunkProto> AllGatherStartThunk::ToProto() const {
return proto;
}

absl::Status AllGatherStartThunk::Prepare(const PrepareParams& params) {
TF_RETURN_IF_ERROR(CollectiveThunk::Prepare(params));
if (collective_kernel_thunk_) {
return collective_kernel_thunk_->Prepare(params);
}
return absl::OkStatus();
}

absl::Status AllGatherStartThunk::Initialize(const InitializeParams& params) {
TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params));
if (collective_kernel_thunk_) {
TF_ASSIGN_OR_RETURN(GpuCliqueKey clique_key,
GetCollectiveGpuCliqueKey(*params.collective_params,
config(), /*is_p2p=*/false));
TF_ASSIGN_OR_RETURN(
bool use_collective_kernel,
collective_kernel_thunk_->IsSupported(clique_key, *params.executor,
*params.collective_params));
if (use_collective_kernel) {
TF_RETURN_IF_ERROR(collective_kernel_thunk_->Initialize(params));
}
}
return absl::OkStatus();
}

absl::StatusOr<bool> AllGatherStartThunk::RunCollective(
const ExecuteParams& params, const GpuCliqueKey& clique_key,
se::Stream& stream, Communicator& comm) {
ASSIGN_OR_RETURN(std::vector<DeviceBufferPair> device_buffers,
ConvertToDeviceBuffers(params.buffer_allocations, buffers_,
config_.config.operand_element_type));
RETURN_IF_ERROR(xla::gpu::RunAllGather(device_buffers, stream, comm,
config_.config.use_symmetric_buffer));

// Try to use Triton collective kernel if available
if (collective_kernel_thunk_) {
TF_ASSIGN_OR_RETURN(
bool use_collective_kernel,
collective_kernel_thunk_->IsSupported(
clique_key, *params.stream->parent(), *params.collective_params));

if (use_collective_kernel) {
TF_RETURN_IF_ERROR(collective_kernel_thunk_->ExecuteOnStream(params));
return true;
}
VLOG(3) << "AllGatherStartThunk: Triton kernel not supported, falling "
"back to NCCL/RCCL";
}

// Fallback to NCCL/RCCL
TF_RETURN_IF_ERROR(xla::gpu::RunAllGather(
device_buffers, stream, comm, config_.config.use_symmetric_buffer));
return true;
}

Expand Down
9 changes: 9 additions & 0 deletions xla/backends/gpu/runtime/all_gather_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/backends/gpu/runtime/collective_kernel_thunk.h"
#include "xla/backends/gpu/runtime/collective_thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand All @@ -44,6 +45,11 @@ class AllGatherStartThunk : public CollectiveThunk {
AllGatherStartThunk(ThunkInfo thunk_info, const HloAllGatherInstruction* inst,
std::vector<Buffer> buffers,
bool p2p_memcpy_enabled = false);
AllGatherStartThunk(
ThunkInfo thunk_info, const HloAllGatherInstruction* inst,
std::vector<Buffer> buffers,
std::unique_ptr<CollectiveKernelThunk> collective_kernel_thunk,
bool p2p_memcpy_enabled = false);
AllGatherStartThunk(
ThunkInfo thunk_info,
std::shared_ptr<CollectiveThunk::AsyncEvents> async_events,
Expand Down Expand Up @@ -81,6 +87,8 @@ class AllGatherStartThunk : public CollectiveThunk {
}

protected:
absl::Status Prepare(const PrepareParams& params) override;
absl::Status Initialize(const InitializeParams& params) override;
absl::StatusOr<bool> RunCollective(const ExecuteParams& params,
const GpuCliqueKey& clique_key,
se::Stream& stream,
Expand All @@ -89,6 +97,7 @@ class AllGatherStartThunk : public CollectiveThunk {
private:
const AllGatherConfig config_;
const std::vector<Buffer> buffers_;
std::unique_ptr<CollectiveKernelThunk> collective_kernel_thunk_;
};

absl::Status RunAllGather(std::vector<DeviceBufferPair>& buffers,
Expand Down
Loading
Loading