@@ -95,23 +95,26 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
9595 const std::vector<comm_count_t >& send_counts, const std::vector<comm_count_t >& send_offsets,
9696 T* recv_buff, const std::vector<comm_count_t >& recv_counts,
9797 const std::vector<comm_count_t >& recv_offsets, cudecompCommAxis comm_axis, cudaStream_t stream) {
98- auto & comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info ;
98+ auto comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info ;
99+ auto comm = comm_info.mpi_comm ;
99100 auto team = comm_info.nvshmem_team ;
100101 int self_rank = comm_info.rank ;
101- auto aux_stream = handle->streams [handle->device_p2p_ce_count ];
102-
103- // Enforce sync dependency between transpose operations
104- CHECK_CUDA (cudaStreamWaitEvent (stream, grid_desc->nvshmem_sync_event ));
105102
106103 // Event dependency on external stream for intra-group transfers
107104 CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], stream));
108105 for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
109106 CHECK_CUDA (cudaStreamWaitEvent (handle->streams [i], grid_desc->events [0 ], 0 ));
110107 }
111108
109+ // Using cudaEventSynchronize + barrier instead of nvshmemx_team_sync_on_stream for lower latency
110+ CHECK_CUDA (cudaEventSynchronize (grid_desc->nvshmem_sync_event ));
111+ CHECK_MPI (MPI_Barrier (comm));
112+ // nvshmemx_team_sync_on_stream(team, stream);
113+
112114 cudecompNvshmemA2AParams<T> params;
113115
114116 // Inter-group transfers (non-blocking)
117+ bool need_quiet = false ;
115118 params.send_buff = send_buff;
116119 params.recv_buff = recv_buff;
117120 int count = 0 ;
@@ -131,11 +134,13 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
131134 params.ntransfers = count;
132135 cudecomp_nvshmem_alltoallv (params, stream);
133136 count = 0 ;
137+ need_quiet = true ;
134138 }
135139 }
136140 if (count != 0 ) {
137141 params.ntransfers = count;
138142 cudecomp_nvshmem_alltoallv (params, stream);
143+ need_quiet = true ;
139144 }
140145
141146 // Intra-group transfers (blocking, scheduled after non-blocking inter-group transfers for concurrency)
@@ -146,19 +151,19 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
146151 int dst_rank_global = getGlobalRank (handle, grid_desc, comm_axis, dst_rank);
147152 if (nvshmem_ptr (recv_buff, dst_rank_global)) {
148153
149- if (comm_info.ngroups == 1 && handle->device_p2p_ce_count == 1 && count != 0 &&
154+ if (comm_info.ngroups == 1 && handle->device_p2p_ce_count == 1 &&
150155 count % CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ == 0 ) {
151156 // For single group, single P2P CE (e.g. NVSwitch), synchronize NVSHMEM team every
152157 // CUDECOMP_NVSHMEM_INTRAGROUP_SYNC_FREQ transfers This helps reduce CE contention due to accumulation of
153158 // jitter.
154159 for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
155160 CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], handle->streams [i]));
156- CHECK_CUDA (cudaStreamWaitEvent (aux_stream , grid_desc->events [0 ], 0 ));
161+ CHECK_CUDA (cudaStreamWaitEvent (handle-> streams [handle-> device_p2p_ce_count ] , grid_desc->events [0 ], 0 ));
157162 }
158163
159- nvshmemx_team_sync_on_stream (team, aux_stream );
164+ nvshmemx_team_sync_on_stream (team, handle-> streams [handle-> device_p2p_ce_count ] );
160165
161- CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], aux_stream ));
166+ CHECK_CUDA (cudaEventRecord (grid_desc->events [0 ], handle-> streams [handle-> device_p2p_ce_count ] ));
162167 for (int i = 0 ; i < handle->device_p2p_ce_count ; ++i) {
163168 CHECK_CUDA (cudaStreamWaitEvent (handle->streams [i], grid_desc->events [0 ], 0 ));
164169 }
@@ -181,7 +186,12 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
181186 CHECK_CUDA (cudaStreamWaitEvent (stream, grid_desc->events [0 ], 0 ));
182187 }
183188
184- nvshmemx_barrier_on_stream (team, stream);
189+ if (need_quiet) { nvshmemx_quiet_on_stream (stream); }
190+
191+ // Using cudaStreamSynchronize + barrier instead of nvshmemx_team_sync_on_stream for lower latency
192+ CHECK_CUDA (cudaStreamSynchronize (stream));
193+ CHECK_MPI (MPI_Barrier (comm));
194+ // nvshmemx_team_sync_on_stream(team, stream);
185195}
186196#endif
187197
@@ -227,7 +237,7 @@ static void cudecompAlltoall(const cudecompHandle_t& handle, const cudecompGridD
227237#endif
228238 }
229239 case CUDECOMP_TRANSPOSE_COMM_NCCL: {
230- auto & comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info ;
240+ auto comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info ;
231241 // For fully intra-group alltoall, use distinct NCCL local comm instead of global comm as it is faster.
232242 auto comm = (comm_info.ngroups == 1 ) ? *grid_desc->nccl_local_comm : *grid_desc->nccl_comm ;
233243
@@ -357,7 +367,7 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc
357367 const std::vector<comm_count_t >& recv_offsets,
358368 const std::vector<comm_count_t >& recv_offsets_nvshmem, cudecompCommAxis comm_axis,
359369 const std::vector<int >& src_ranks, const std::vector<int >& dst_ranks, cudaStream_t stream,
360- cudecompTransposePerformanceSample* current_sample = nullptr ) {
370+ bool & synced, cudecompTransposePerformanceSample* current_sample = nullptr ) {
361371
362372 // If there are no transfers to complete, quick return
363373 if (send_counts.size () == 0 && recv_counts.size () == 0 ) { return ; }
@@ -394,17 +404,14 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc
394404 case CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL: {
395405#ifdef ENABLE_NVSHMEM
396406 if (nvshmem_ptr (send_buff, handle->rank ) && nvshmem_ptr (recv_buff, handle->rank )) {
397- auto & comm_info = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info : grid_desc->col_comm_info ;
407+ auto comm =
408+ (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info .mpi_comm : grid_desc->col_comm_info .mpi_comm ;
409+ // auto team = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info.nvshmem_team
410+ // : grid_desc->col_comm_info.nvshmem_team;
398411 auto pl_stream = handle->streams [0 ];
399- auto aux_stream = handle->streams [handle->device_p2p_ce_count ];
400412 int self_rank = (comm_axis == CUDECOMP_COMM_ROW) ? grid_desc->row_comm_info .rank : grid_desc->col_comm_info .rank ;
401413
402- // Enforce sync dependency between transpose operations
403- CHECK_CUDA (cudaStreamWaitEvent (pl_stream, grid_desc->nvshmem_sync_event ));
404-
405- bool need_quiet = false ;
406-
407- // Inter-group transfers and self-copy (non-blocking)
414+ bool barrier = false ;
408415 for (int i = 0 ; i < src_ranks.size (); ++i) {
409416 int src_rank = src_ranks[i];
410417 int dst_rank = dst_ranks[i];
@@ -414,44 +421,39 @@ cudecompAlltoallPipelined(const cudecompHandle_t& handle, const cudecompGridDesc
414421 CHECK_CUDA (cudaMemcpyAsync (recv_buff + recv_offsets_nvshmem[self_rank], send_buff + send_offsets[self_rank],
415422 send_counts[self_rank] * sizeof (T), cudaMemcpyDeviceToDevice, stream));
416423 } else {
417- int dst_rank_global = getGlobalRank (handle, grid_desc, comm_axis, dst_rank);
418- if (nvshmem_ptr (recv_buff, dst_rank_global)) { continue ; }
419-
420424 CHECK_CUDA (cudaStreamWaitEvent (pl_stream, grid_desc->events [dst_rank], 0 ));
425+ if (!synced) {
426+ // Using cudaEventSynchronize + barrier instead of nvshmemx_team_sync_on_stream for lower latency
427+ CHECK_CUDA (cudaEventSynchronize (grid_desc->nvshmem_sync_event ));
428+ CHECK_MPI (MPI_Barrier (comm));
429+ // Only need to sync on the first remote operation of an alltoall sequence to ensure reads on other ranks
430+ // from previous communication have completed.
431+ synced = true ;
432+ }
421433
422- nvshmemx_putmem_signal_nbi_on_stream (recv_buff + recv_offsets_nvshmem[dst_rank],
423- send_buff + send_offsets[dst_rank], send_counts[dst_rank] * sizeof (T),
424- &comm_info.nvshmem_signals [comm_info.rank ], 1 , NVSHMEM_SIGNAL_SET,
425- dst_rank_global, pl_stream);
434+ int dst_rank_global = getGlobalRank (handle, grid_desc, comm_axis, dst_rank);
435+ nvshmemx_putmem_nbi_on_stream (recv_buff + recv_offsets_nvshmem[dst_rank], send_buff + send_offsets[dst_rank],
436+ send_counts[dst_rank] * sizeof (T), dst_rank_global, pl_stream);
426437
427- need_quiet = true ;
438+ barrier = true ;
428439 }
429440 }
430441
431- // Intra-group transfers (blocking, scheduled after non-blocking inter-group transfers for concurrency)
432- for (int i = 0 ; i < src_ranks.size (); ++i) {
433- int src_rank = src_ranks[i];
434- int dst_rank = dst_ranks[i];
435-
436- int dst_rank_global = getGlobalRank (handle, grid_desc, comm_axis, dst_rank);
437- if (!nvshmem_ptr (recv_buff, dst_rank_global) || src_rank == self_rank) { continue ; }
438-
439- CHECK_CUDA (cudaStreamWaitEvent (pl_stream, grid_desc->events [dst_rank], 0 ));
440-
441- nvshmemx_putmem_signal_on_stream (recv_buff + recv_offsets_nvshmem[dst_rank], send_buff + send_offsets[dst_rank],
442- send_counts[dst_rank] * sizeof (T), &comm_info.nvshmem_signals [comm_info.rank ],
443- 1 , NVSHMEM_SIGNAL_SET, dst_rank_global, pl_stream);
444- }
445-
446- if (need_quiet) { nvshmemx_quiet_on_stream (pl_stream); }
447- for (int i = 0 ; i < src_ranks.size (); ++i) {
448- int src_rank = src_ranks[i];
449- int dst_rank = dst_ranks[i];
450- if (src_rank != self_rank) {
451- nvshmemx_signal_wait_until_on_stream (&comm_info.nvshmem_signals [src_rank], NVSHMEM_CMP_EQ, 1 , pl_stream);
452- CHECK_CUDA (cudaEventRecord (grid_desc->events [dst_rank], pl_stream));
453- CHECK_CUDA (cudaStreamWaitEvent (stream, grid_desc->events [dst_rank], 0 ));
454- }
442+ if (barrier) {
443+ nvshmemx_quiet_on_stream (pl_stream);
444+ // Using cudaStreamSynchronize + barrier instead of nvshmemx_team_sync_on_stream for lower latency
445+ CHECK_CUDA (cudaStreamSynchronize (pl_stream));
446+ CHECK_MPI (MPI_Barrier (comm));
447+
448+ // nvshmemx_team_sync_on_stream(team, pl_stream);
449+ // for (int i = 0; i < src_ranks.size(); ++i) {
450+ // int src_rank = src_ranks[i];
451+ // int dst_rank = dst_ranks[i];
452+ // if (src_rank != self_rank) {
453+ // CHECK_CUDA(cudaEventRecord(grid_desc->events[dst_rank], pl_stream));
454+ // CHECK_CUDA(cudaStreamWaitEvent(stream, grid_desc->events[dst_rank], 0));
455+ // }
456+ // }
455457 }
456458 break ;
457459 } else {
@@ -592,7 +594,8 @@ static void cudecompSendRecvPair(const cudecompHandle_t& handle, const cudecompG
592594 case CUDECOMP_HALO_COMM_NVSHMEM_BLOCKING: {
593595#ifdef ENABLE_NVSHMEM
594596 if (nvshmem_ptr (send_buff, handle->rank ) && nvshmem_ptr (recv_buff, handle->rank )) {
595- nvshmemx_barrier_all_on_stream (stream);
597+ nvshmemx_quiet_on_stream (stream);
598+ nvshmemx_sync_all_on_stream (stream);
596599 for (int i = 0 ; i < send_counts.size (); ++i) {
597600 if (peer_ranks[i] == handle->rank ) {
598601 // Self-copy with cudaMemcpy
@@ -605,12 +608,14 @@ static void cudecompSendRecvPair(const cudecompHandle_t& handle, const cudecompG
605608 }
606609 }
607610 if (grid_desc->config .halo_comm_backend == CUDECOMP_HALO_COMM_NVSHMEM_BLOCKING) {
608- nvshmemx_barrier_all_on_stream (stream);
611+ nvshmemx_quiet_on_stream (stream);
612+ nvshmemx_sync_all_on_stream (stream);
609613 }
610614 }
611615
612616 if (grid_desc->config .halo_comm_backend == CUDECOMP_HALO_COMM_NVSHMEM) {
613- nvshmemx_barrier_all_on_stream (stream);
617+ nvshmemx_quiet_on_stream (stream);
618+ nvshmemx_sync_all_on_stream (stream);
614619 };
615620 break ;
616621 } else {
0 commit comments