Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/cudecomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ typedef enum {
CUDECOMP_TRANSPOSE_COMM_NCCL = 4, ///< NCCL backend
CUDECOMP_TRANSPOSE_COMM_NCCL_PL = 5, ///< NCCL backend with pipelining
CUDECOMP_TRANSPOSE_COMM_NVSHMEM = 6, ///< NVSHMEM backend
CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL = 7 ///< NVSHMEM backend with pipelining
CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL = 7, ///< NVSHMEM backend with pipelining
CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM = 8 ///< NVSHMEM backend using SM-based P2P transfers
} cudecompTransposeCommBackend_t;

/**
Expand Down
120 changes: 81 additions & 39 deletions include/internal/comm_routines.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ static inline void checkMpiInt32Limit(int64_t val, cudecompHaloCommBackend_t bac
#ifdef ENABLE_NVSHMEM
#define CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ 8 // max number of intra-group transfers to schedule between team syncs
template <typename T>
static void
nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_desc, T* send_buff,
const std::vector<comm_count_t>& send_counts, const std::vector<comm_count_t>& send_offsets,
T* recv_buff, const std::vector<comm_count_t>& recv_counts,
const std::vector<comm_count_t>& recv_offsets, cudecompCommAxis comm_axis, cudaStream_t stream) {
static void nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_desc, T* send_buff,
const std::vector<comm_count_t>& send_counts,
const std::vector<comm_count_t>& send_offsets, T* recv_buff,
const std::vector<comm_count_t>& recv_counts,
const std::vector<comm_count_t>& recv_offsets, cudecompCommAxis comm_axis, bool use_sm,
cudaStream_t stream) {
auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;
auto team = comm_info.nvshmem_team;
int self_rank = comm_info.rank;
Expand All @@ -104,23 +105,34 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->nvshmem_sync_event));

// Event dependency on external stream for intra-group transfers
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], stream));
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
CHECK_CUDA(cudaStreamWaitEvent(handle->streams[i], grid_desc->events[0], 0));
if (!use_sm) {
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], stream));
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
CHECK_CUDA(cudaStreamWaitEvent(handle->streams[i], grid_desc->events[0], 0));
}
}

bool need_barrier = false;
bool need_quiet = false;
cudecompNvshmemA2AParams<T> params;
cudecompNvshmemP2PParams<T> p2p_params;
p2p_params.send_buff = send_buff;
p2p_params.recv_buff = recv_buff;
p2p_params.block_counters = grid_desc->nvshmem_block_counters;

// Inter-group transfers (non-blocking)
params.send_buff = send_buff;
params.recv_buff = recv_buff;

int count = 0;
for (int i = 1; i < send_counts.size(); ++i) {
int src_rank, dst_rank;
getAlltoallPeerRanks(grid_desc, comm_axis, i, src_rank, dst_rank);
int dst_rank_global = getGlobalRank(handle, grid_desc, comm_axis, dst_rank);
if (nvshmem_ptr(recv_buff, dst_rank_global)) { continue; }

if (!use_sm) need_barrier = true;
need_quiet = true;
params.send_offsets[count] = send_offsets[dst_rank];
params.recv_offsets[count] = recv_offsets[dst_rank];
params.send_counts[count] = send_counts[dst_rank];
Expand All @@ -129,59 +141,87 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_

if (count == CUDECOMP_NVSHMEM_A2A_PARAM_CAPACITY) {
params.ntransfers = count;
cudecomp_nvshmem_alltoallv(params, stream);
cudecomp_nvshmem_alltoallv(params, use_sm ? &comm_info.nvshmem_signals[0] : nullptr, stream);
count = 0;
}
}
if (count != 0) {
params.ntransfers = count;
cudecomp_nvshmem_alltoallv(params, stream);
cudecomp_nvshmem_alltoallv(params, use_sm ? &comm_info.nvshmem_signals[0] : nullptr, stream);
}

// Intra-group transfers (blocking, scheduled after non-blocking inter-group transfers for concurrency)
count = 0;
for (int i = 1; i < send_counts.size(); ++i) {
for (int i = (use_sm ? 0 : 1); i < send_counts.size(); ++i) {
int src_rank, dst_rank;
getAlltoallPeerRanks(grid_desc, comm_axis, i, src_rank, dst_rank);
int dst_rank_global = getGlobalRank(handle, grid_desc, comm_axis, dst_rank);
if (nvshmem_ptr(recv_buff, dst_rank_global)) {

if (comm_info.ngroups == 1 && handle->device_p2p_ce_count == 1 && count != 0 &&
count % CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ == 0) {
// For single group, single P2P CE (e.g. NVSwitch), synchronize NVSHMEM team every
// CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ transfers This helps reduce CE contention due to accumulation of
// jitter.
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], handle->streams[i]));
CHECK_CUDA(cudaStreamWaitEvent(aux_stream, grid_desc->events[0], 0));
}
if (!use_sm) {
need_barrier = true;
if (comm_info.ngroups == 1 && handle->device_p2p_ce_count == 1 && count != 0 &&
count % CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ == 0) {
// For single group, single P2P CE (e.g. NVSwitch), synchronize NVSHMEM team every
// CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ transfers This helps reduce CE contention due to accumulation of
// jitter.
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], handle->streams[i]));
CHECK_CUDA(cudaStreamWaitEvent(aux_stream, grid_desc->events[0], 0));
}

nvshmemx_team_sync_on_stream(team, aux_stream);
nvshmemx_team_sync_on_stream(team, aux_stream);

CHECK_CUDA(cudaEventRecord(grid_desc->events[0], aux_stream));
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
CHECK_CUDA(cudaStreamWaitEvent(handle->streams[i], grid_desc->events[0], 0));
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], aux_stream));
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
CHECK_CUDA(cudaStreamWaitEvent(handle->streams[i], grid_desc->events[0], 0));
}
}
}

nvshmemx_putmem_on_stream(recv_buff + recv_offsets[dst_rank], send_buff + send_offsets[dst_rank],
send_counts[dst_rank] * sizeof(T), dst_rank_global,
handle->streams[count % handle->device_p2p_ce_count]);
count++;
nvshmemx_putmem_on_stream(recv_buff + recv_offsets[dst_rank], send_buff + send_offsets[dst_rank],
send_counts[dst_rank] * sizeof(T), dst_rank_global,
handle->streams[count % handle->device_p2p_ce_count]);
count++;
} else {
p2p_params.send_offsets[count] = send_offsets[dst_rank];
p2p_params.recv_offsets[count] = recv_offsets[dst_rank];
p2p_params.send_counts[count] = send_counts[dst_rank];
p2p_params.peer_ranks[count] = dst_rank_global;
count++;

if (count == CUDECOMP_NVSHMEM_P2P_PARAM_CAPACITY) {
p2p_params.ntransfers = count;
cudecomp_nvshmem_alltoallv_p2p(handle, p2p_params, &comm_info.nvshmem_signals[0], stream);
count = 0;
}
}
}
}

// Self-copy with cudaMemcpy
CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank],
send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream));
if (use_sm) {
if (count != 0) {
p2p_params.ntransfers = count;
cudecomp_nvshmem_alltoallv_p2p(handle, p2p_params, &comm_info.nvshmem_signals[0], stream);
}

if (need_quiet) { nvshmemx_quiet_on_stream(stream); }
nvshmemx_signal_wait_until_on_stream(&comm_info.nvshmem_signals[0], NVSHMEM_CMP_EQ,
static_cast<uint64_t>(comm_info.nranks), stream);
} else {
// Self-copy with cudaMemcpy
CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank],
send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream));
}

// Event dependency on internal streams for completion of intra-group transfers
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], handle->streams[i]));
CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->events[0], 0));
// Event dependency on internal streams for completion of intra-group transfers (not needed for SM path)
if (!use_sm) {
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], handle->streams[i]));
CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->events[0], 0));
}
}

nvshmemx_barrier_on_stream(team, stream);
if (need_barrier) { nvshmemx_barrier_on_stream(team, stream); }
}
#endif

Expand Down Expand Up @@ -213,11 +253,13 @@ static void cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridD

std::vector<MPI_Request> reqs;
switch (grid_desc->config.transpose_comm_backend) {
case CUDECOMP_TRANSPOSE_COMM_NVSHMEM: {
case CUDECOMP_TRANSPOSE_COMM_NVSHMEM:
case CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM: {
#ifdef ENABLE_NVSHMEM
if (nvshmem_ptr(send_buff, handle->rank) && nvshmem_ptr(recv_buff, handle->rank)) {
nvshmemAlltoallV(handle, grid_desc, send_buff, send_counts, send_offsets, recv_buff, recv_counts,
recv_offsets_nvshmem, comm_axis, stream);
recv_offsets_nvshmem, comm_axis,
grid_desc->config.transpose_comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM, stream);
break;
} else {
THROW_INVALID_USAGE("NVSHMEM communication backends require workspace allocated via cudecompMalloc.");
Expand Down
13 changes: 10 additions & 3 deletions include/internal/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct cudecompHandle {

// Miscellaneous
int32_t device_p2p_ce_count = 0; // number of P2P CEs available
int32_t device_num_sms = 0; // number of SMs on the device
int32_t device_max_threads_per_sm = 0; // maximum threads per SM
bool use_col_major_rank_order = false; // Flag to control whether to use column-major rank order
};

Expand Down Expand Up @@ -183,6 +185,10 @@ struct cudecompGridDesc {
std::vector<cudaEvent_t> events{nullptr}; // CUDA events used for scheduling
cudaEvent_t nvshmem_sync_event = nullptr; // NVSHMEM event used for synchronization

#ifdef ENABLE_NVSHMEM
int* nvshmem_block_counters = nullptr; // device memory counters for SM alltoallv last-block detection
#endif

cudecomp::graphCache graph_cache; // CUDA graph cache

cudecomp::ncclComm nccl_comm; // NCCL communicator (global), shared from handle
Expand Down Expand Up @@ -292,7 +298,8 @@ static inline bool haloBackendRequiresNccl(cudecompHaloCommBackend_t comm_backen
}

static inline bool transposeBackendRequiresNvshmem(cudecompTransposeCommBackend_t comm_backend) {
return (comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM || comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL);
return (comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM || comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL ||
comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM);
}

static inline bool haloBackendRequiresNvshmem(cudecompHaloCommBackend_t comm_backend) {
Expand Down Expand Up @@ -380,8 +387,8 @@ static inline void getAlltoallPeerRanks(cudecompGridDesc_t grid_desc, cudecompCo

const auto& info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;

// Quick return for single rank case
if (info.nranks == 1) {
// Return self for single rank case or when iter is zero
if (info.nranks == 1 || iter == 0) {
src_rank = info.rank;
dst_rank = info.rank;
return;
Expand Down
71 changes: 64 additions & 7 deletions include/internal/cudecomp_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#ifndef CUDECOMP_KERNELS_CUH
#define CUDECOMP_KERNELS_CUH

#include <algorithm>

#ifdef ENABLE_NVSHMEM
#include <nvshmem.h>
#endif
Expand All @@ -28,6 +30,8 @@
#define CUDECOMP_UNROLL_FACTOR (4)
#define CUDECOMP_MIN_BLOCKS_PER_SM (16)

#define CUDECOMP_NVSHMEM_NTHREADS (1024)

namespace cudecomp {

#ifdef ENABLE_NVSHMEM
Expand All @@ -47,6 +51,61 @@ __launch_bounds__(CUDECOMP_CUDA_NTHREADS) __global__

nvshmem_putmem_nbi(recv_buff + recv_offset, send_buff + send_offset, send_count * sizeof(T), peer_rank);
}

template <typename T>
__launch_bounds__(CUDECOMP_CUDA_NTHREADS) __global__
void cudecomp_nvshmem_alltoallv_signal_k(cudecompNvshmemA2AParams<T> params, uint64_t* sig_addr) {

const int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= params.ntransfers) return;

int peer_rank = params.peer_ranks[tid];
T* send_buff = params.send_buff;
T* recv_buff = params.recv_buff;
size_t send_offset = params.send_offsets[tid];
size_t recv_offset = params.recv_offsets[tid];
size_t send_count = params.send_counts[tid];

nvshmem_putmem_signal_nbi(recv_buff + recv_offset, send_buff + send_offset, send_count * sizeof(T), sig_addr, 1,
NVSHMEM_SIGNAL_ADD, peer_rank);
}

template <typename T>
__launch_bounds__(CUDECOMP_NVSHMEM_NTHREADS) __global__
void cudecomp_nvshmem_alltoallv_p2p_k(cudecompNvshmemP2PParams<T> params, uint64_t* sig_addr) {

T* send_buff = params.send_buff;
T* recv_buff = params.recv_buff;
int bid = blockIdx.x;

if (params.ntransfers > 0) {
int blocks_per_copy = gridDim.x / params.ntransfers;
int copyid = bid / blocks_per_copy;
int block_within_copy = bid % blocks_per_copy;
int peer_rank = params.peer_ranks[copyid];
size_t send_offset = params.send_offsets[copyid];
size_t recv_offset = params.recv_offsets[copyid];
size_t send_count = params.send_counts[copyid];

size_t nelems_per_block = (send_count + blocks_per_copy - 1) / blocks_per_copy;
size_t block_offset = (size_t)block_within_copy * nelems_per_block;
if (block_offset < send_count) {
size_t block_count = min(nelems_per_block, send_count - block_offset);
nvshmemx_putmem_block(recv_buff + recv_offset + block_offset, send_buff + send_offset + block_offset,
block_count * sizeof(T), peer_rank);
}

// Last block to finish this copy signals the destination PE.
nvshmem_fence();
__syncthreads();
if (threadIdx.x == 0) {
if (atomicAdd(&params.block_counters[peer_rank], 1) + 1 == blocks_per_copy) {
params.block_counters[peer_rank] = 0;
nvshmemx_signal_op(sig_addr, 1, NVSHMEM_SIGNAL_ADD, peer_rank);
}
}
}
}
#endif

template <int src_nd, int dest_nd, typename T>
Expand Down Expand Up @@ -107,7 +166,8 @@ __launch_bounds__(CUDECOMP_CUDA_NTHREADS) __global__
}

template <typename T>
void cudecomp_batched_d2d_memcpy_3d_nd_dispatch(const cudecompBatchedD2DMemcpy3DParams<T>& params,
void cudecomp_batched_d2d_memcpy_3d_nd_dispatch(cudecompHandle_t handle,
const cudecompBatchedD2DMemcpy3DParams<T>& params,
cudaStream_t stream) {
size_t N = params.extents[0][0] * params.extents[1][0] * params.extents[2][0];

Expand Down Expand Up @@ -138,12 +198,9 @@ void cudecomp_batched_d2d_memcpy_3d_nd_dispatch(const cudecompBatchedD2DMemcpy3D
int blocks_per_copy_unroll = (blocks_per_copy + CUDECOMP_UNROLL_FACTOR - 1) / CUDECOMP_UNROLL_FACTOR;
size_t total_blocks_unroll = params.ncopies * blocks_per_copy_unroll;

// Clamp minimum number of blocks from unrolling
int dev, num_sms;
CHECK_CUDA(cudaGetDevice(&dev));
CHECK_CUDA(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev));

if (total_blocks_unroll > CUDECOMP_MIN_BLOCKS_PER_SM * num_sms) { blocks_per_copy = blocks_per_copy_unroll; }
if (total_blocks_unroll > CUDECOMP_MIN_BLOCKS_PER_SM * handle->device_num_sms) {
blocks_per_copy = blocks_per_copy_unroll;
}

switch (src_nd) {
case 1:
Expand Down
Loading
Loading