@@ -90,11 +90,12 @@ static inline void checkMpiInt32Limit(int64_t val, cudecompHaloCommBackend_t bac
9090#ifdef ENABLE_NVSHMEM
9191#define CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ 8 // max number of intra-group transfers to schedule between team syncs
9292template <typename T>
93- static void
94- nvshmemAlltoallV (const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_desc, T* send_buff,
95- const std::vector<comm_count_t >& send_counts, const std::vector<comm_count_t >& send_offsets,
96- T* recv_buff, const std::vector<comm_count_t >& recv_counts,
97- const std::vector<comm_count_t >& recv_offsets, cudecompCommAxis comm_axis, cudaStream_t stream) {
93+ static void nvshmemAlltoallV (const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_desc, T* send_buff,
94+ const std::vector<comm_count_t >& send_counts,
95+ const std::vector<comm_count_t >& send_offsets, T* recv_buff,
96+ const std::vector<comm_count_t >& recv_counts,
97+ const std::vector<comm_count_t >& recv_offsets, cudecompCommAxis comm_axis, bool use_sm,
98+ cudaStream_t stream) {
9899 auto & comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info ;
99100 auto team = comm_info.nvshmem_team ;
100101 int self_rank = comm_info.rank ;
@@ -104,23 +105,34 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
104105 CHECK_CUDA (cudaStreamWaitEvent (stream, grid_desc->nvshmem_sync_event ));
105106
106107 // Event dependency on external stream for intra-group transfers
107- CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], stream));
108- for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
109- CHECK_CUDA (cudaStreamWaitEvent (handle->streams [i], grid_desc->events [0 ], 0 ));
108+ if (!use_sm) {
109+ CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], stream));
110+ for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
111+ CHECK_CUDA (cudaStreamWaitEvent (handle->streams [i], grid_desc->events [0 ], 0 ));
112+ }
110113 }
111114
115+ bool need_barrier = false ;
116+ bool need_quiet = false ;
112117 cudecompNvshmemA2AParams<T> params;
118+ cudecompNvshmemP2PParams<T> p2p_params;
119+ p2p_params.send_buff = send_buff;
120+ p2p_params.recv_buff = recv_buff;
121+ p2p_params.block_counters = grid_desc->nvshmem_block_counters ;
113122
114123 // Inter-group transfers (non-blocking)
115124 params.send_buff = send_buff;
116125 params.recv_buff = recv_buff;
126+
117127 int count = 0 ;
118128 for (int i = 1 ; i < send_counts.size (); ++i) {
119129 int src_rank, dst_rank;
120130 getAlltoallPeerRanks (grid_desc, comm_axis, i, src_rank, dst_rank);
121131 int dst_rank_global = getGlobalRank (handle, grid_desc, comm_axis, dst_rank);
122132 if (nvshmem_ptr (recv_buff, dst_rank_global)) { continue ; }
123133
134+ if (!use_sm) need_barrier = true ;
135+ need_quiet = true ;
124136 params.send_offsets [count] = send_offsets[dst_rank];
125137 params.recv_offsets [count] = recv_offsets[dst_rank];
126138 params.send_counts [count] = send_counts[dst_rank];
@@ -129,59 +141,87 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
129141
130142 if (count == CUDECOMP_NVSHMEM_A2A_PARAM_CAPACITY) {
131143 params.ntransfers = count;
132- cudecomp_nvshmem_alltoallv (params, stream);
144+ cudecomp_nvshmem_alltoallv (params, use_sm ? &comm_info. nvshmem_signals [ 0 ] : nullptr , stream);
133145 count = 0 ;
134146 }
135147 }
136148 if (count != 0 ) {
137149 params.ntransfers = count;
138- cudecomp_nvshmem_alltoallv (params, stream);
150+ cudecomp_nvshmem_alltoallv (params, use_sm ? &comm_info. nvshmem_signals [ 0 ] : nullptr , stream);
139151 }
140152
141153 // Intra-group transfers (blocking, scheduled after non-blocking inter-group transfers for concurrency)
142154 count = 0 ;
143- for (int i = 1 ; i < send_counts.size (); ++i) {
155+ for (int i = (use_sm ? 0 : 1 ) ; i < send_counts.size (); ++i) {
144156 int src_rank, dst_rank;
145157 getAlltoallPeerRanks (grid_desc, comm_axis, i, src_rank, dst_rank);
146158 int dst_rank_global = getGlobalRank (handle, grid_desc, comm_axis, dst_rank);
147159 if (nvshmem_ptr (recv_buff, dst_rank_global)) {
148160
149- if (comm_info.ngroups == 1 && handle->device_p2p_ce_count == 1 && count != 0 &&
150- count % CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ == 0 ) {
151- // For single group, single P2P CE (e.g. NVSwitch), synchronize NVSHMEM team every
152- // CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ transfers This helps reduce CE contention due to accumulation of
153- // jitter.
154- for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
155- CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], handle->streams [i]));
156- CHECK_CUDA (cudaStreamWaitEvent (aux_stream, grid_desc->events [0 ], 0 ));
157- }
161+ if (!use_sm) {
162+ need_barrier = true ;
163+ if (comm_info.ngroups == 1 && handle->device_p2p_ce_count == 1 && count != 0 &&
164+ count % CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ == 0 ) {
165+ // For single group, single P2P CE (e.g. NVSwitch), synchronize NVSHMEM team every
166+ // CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ transfers This helps reduce CE contention due to accumulation of
167+ // jitter.
168+ for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
169+ CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], handle->streams [i]));
170+ CHECK_CUDA (cudaStreamWaitEvent (aux_stream, grid_desc->events [0 ], 0 ));
171+ }
158172
159- nvshmemx_team_sync_on_stream (team, aux_stream);
173+ nvshmemx_team_sync_on_stream (team, aux_stream);
160174
161- CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], aux_stream));
162- for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
163- CHECK_CUDA (cudaStreamWaitEvent (handle->streams [i], grid_desc->events [0 ], 0 ));
175+ CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], aux_stream));
176+ for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
177+ CHECK_CUDA (cudaStreamWaitEvent (handle->streams [i], grid_desc->events [0 ], 0 ));
178+ }
164179 }
165- }
166180
167- nvshmemx_putmem_on_stream (recv_buff + recv_offsets[dst_rank], send_buff + send_offsets[dst_rank],
168- send_counts[dst_rank] * sizeof (T), dst_rank_global,
169- handle->streams [count % handle->device_p2p_ce_count ]);
170- count++;
181+ nvshmemx_putmem_on_stream (recv_buff + recv_offsets[dst_rank], send_buff + send_offsets[dst_rank],
182+ send_counts[dst_rank] * sizeof (T), dst_rank_global,
183+ handle->streams [count % handle->device_p2p_ce_count ]);
184+ count++;
185+ } else {
186+ p2p_params.send_offsets [count] = send_offsets[dst_rank];
187+ p2p_params.recv_offsets [count] = recv_offsets[dst_rank];
188+ p2p_params.send_counts [count] = send_counts[dst_rank];
189+ p2p_params.peer_ranks [count] = dst_rank_global;
190+ count++;
191+
192+ if (count == CUDECOMP_NVSHMEM_P2P_PARAM_CAPACITY) {
193+ p2p_params.ntransfers = count;
194+ cudecomp_nvshmem_alltoallv_p2p (handle, p2p_params, &comm_info.nvshmem_signals [0 ], stream);
195+ count = 0 ;
196+ }
197+ }
171198 }
172199 }
173200
174- // Self-copy with cudaMemcpy
175- CHECK_CUDA (cudaMemcpyAsync (recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank],
176- send_counts[self_rank] * sizeof (T), cudaMemcpyDeviceToDevice, stream));
201+ if (use_sm) {
202+ if (count != 0 ) {
203+ p2p_params.ntransfers = count;
204+ cudecomp_nvshmem_alltoallv_p2p (handle, p2p_params, &comm_info.nvshmem_signals [0 ], stream);
205+ }
206+
207+ if (need_quiet) { nvshmemx_quiet_on_stream (stream); }
208+ nvshmemx_signal_wait_until_on_stream (&comm_info.nvshmem_signals [0 ], NVSHMEM_CMP_EQ,
209+ static_cast <uint64_t >(comm_info.nranks ), stream);
210+ } else {
211+ // Self-copy with cudaMemcpy
212+ CHECK_CUDA (cudaMemcpyAsync (recv_buff + recv_offsets[self_rank], send_buff + send_offsets[self_rank],
213+ send_counts[self_rank] * sizeof (T), cudaMemcpyDeviceToDevice, stream));
214+ }
177215
178- // Event dependency on internal streams for completion of intra-group transfers
179- for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
180- CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], handle->streams [i]));
181- CHECK_CUDA (cudaStreamWaitEvent (stream, grid_desc->events [0 ], 0 ));
216+ // Event dependency on internal streams for completion of intra-group transfers (not needed for SM path)
217+ if (!use_sm) {
218+ for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
219+ CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], handle->streams [i]));
220+ CHECK_CUDA (cudaStreamWaitEvent (stream, grid_desc->events [0 ], 0 ));
221+ }
182222 }
183223
184- nvshmemx_barrier_on_stream (team, stream);
224+ if (need_barrier) { nvshmemx_barrier_on_stream (team, stream); }
185225}
186226#endif
187227
@@ -213,11 +253,13 @@ static void cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridD
213253
214254 std::vector<MPI_Request> reqs;
215255 switch (grid_desc->config .transpose_comm_backend ) {
216- case CUDECOMP_TRANSPOSE_COMM_NVSHMEM: {
256+ case CUDECOMP_TRANSPOSE_COMM_NVSHMEM:
257+ case CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM: {
217258#ifdef ENABLE_NVSHMEM
218259 if (nvshmem_ptr (send_buff, handle->rank ) && nvshmem_ptr (recv_buff, handle->rank )) {
219260 nvshmemAlltoallV (handle, grid_desc, send_buff, send_counts, send_offsets, recv_buff, recv_counts,
220- recv_offsets_nvshmem, comm_axis, stream);
261+ recv_offsets_nvshmem, comm_axis,
262+ grid_desc->config .transpose_comm_backend == CUDECOMP_TRANSPOSE_COMM_NVSHMEM_SM, stream);
221263 break ;
222264 } else {
223265 THROW_INVALID_USAGE (" NVSHMEM communication backends require workspace allocated via cudecompMalloc." );
0 commit comments