Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 14 additions & 108 deletions xla/backends/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ limitations under the License.
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/reduction_kind.h"
#include "xla/debug_options_flags.h"
#include "xla/service/rendezvous.h"
#include "xla/executable_run_options.h"
#include "xla/ffi/call_frame.h"
#include "xla/ffi/ffi_api.h"
Expand Down Expand Up @@ -996,23 +995,6 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer(
return shift_right(capacity_ - 1).command_buffer.get();
}

bool TracedCommandBuffer::HasEntry(
const BufferAllocations* buffer_allocation) const {
absl::InlinedVector<se::DeviceAddressBase, 4> allocs;
allocs.reserve(allocs_indices_.size());
for (auto& index : allocs_indices_) {
allocs.emplace_back(buffer_allocation->GetDeviceAddress(index));
}

for (size_t i = 0; i < capacity_; ++i) {
if (absl::c_equal(entries_[i].recorded_allocs, allocs) &&
entries_[i].command_buffer) {
return true;
}
}
return false;
}

//===----------------------------------------------------------------------===//
// TracedCommandBufferCmd
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2140,101 +2122,31 @@ absl::Status CollectiveCmd::Prepare(const Thunk::PrepareParams& params) {
return params.clique_requests->RequestClique(clique_key);
}

namespace {

struct CollectiveTraceCacheKey {
GpuCliqueKey clique_key;
const CollectiveCmd* cmd;

template <typename H>
friend H AbslHashValue(H h, const CollectiveTraceCacheKey& k) {
return H::combine(std::move(h), k.clique_key, k.cmd);
}

friend bool operator==(const CollectiveTraceCacheKey& a,
const CollectiveTraceCacheKey& b) {
return a.clique_key == b.clique_key && a.cmd == b.cmd;
}
};

} // namespace

absl::StatusOr<const se::CommandBuffer::Command*>
CollectiveCmd::RecordTracedCommand(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer,
absl::FunctionRef<absl::Status(se::Stream*)> trace,
const GpuCliqueKey& clique_key) {

se::CommandBuffer* nested_cmd_ptr = nullptr;
std::unique_ptr<se::CommandBuffer> nested_cmd_owned;

if (clique_key.is_local()) {
auto traced_cmd = record_params.state.GetOrCreate<TracedCommandBuffer>(
this, command_buffer, [&] {
const auto& debug_options = xla::GetDebugOptionsFromFlags();
return std::make_unique<TracedCommandBuffer>(
this, buffers(),
debug_options.xla_cmd_buffer_trace_cache_size());
});

bool local_hit = traced_cmd->HasEntry(execute_params.buffer_allocations);

CollectiveTraceCacheKey rendezvous_key{clique_key, this};
TF_ASSIGN_OR_RETURN(
std::shared_ptr<bool> all_hit,
xla::Rendezvous<bool>(
"collective_trace_cache", rendezvous_key, local_hit,
clique_key.num_local_participants(),
[](absl::Span<const bool*> votes) {
return std::all_of(votes.begin(), votes.end(),
[](const bool* v) { return *v; });
},
/*warn_stuck_timeout=*/absl::Seconds(10),
/*terminate_timeout=*/absl::Seconds(30)));

if (*all_hit) {
VLOG(5) << "Collective trace cache: all ranks hit, using cached graph";
TF_ASSIGN_OR_RETURN(
nested_cmd_ptr,
traced_cmd->GetOrTraceCommandBuffer(
execute_params.buffer_allocations,
execute_params.stream->parent(),
execute_params.command_buffer_trace_stream, trace, priority()));
} else {
VLOG(5) << "Collective trace cache: not all ranks hit, all retracing";
TF_ASSIGN_OR_RETURN(
nested_cmd_owned,
se::TraceCommandBufferFactory::Create(
execute_params.stream->parent(),
execute_params.command_buffer_trace_stream, trace));
nested_cmd_ptr = nested_cmd_owned.get();
}
} else {
TF_ASSIGN_OR_RETURN(
nested_cmd_owned,
se::TraceCommandBufferFactory::Create(
execute_params.stream->parent(),
execute_params.command_buffer_trace_stream, trace));
nested_cmd_ptr = nested_cmd_owned.get();
}
absl::FunctionRef<absl::Status(se::Stream*)> trace) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<se::CommandBuffer> nested_cmd,
se::TraceCommandBufferFactory::Create(
execute_params.stream->parent(),
execute_params.command_buffer_trace_stream, trace));

if (priority() != se::StreamPriority::Default) {
TF_RETURN_IF_ERROR(nested_cmd_ptr->SetPriority(priority()));
TF_RETURN_IF_ERROR(nested_cmd->SetPriority(priority()));
}

