2121#include < algorithm>
2222#include < array>
2323#include < atomic>
24+ #include < cstdint>
2425#include < functional>
2526#include < mutex>
2627#include < ostream>
@@ -39,8 +40,36 @@ static constexpr int max_compile_time_stream_priorities = 4;
3940
4041namespace detail {
4142
42- constexpr int kStreamsPerPool = 32 ;
43+ constexpr int kStreamsPerPoolBits = 5 ;
44+ constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits ; // 32
4345constexpr int kMaxDevices = 64 ;
46+ constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
47+
48+ // Global stream state - inline for header-only implementation
49+ inline std::array<std::once_flag, kMaxDevices >& device_flags () {
50+ static std::array<std::once_flag, kMaxDevices > flags;
51+ return flags;
52+ }
53+
54+ inline std::array<std::array<std::atomic<uint32_t >, kMaxDevices >,
55+ max_compile_time_stream_priorities>&
56+ priority_counters () {
57+ static std::array<std::array<std::atomic<uint32_t >, kMaxDevices >,
58+ max_compile_time_stream_priorities>
59+ counters;
60+ return counters;
61+ }
62+
63+ inline std::array<
64+ std::array<std::array<cudaStream_t, kStreamsPerPool >, kMaxDevices >,
65+ max_compile_time_stream_priorities>&
66+ streams () {
67+ static std::array<
68+ std::array<std::array<cudaStream_t, kStreamsPerPool >, kMaxDevices >,
69+ max_compile_time_stream_priorities>
70+ stream_arrays;
71+ return stream_arrays;
72+ }
4473
4574inline int gpu_device_count () {
4675 static const int count = phi::backends::gpu::GetGPUDeviceCount ();
@@ -59,32 +88,31 @@ inline void check_device_index(int device_index) {
5988 " )" );
6089}
6190
62- struct StreamPoolState {
63- cudaStream_t low_priority[kStreamsPerPool ]{};
64- cudaStream_t high_priority[kStreamsPerPool ]{};
65- std::atomic<uint32_t > lp_counter{0 };
66- std::atomic<uint32_t > hp_counter{0 };
67- std::once_flag init_flag;
68- };
69-
70- inline StreamPoolState& get_pool (int device_index) {
71- check_device_index (device_index);
72- static StreamPoolState states[kMaxDevices ];
73- return states[device_index];
91+ // Init a single CUDA stream with given priority
92+ inline void initSingleStream (int priority_idx,
93+ int device_index,
94+ int stream_idx) {
95+ phi::backends::gpu::GPUDeviceGuard guard (device_index);
96+ auto & stream = streams ()[priority_idx][device_index][stream_idx];
97+ // priority_idx 0 = highest priority (most negative value)
98+ int pri = -(priority_idx);
99+ C10_CUDA_CHECK (cudaStreamCreateWithPriority (&stream, kDefaultFlags , pri));
74100}
75101
76- inline void init_pool (int device_index, StreamPoolState* state) {
77- phi::backends::gpu::GPUDeviceGuard guard (device_index);
78- int lo_pri = 0 , hi_pri = 0 ;
79- C10_CUDA_CHECK (cudaDeviceGetStreamPriorityRange (&lo_pri, &hi_pri));
102+ // Init stream pools for a device (called once per device)
103+ inline void initDeviceStreamState (int device_index) {
80104 for (int i = 0 ; i < kStreamsPerPool ; ++i) {
81- C10_CUDA_CHECK (cudaStreamCreateWithPriority (
82- &state->low_priority [i], cudaStreamNonBlocking, lo_pri));
83- C10_CUDA_CHECK (cudaStreamCreateWithPriority (
84- &state->high_priority [i], cudaStreamNonBlocking, hi_pri));
105+ for (int p = 0 ; p < max_compile_time_stream_priorities; ++p) {
106+ initSingleStream (p, device_index, i);
107+ }
85108 }
86109}
87110
111+ // Helper to get round-robin index
112+ inline uint32_t get_idx (std::atomic<uint32_t >* counter) {
113+ return counter->fetch_add (1 ) % kStreamsPerPool ;
114+ }
115+
88116struct TLSStreamState {
89117 cudaStream_t streams[kMaxDevices ]{};
90118 bool has_stream[kMaxDevices ]{};
@@ -95,6 +123,25 @@ inline TLSStreamState& get_tls() {
95123 return s;
96124}
97125
126+ // Global initialization flag and max priorities
127+ inline void initGlobalStreamState () {
128+ // This is called once to initialize global state
129+ int leastPriority = 0 , greatestPriority = 0 ;
130+ C10_CUDA_CHECK (
131+ cudaDeviceGetStreamPriorityRange (&leastPriority, &greatestPriority));
132+ // max_stream_priorities is computed at runtime but bounded by compile-time
133+ // constant For simplicity, we use the compile-time max
134+ }
135+
136+ inline void initCUDAStreamsOnce () {
137+ static std::once_flag init_flag;
138+ std::call_once (init_flag, initGlobalStreamState);
139+
140+ auto & tls = get_tls ();
141+ // Initialize TLS current streams to default (null)
142+ // This is lazy - we don't need to pre-initialize all entries
143+ }
144+
98145} // namespace detail
99146
100147class CUDAStream {
@@ -191,6 +238,7 @@ inline CUDAStream make_cuda_stream(cudaStream_t raw,
191238}
192239
193240inline CUDAStream getCurrentCUDAStream (c10::DeviceIndex device_index = -1 ) {
241+ detail::initCUDAStreamsOnce ();
194242 if (device_index == -1 ) {
195243 device_index = phi::backends::gpu::GetCurrentDeviceId ();
196244 }
@@ -210,37 +258,41 @@ inline CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index = -1) {
210258
211259inline CUDAStream getStreamFromPool (const int priority,
212260 c10::DeviceIndex device_index = -1 ) {
261+ detail::initCUDAStreamsOnce ();
213262 if (device_index == -1 ) {
214263 device_index = phi::backends::gpu::GetCurrentDeviceId ();
215264 }
216- auto & state = detail::get_pool (device_index);
217- std::call_once (state.init_flag , [device_index, &state]() {
218- detail::init_pool (device_index, &state);
219- });
265+ detail::check_device_index (device_index);
220266
221- cudaStream_t raw;
267+ // Initialize stream pools for this device (once)
268+ std::call_once (detail::device_flags ()[device_index],
269+ detail::initDeviceStreamState,
270+ device_index);
222271
223- // Keep parity with PyTorch API shape: negative priority selects the
224- // high-priority pool, non-negative selects the low-priority pool.
225- if (priority < 0 ) {
226- raw = state.high_priority [state.hp_counter .fetch_add (1 ) %
227- detail::kStreamsPerPool ];
228- } else {
229- raw = state.low_priority [state.lp_counter .fetch_add (1 ) %
230- detail::kStreamsPerPool ];
231- }
272+ // Map priority to priority index: higher priority = lower (more negative)
273+ // value PyTorch: pri_idx = clamp(-priority, 0, max_stream_priorities - 1)
274+ int pri_idx =
275+ std::clamp (-priority, 0 , max_compile_time_stream_priorities - 1 );
276+ const auto idx =
277+ detail::get_idx (&detail::priority_counters ()[pri_idx][device_index]);
278+
279+ cudaStream_t raw = detail::streams ()[pri_idx][device_index][idx];
232280 return make_cuda_stream (raw, device_index);
233281}
234282
235283/* *
236284 * Get a new stream from the CUDA stream pool.
237285 *
238- * This overload matches PyTorch's bool-based entry point and preserves the
239- * single-argument form `getStreamFromPool(true)` for high-priority requests.
286+ * This overload matches PyTorch's bool-based entry point.
240287 */
241288inline CUDAStream getStreamFromPool (const bool isHighPriority = false ,
242289 c10::DeviceIndex device_index = -1 ) {
243- return getStreamFromPool (isHighPriority ? -1 : 0 , device_index);
290+ // High priority: -1 (highest priority)
291+ // Low priority: 0 (default priority)
292+ // Using -1 to match typical CUDA priority range and ensure
293+ // getStreamFromPool(true) and getStreamFromPool(-1) behave consistently
294+ int priority = isHighPriority ? -1 : 0 ;
295+ return getStreamFromPool (priority, device_index);
244296}
245297
246298inline CUDAStream getStreamFromExternal (cudaStream_t ext_stream,
@@ -252,20 +304,18 @@ inline CUDAStream getStreamFromExternal(cudaStream_t ext_stream,
252304/* *
253305 * Set the current CUDA stream for the device of the given stream in the
254306 * calling thread.
255- *
256- * Implements per-thread, per-device current stream semantics: the change is
257- * local to the calling OS thread and does not affect any shared state such as
258- * Paddle's GPUContext. Other threads continue to see their own current stream.
259307 */
260308inline void setCurrentCUDAStream (CUDAStream stream) {
261309 c10::DeviceIndex idx = stream.unwrap ().device_index ();
262310 detail::check_device_index (idx);
311+ detail::initCUDAStreamsOnce ();
263312 auto & tls = detail::get_tls ();
264313 tls.streams [idx] = stream.stream ();
265314 tls.has_stream [idx] = true ;
266315}
267316
268317inline CUDAStream getDefaultCUDAStream (c10::DeviceIndex device_index = -1 ) {
318+ detail::initCUDAStreamsOnce ();
269319 if (device_index == -1 ) {
270320 device_index = phi::backends::gpu::GetCurrentDeviceId ();
271321 }
0 commit comments