Skip to content

Commit 90a3bc1

Browse files
authored
Add new NVSHMEM transpose communication backend with SM-based P2P copies. (#114)
* Add new NVSHMEM backend with SM-based NVLink transfers. Signed-off-by: Josh Romero <joshr@nvidia.com> * Adjust block rounding logic. Run self-copy sequentially. Signed-off-by: Josh Romero <joshr@nvidia.com> * Increase threads per block to 1024. Fuse self-copy into kernel. Signed-off-by: Josh Romero <joshr@nvidia.com> * Fuse barrier into kernel. Signed-off-by: Josh Romero <joshr@nvidia.com> * Modify block assignments in kernel. Signed-off-by: Josh Romero <joshr@nvidia.com> * Simplify self copy fusion. Signed-off-by: Josh Romero <joshr@nvidia.com> * Update intergroup transfers to use signals. Signed-off-by: Josh Romero <joshr@nvidia.com> * Fix block counter sizing. Signed-off-by: Josh Romero <joshr@nvidia.com> * Preserve existing behavior with non-SM NVSHMEM backend. Signed-off-by: Josh Romero <joshr@nvidia.com> * Cache device attributes. Signed-off-by: Josh Romero <joshr@nvidia.com> * Use separate P2P params struct for new kernel. Signed-off-by: Josh Romero <joshr@nvidia.com> * Add new backend to test_config.yaml Signed-off-by: Josh Romero <joshr@nvidia.com> * Formatting. Signed-off-by: Josh Romero <joshr@nvidia.com> * Update Fortran transpose test to handle new backend. Signed-off-by: Josh Romero <joshr@nvidia.com> --------- Signed-off-by: Josh Romero <joshr@nvidia.com>
1 parent 8e4a7e0 commit 90a3bc1

14 files changed

Lines changed: 304 additions & 90 deletions

include/cudecomp.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ typedef enum {
4848
CUDECOMP_TRANSPOSE_COMM_NCCL = 4, ///< NCCL backend
4949
CUDECOMP_TRANSPOSE_COMM_NCCL_PL = 5, ///< NCCL backend with pipelining
5050
CUDECOMP_TRANSPOSE_COMM_NVSHMEM = 6, ///< NVSHMEM backend
51-
CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL = 7 ///< NVSHMEM backend with pipelining
51+
CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL = 7, ///< NVSHMEM backend with pipelining
52+
CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM = 8 ///< NVSHMEM backend using SM-based P2P transfers
5253
} cudecompTransposeCommBackend_t;
5354

5455
/**

include/internal/comm_routines.h

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,12 @@ static inline void checkMpiInt32Limit(int64_t val, cudecompHaloCommBackend_t bac
9090
#ifdef ENABLE_NVSHMEM
9191
#define CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ 8 // max number of intra-group transfers to schedule between team syncs
9292
template <typename T>
93-
static void
94-
nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_desc, T* send_buff,
95-
const std::vector<comm_count_t>& send_counts, const std::vector<comm_count_t>& send_offsets,
96-
T* recv_buff, const std::vector<comm_count_t>& recv_counts,
97-
const std::vector<comm_count_t>& recv_offsets, cudecompCommAxis comm_axis, cudaStream_t stream) {
93+
static void nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_desc, T* send_buff,
94+
const std::vector<comm_count_t>& send_counts,
95+
const std::vector<comm_count_t>& send_offsets, T* recv_buff,
96+
const std::vector<comm_count_t>& recv_counts,
97+
const std::vector<comm_count_t>& recv_offsets, cudecompCommAxis comm_axis, bool use_sm,
98+
cudaStream_t stream) {
9899
auto& comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;
99100
auto team = comm_info.nvshmem_team;
100101
int self_rank = comm_info.rank;
@@ -104,23 +105,34 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
104105
CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->nvshmem_sync_event));
105106

106107
// Event dependency on external stream for intra-group transfers
107-
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], stream));
108-
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
109-
CHECK_CUDA(cudaStreamWaitEvent(handle->streams[i], grid_desc->events[0], 0));
108+
if (!use_sm) {
109+
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], stream));
110+
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
111+
CHECK_CUDA(cudaStreamWaitEvent(handle->streams[i], grid_desc->events[0], 0));
112+
}
110113
}
111114

115+
bool need_barrier = false;
116+
bool need_quiet = false;
112117
cudecompNvshmemA2AParams<T> params;
118+
cudecompNvshmemP2PParams<T> p2p_params;
119+
p2p_params.send_buff = send_buff;
120+
p2p_params.recv_buff = recv_buff;
121+
p2p_params.block_counters = grid_desc->nvshmem_block_counters;
113122

114123
// Inter-group transfers (non-blocking)
115124
params.send_buff = send_buff;
116125
params.recv_buff = recv_buff;
126+
117127
int count = 0;
118128
for (int i = 1; i < send_counts.size(); ++i) {
119129
int src_rank, dst_rank;
120130
getAlltoallPeerRanks(grid_desc, comm_axis, i, src_rank, dst_rank);
121131
int dst_rank_global = getGlobalRank(handle, grid_desc, comm_axis, dst_rank);
122132
if (nvshmem_ptr(recv_buff, dst_rank_global)) { continue; }
123133

134+
if (!use_sm) need_barrier = true;
135+
need_quiet = true;
124136
params.send_offsets[count] = send_offsets[dst_rank];
125137
params.recv_offsets[count] = recv_offsets[dst_rank];
126138
params.send_counts[count] = send_counts[dst_rank];
@@ -129,59 +141,87 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
129141

130142
if (count == CUDECOMP_NVSHMEM_A2A_PARAM_CAPACITY) {
131143
params.ntransfers = count;
132-
cudecomp_nvshmem_alltoallv(params, stream);
144+
cudecomp_nvshmem_alltoallv(params, use_sm ? &comm_info.nvshmem_signals[0] : nullptr, stream);
133145
count = 0;
134146
}
135147
}
136148
if (count != 0) {
137149
params.ntransfers = count;
138-
cudecomp_nvshmem_alltoallv(params, stream);
150+
cudecomp_nvshmem_alltoallv(params, use_sm ? &comm_info.nvshmem_signals[0] : nullptr, stream);
139151
}
140152

141153
// Intra-group transfers (blocking, scheduled after non-blocking inter-group transfers for concurrency)
142154
count = 0;
143-
for (int i = 1; i < send_counts.size(); ++i) {
155+
for (int i = (use_sm ? 0 : 1); i < send_counts.size(); ++i) {
144156
int src_rank, dst_rank;
145157
getAlltoallPeerRanks(grid_desc, comm_axis, i, src_rank, dst_rank);
146158
int dst_rank_global = getGlobalRank(handle, grid_desc, comm_axis, dst_rank);
147159
if (nvshmem_ptr(recv_buff, dst_rank_global)) {
148160

149-
if (comm_info.ngroups == 1 && handle->device_p2p_ce_count == 1 && count != 0 &&
150-
count % CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ == 0) {
151-
// For single group, single P2P CE (e.g. NVSwitch), synchronize NVSHMEM team every
152-
// CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ transfers This helps reduce CE contention due to accumulation of
153-
// jitter.
154-
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
155-
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], handle->streams[i]));
156-
CHECK_CUDA(cudaStreamWaitEvent(aux_stream, grid_desc->events[0], 0));
157-
}
161+
if (!use_sm) {
162+
need_barrier = true;
163+
if (comm_info.ngroups == 1 && handle->device_p2p_ce_count == 1 && count != 0 &&
164+
count % CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ == 0) {
165+
// For single group, single P2P CE (e.g. NVSwitch), synchronize NVSHMEM team every
166+
// CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ transfers This helps reduce CE contention due to accumulation of
167+
// jitter.
168+
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
169+
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], handle->streams[i]));
170+
CHECK_CUDA(cudaStreamWaitEvent(aux_stream, grid_desc->events[0], 0));
171+
}
158172

159-
nvshmemx_team_sync_on_stream(team, aux_stream);
173+
nvshmemx_team_sync_on_stream(team, aux_stream);
160174

161-
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], aux_stream));
162-
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
163-
CHECK_CUDA(cudaStreamWaitEvent(handle->streams[i], grid_desc->events[0], 0));
175+
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], aux_stream));
176+
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
177+
CHECK_CUDA(cudaStreamWaitEvent(handle->streams[i], grid_desc->events[0], 0));
178+
}
164179
}
165-
}
166180

167-
nvshmemx_putmem_on_stream(recv_buff + recv_offsets[dst_rank], send_buff + send_offsets[dst_rank],
168-
send_counts[dst_rank] * sizeof(T), dst_rank_global,
169-
handle->streams[count % handle->device_p2p_ce_count]);
170-
count++;
181+
nvshmemx_putmem_on_stream(recv_buff + recv_offsets[dst_rank], send_buff + send_offsets[dst_rank],
182+
send_counts[dst_rank] * sizeof(T), dst_rank_global,
183+
handle->streams[count % handle->device_p2p_ce_count]);
184+
count++;
185+
} else {
186+
p2p_params.send_offsets[count] = send_offsets[dst_rank];
187+
p2p_params.recv_offsets[count] = recv_offsets[dst_rank];
188+
p2p_params.send_counts[count] = send_counts[dst_rank];
189+
p2p_params.peer_ranks[count] = dst_rank_global;
190+
count++;
191+
192+
if (count == CUDECOMP_NVSHMEM_P2P_PARAM_CAPACITY) {
193+
p2p_params.ntransfers = count;
194+
cudecomp_nvshmem_alltoallv_p2p(handle, p2p_params, &comm_info.nvshmem_signals[0], stream);
195+
count = 0;
196+
}
197+
}
171198
}
172199
}
173200

174-
// Self-copy with cudaMemcpy
175-
CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank],
176-
send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream));
201+
if (use_sm) {
202+
if (count != 0) {
203+
p2p_params.ntransfers = count;
204+
cudecomp_nvshmem_alltoallv_p2p(handle, p2p_params, &comm_info.nvshmem_signals[0], stream);
205+
}
206+
207+
if (need_quiet) { nvshmemx_quiet_on_stream(stream); }
208+
nvshmemx_signal_wait_until_on_stream(&comm_info.nvshmem_signals[0], NVSHMEM_CMP_EQ,
209+
static_cast<uint64_t>(comm_info.nranks), stream);
210+
} else {
211+
// Self-copy with cudaMemcpy
212+
CHECK_CUDA(cudaMemcpyAsync(recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank],
213+
send_counts[self_rank] * sizeof(T), cudaMemcpyDeviceToDevice, stream));
214+
}
177215

178-
// Event dependency on internal streams for completion of intra-group transfers
179-
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
180-
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], handle->streams[i]));
181-
CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->events[0], 0));
216+
// Event dependency on internal streams for completion of intra-group transfers (not needed for SM path)
217+
if (!use_sm) {
218+
for (int i = 0; i < handle->device_p2p_ce_count; ++i) {
219+
CHECK_CUDA(cudaEventRecord(grid_desc->events[0], handle->streams[i]));
220+
CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->events[0], 0));
221+
}
182222
}
183223

184-
nvshmemx_barrier_on_stream(team, stream);
224+
if (need_barrier) { nvshmemx_barrier_on_stream(team, stream); }
185225
}
186226
#endif
187227

@@ -213,11 +253,13 @@ static void cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridD
213253

214254
std::vector<MPI_Request> reqs;
215255
switch (grid_desc->config.transpose_comm_backend) {
216-
case CUDECOMP_TRANSPOSE_COMM_NVSHMEM: {
256+
case CUDECOMP_TRANSPOSE_COMM_NVSHMEM:
257+
case CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM: {
217258
#ifdef ENABLE_NVSHMEM
218259
if (nvshmem_ptr(send_buff, handle->rank) && nvshmem_ptr(recv_buff, handle->rank)) {
219260
nvshmemAlltoallV(handle, grid_desc, send_buff, send_counts, send_offsets, recv_buff, recv_counts,
220-
recv_offsets_nvshmem, comm_axis, stream);
261+
recv_offsets_nvshmem, comm_axis,
262+
grid_desc->config.transpose_comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM, stream);
221263
break;
222264
} else {
223265
THROW_INVALID_USAGE("NVSHMEM communication backends require workspace allocated via cudecompMalloc.");

include/internal/common.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ struct cudecompHandle {
114114

115115
// Miscellaneous
116116
int32_t device_p2p_ce_count = 0; // number of P2P CEs available
117+
int32_t device_num_sms = 0; // number of SMs on the device
118+
int32_t device_max_threads_per_sm = 0; // maximum threads per SM
117119
bool use_col_major_rank_order = false; // Flag to control whether to use column-major rank order
118120
};
119121

@@ -183,6 +185,10 @@ struct cudecompGridDesc {
183185
std::vector<cudaEvent_t> events{nullptr}; // CUDA events used for scheduling
184186
cudaEvent_t nvshmem_sync_event = nullptr; // NVSHMEM event used for synchronization
185187

188+
#ifdef ENABLE_NVSHMEM
189+
int* nvshmem_block_counters = nullptr; // device memory counters for SM alltoallv last-block detection
190+
#endif
191+
186192
cudecomp::graphCache graph_cache; // CUDA graph cache
187193

188194
cudecomp::ncclComm nccl_comm; // NCCL communicator (global), shared from handle
@@ -292,7 +298,8 @@ static inline bool haloBackendRequiresNccl(cudecompHaloCommBackend_t comm_backen
292298
}
293299

294300
static inline bool transposeBackendRequiresNvshmem(cudecompTransposeCommBackend_t comm_backend) {
295-
return (comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM || comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL);
301+
return (comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM || comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL ||
302+
comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM);
296303
}
297304

298305
static inline bool haloBackendRequiresNvshmem(cudecompHaloCommBackend_t comm_backend) {
@@ -380,8 +387,8 @@ static inline void getAlltoallPeerRanks(cudecompGridDesc_t grid_desc, cudecompCo
380387

381388
const auto& info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info;
382389

383-
// Quick return for single rank case
384-
if (info.nranks == 1) {
390+
// Return self for single rank case or when iter is zero
391+
if (info.nranks == 1 || iter == 0) {
385392
src_rank = info.rank;
386393
dst_rank = info.rank;
387394
return;

include/internal/cudecomp_kernels.cuh

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#ifndef CUDECOMP_KERNELS_CUH
1919
#define CUDECOMP_KERNELS_CUH
2020

21+
#include <algorithm>
22+
2123
#ifdef ENABLE_NVSHMEM
2224
#include <nvshmem.h>
2325
#endif
@@ -28,6 +30,8 @@
2830
#define CUDECOMP_UNROLL_FACTOR (4)
2931
#define CUDECOMP_MIN_BLOCKS_PER_SM (16)
3032

33+
#define CUDECOMP_NVSHMEM_NTHREADS (1024)
34+
3135
namespace cudecomp {
3236

3337
#ifdef ENABLE_NVSHMEM
@@ -47,6 +51,61 @@ __launch_bounds__(CUDECOMP_CUDA_NTHREADS) __global__
4751

4852
nvshmem_putmem_nbi(recv_buff + recv_offset, send_buff + send_offset, send_count * sizeof(T), peer_rank);
4953
}
54+
55+
template <typename T>
56+
__launch_bounds__(CUDECOMP_CUDA_NTHREADS) __global__
57+
void cudecomp_nvshmem_alltoallv_signal_k(cudecompNvshmemA2AParams<T> params, uint64_t* sig_addr) {
58+
59+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
60+
if (tid >= params.ntransfers) return;
61+
62+
int peer_rank = params.peer_ranks[tid];
63+
T* send_buff = params.send_buff;
64+
T* recv_buff = params.recv_buff;
65+
size_t send_offset = params.send_offsets[tid];
66+
size_t recv_offset = params.recv_offsets[tid];
67+
size_t send_count = params.send_counts[tid];
68+
69+
nvshmem_putmem_signal_nbi(recv_buff + recv_offset, send_buff + send_offset, send_count * sizeof(T), sig_addr, 1,
70+
NVSHMEM_SIGNAL_ADD, peer_rank);
71+
}
72+
73+
template <typename T>
74+
__launch_bounds__(CUDECOMP_NVSHMEM_NTHREADS) __global__
75+
void cudecomp_nvshmem_alltoallv_p2p_k(cudecompNvshmemP2PParams<T> params, uint64_t* sig_addr) {
76+
77+
T* send_buff = params.send_buff;
78+
T* recv_buff = params.recv_buff;
79+
int bid = blockIdx.x;
80+
81+
if (params.ntransfers > 0) {
82+
int blocks_per_copy = gridDim.x / params.ntransfers;
83+
int copyid = bid / blocks_per_copy;
84+
int block_within_copy = bid % blocks_per_copy;
85+
int peer_rank = params.peer_ranks[copyid];
86+
size_t send_offset = params.send_offsets[copyid];
87+
size_t recv_offset = params.recv_offsets[copyid];
88+
size_t send_count = params.send_counts[copyid];
89+
90+
size_t nelems_per_block = (send_count + blocks_per_copy - 1) / blocks_per_copy;
91+
size_t block_offset = (size_t)block_within_copy * nelems_per_block;
92+
if (block_offset < send_count) {
93+
size_t block_count = min(nelems_per_block, send_count - block_offset);
94+
nvshmemx_putmem_block(recv_buff + recv_offset + block_offset, send_buff + send_offset + block_offset,
95+
block_count * sizeof(T), peer_rank);
96+
}
97+
98+
// Last block to finish this copy signals the destination PE.
99+
nvshmem_fence();
100+
__syncthreads();
101+
if (threadIdx.x == 0) {
102+
if (atomicAdd(&params.block_counters[peer_rank], 1) + 1 == blocks_per_copy) {
103+
params.block_counters[peer_rank] = 0;
104+
nvshmemx_signal_op(sig_addr, 1, NVSHMEM_SIGNAL_ADD, peer_rank);
105+
}
106+
}
107+
}
108+
}
50109
#endif
51110

52111
template <int src_nd, int dest_nd, typename T>
@@ -107,7 +166,8 @@ __launch_bounds__(CUDECOMP_CUDA_NTHREADS) __global__
107166
}
108167

109168
template <typename T>
110-
void cudecomp_batched_d2d_memcpy_3d_nd_dispatch(const cudecompBatchedD2DMemcpy3DParams<T>& params,
169+
void cudecomp_batched_d2d_memcpy_3d_nd_dispatch(cudecompHandle_t handle,
170+
const cudecompBatchedD2DMemcpy3DParams<T>& params,
111171
cudaStream_t stream) {
112172
size_t N = params.extents[0][0] * params.extents[1][0] * params.extents[2][0];
113173

@@ -138,12 +198,9 @@ void cudecomp_batched_d2d_memcpy_3d_nd_dispatch(const cudecompBatchedD2DMemcpy3D
138198
int blocks_per_copy_unroll = (blocks_per_copy + CUDECOMP_UNROLL_FACTOR - 1) / CUDECOMP_UNROLL_FACTOR;
139199
size_t total_blocks_unroll = params.ncopies * blocks_per_copy_unroll;
140200

141-
// Clamp minimum number of blocks from unrolling
142-
int dev, num_sms;
143-
CHECK_CUDA(cudaGetDevice(&dev));
144-
CHECK_CUDA(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev));
145-
146-
if (total_blocks_unroll > CUDECOMP_MIN_BLOCKS_PER_SM * num_sms) { blocks_per_copy = blocks_per_copy_unroll; }
201+
if (total_blocks_unroll > CUDECOMP_MIN_BLOCKS_PER_SM * handle->device_num_sms) {
202+
blocks_per_copy = blocks_per_copy_unroll;
203+
}
147204

148205
switch (src_nd) {
149206
case 1:

0 commit comments

Comments
 (0)