return Handle(
std::move(record_action),
[&](absl::Span<const se::CommandBuffer::Command* const> dependencies) {
return command_buffer->CreateChildCommand(
se::CommandBuffer::ChildCommandType::kCloned, *nested_cmd_ptr,
se::CommandBuffer::ChildCommandType::kCloned, *nested_cmd,
dependencies);
},
[&](const se::CommandBuffer::Command* command) {
return command_buffer->UpdateChildCommand(
se::CommandBuffer::ChildCommandType::kCloned, command,
*nested_cmd_ptr);
se::CommandBuffer::ChildCommandType::kCloned, command, *nested_cmd);
});
}

Expand Down Expand Up @@ -2294,8 +2206,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllReduceCmd::Record(
[&](se::Stream* stream) {
return RunAllReduce(reduction_kind_, device_buffers, *stream, *comm,
config().use_symmetric_buffer);
},
clique_key);
});
}

CommandBufferCmd::BufferUseVector AllReduceCmd::buffers() const {
Expand Down Expand Up @@ -2363,8 +2274,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> ReduceScatterCmd::Record(
return RunReduceScatter(
reduction_kind_, device_buffers, *stream,
*comm, config().use_symmetric_buffer);
},
clique_key);
});
}

CommandBufferCmd::BufferUseVector ReduceScatterCmd::buffers() const {
Expand Down Expand Up @@ -2433,8 +2343,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllToAllCmd::Record(
[&](se::Stream* stream) {
return RunAllToAll(has_split_dimension_, device_buffers, *stream, *comm,
config().use_symmetric_buffer);
},
clique_key);
});
}

CommandBufferCmd::BufferUseVector AllToAllCmd::buffers() const {
Expand Down Expand Up @@ -2499,8 +2408,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllGatherCmd::Record(
[&](se::Stream* stream) {
return RunAllGather(device_buffers, *stream, *comm,
config().use_symmetric_buffer);
},
clique_key);
});
}

CommandBufferCmd::BufferUseVector AllGatherCmd::buffers() const {
Expand Down Expand Up @@ -2565,8 +2473,7 @@ CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params,
execute_params, record_params, std::move(record_action), command_buffer,
[&](se::Stream* stream) {
return RunCollectiveBroadcast(device_buffers, *stream, *comm);
},
clique_key);
});
}

CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() const {
Expand Down Expand Up @@ -2648,8 +2555,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> CollectivePermuteCmd::Record(
/*use_memcpy=*/false,
/*recv_ptr_map=*/nullptr,
use_symmetric_buffer);
},
clique_key);
});
}

CommandBufferCmd::BufferUseVector CollectivePermuteCmd::buffers() const {
Expand Down
8 changes: 1 addition & 7 deletions xla/backends/gpu/runtime/command_buffer_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/backends/gpu/runtime/collective_permute_thunk.h"
#include "xla/backends/gpu/runtime/collective_thunk.h"
#include "xla/backends/gpu/runtime/copy_thunk.h"
Expand Down Expand Up @@ -612,8 +611,6 @@ class TracedCommandBuffer : public CommandBufferCmd::State {
se::Stream* stream, absl::FunctionRef<absl::Status(se::Stream*)> trace,
se::StreamPriority priority = se::StreamPriority::Default);

bool HasEntry(const BufferAllocations* buffer_allocation) const;

private:
std::vector<BufferAllocation::Index> allocs_indices_;

Expand Down Expand Up @@ -1120,16 +1117,13 @@ class CollectiveCmd : public CommandBufferCmd {

bool requires_initialization() override { return true; }

bool force_update() override { return true; }

bool IsNestedCommandBuffer() const final { return true; }

absl::StatusOr<const se::CommandBuffer::Command*> RecordTracedCommand(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer,
absl::FunctionRef<absl::Status(se::Stream*)> trace,
const GpuCliqueKey& clique_key);
absl::FunctionRef<absl::Status(se::Stream*)> trace);

bool IsAsync() const { return async_events_ != nullptr; }
std::shared_ptr<CollectiveThunk::AsyncEvents> async_events() const {
Expand Down
116 changes: 0 additions & 116 deletions xla/backends/gpu/runtime/command_buffer_cmd_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -903,120 +903,4 @@ static void BM_GetOrTraceCommandBuffer(benchmark::State& state) {

BENCHMARK(BM_GetOrTraceCommandBuffer);

TEST(TracedCommandBuffer, HasEntry) {
se::StreamExecutor* executor = GpuExecutor();
auto stream = executor->CreateStream().value();
auto traced_cmd = FakeCmd();

BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0);
BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0);

CommandBufferCmd::BufferUseVector buffers = {
BufferUse::Read(BufferAllocation::Slice(&alloc0, 0, 1024)),
BufferUse::Write(BufferAllocation::Slice(&alloc1, 0, 1024))};

TracedCommandBuffer traced_cmd_buffer(&traced_cmd, buffers,
/*capacity=*/4);

se::DeviceAddressBase mem0(reinterpret_cast<void*>(0x01234567));
se::DeviceAddressBase mem1(reinterpret_cast<void*>(0x12345670));
se::DeviceAddressBase mem2(reinterpret_cast<void*>(0x23456701));

se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations({mem0, mem1}, 0, &allocator);

// Empty cache should report no entry.
EXPECT_FALSE(traced_cmd_buffer.HasEntry(&allocations));

// Trace a command buffer for {mem0, mem1}.
se::DeviceAddress<int32_t> mem = executor->AllocateArray<int32_t>(16, 0);
auto trace = [&](se::Stream* stream) -> absl::Status {
TF_RETURN_IF_ERROR(stream->Memset32(&mem, 42, 16));
return absl::OkStatus();
};

TF_ASSERT_OK(traced_cmd_buffer
.GetOrTraceCommandBuffer(&allocations, executor,
stream.get(), trace)
.status());

// Now HasEntry should find {mem0, mem1}.
EXPECT_TRUE(traced_cmd_buffer.HasEntry(&allocations));

// Different addresses should not be found.
BufferAllocations different_allocs({mem0, mem2}, 0, &allocator);
EXPECT_FALSE(traced_cmd_buffer.HasEntry(&different_allocs));

// Trace for {mem0, mem2} and verify both entries exist.
TF_ASSERT_OK(traced_cmd_buffer
.GetOrTraceCommandBuffer(&different_allocs, executor,
stream.get(), trace)
.status());

EXPECT_TRUE(traced_cmd_buffer.HasEntry(&allocations));
EXPECT_TRUE(traced_cmd_buffer.HasEntry(&different_allocs));
}

TEST(TracedCommandBuffer, HasEntryDoesNotModifyCache) {
se::StreamExecutor* executor = GpuExecutor();
auto stream = executor->CreateStream().value();
auto traced_cmd = FakeCmd();

BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0);

CommandBufferCmd::BufferUseVector buffers = {
BufferUse::Read(BufferAllocation::Slice(&alloc0, 0, 1024))};

TracedCommandBuffer traced_cmd_buffer(&traced_cmd, buffers,
/*capacity=*/2);

se::DeviceAddressBase mem0(reinterpret_cast<void*>(0x01234567));
se::DeviceAddressBase mem1(reinterpret_cast<void*>(0x12345670));
se::DeviceAddressBase mem2(reinterpret_cast<void*>(0x23456701));

se::StreamExecutorMemoryAllocator allocator(executor);

se::DeviceAddress<int32_t> mem = executor->AllocateArray<int32_t>(16, 0);
int64_t num_traces = 0;
auto trace = [&](se::Stream* stream) -> absl::Status {
TF_RETURN_IF_ERROR(stream->Memset32(&mem, 42, 16));
num_traces++;
return absl::OkStatus();
};

// Fill cache with {mem0} and {mem1}.
BufferAllocations allocs0({mem0}, 0, &allocator);
BufferAllocations allocs1({mem1}, 0, &allocator);
BufferAllocations allocs2({mem2}, 0, &allocator);

TF_ASSERT_OK(traced_cmd_buffer
.GetOrTraceCommandBuffer(&allocs0, executor, stream.get(),
trace)
.status());
TF_ASSERT_OK(traced_cmd_buffer
.GetOrTraceCommandBuffer(&allocs1, executor, stream.get(),
trace)
.status());
EXPECT_EQ(num_traces, 2);

// HasEntry should not affect the LRU order -- calling it many times
// for allocs0 should not evict allocs1.
for (int i = 0; i < 10; i++) {
EXPECT_TRUE(traced_cmd_buffer.HasEntry(&allocs0));
EXPECT_TRUE(traced_cmd_buffer.HasEntry(&allocs1));
EXPECT_FALSE(traced_cmd_buffer.HasEntry(&allocs2));
}

// Both should still be in cache (no re-trace needed).
TF_ASSERT_OK(traced_cmd_buffer
.GetOrTraceCommandBuffer(&allocs0, executor, stream.get(),
trace)
.status());
TF_ASSERT_OK(traced_cmd_buffer
.GetOrTraceCommandBuffer(&allocs1, executor, stream.get(),
trace)
.status());
EXPECT_EQ(num_traces, 2);
}

} // namespace xla::gpu
Loading