Skip to content

Commit c65ab5f

Browse files
committed
align getStreamFromPool with Pytorch
1 parent a68b7d4 commit c65ab5f

1 file changed

Lines changed: 92 additions & 42 deletions

File tree

paddle/phi/api/include/compat/c10/cuda/CUDAStream.h

Lines changed: 92 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <algorithm>
2222
#include <array>
2323
#include <atomic>
24+
#include <cstdint>
2425
#include <functional>
2526
#include <mutex>
2627
#include <ostream>
@@ -39,8 +40,36 @@ static constexpr int max_compile_time_stream_priorities = 4;
3940

4041
namespace detail {
4142

42-
constexpr int kStreamsPerPool = 32;
43+
constexpr int kStreamsPerPoolBits = 5;
44+
constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; // 32
4345
constexpr int kMaxDevices = 64;
46+
constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
47+
48+
// Global stream state - inline for header-only implementation
49+
inline std::array<std::once_flag, kMaxDevices>& device_flags() {
50+
static std::array<std::once_flag, kMaxDevices> flags;
51+
return flags;
52+
}
53+
54+
inline std::array<std::array<std::atomic<uint32_t>, kMaxDevices>,
55+
max_compile_time_stream_priorities>&
56+
priority_counters() {
57+
static std::array<std::array<std::atomic<uint32_t>, kMaxDevices>,
58+
max_compile_time_stream_priorities>
59+
counters;
60+
return counters;
61+
}
62+
63+
inline std::array<
64+
std::array<std::array<cudaStream_t, kStreamsPerPool>, kMaxDevices>,
65+
max_compile_time_stream_priorities>&
66+
streams() {
67+
static std::array<
68+
std::array<std::array<cudaStream_t, kStreamsPerPool>, kMaxDevices>,
69+
max_compile_time_stream_priorities>
70+
stream_arrays;
71+
return stream_arrays;
72+
}
4473

4574
inline int gpu_device_count() {
4675
static const int count = phi::backends::gpu::GetGPUDeviceCount();
@@ -59,32 +88,31 @@ inline void check_device_index(int device_index) {
5988
")");
6089
}
6190

62-
struct StreamPoolState {
63-
cudaStream_t low_priority[kStreamsPerPool]{};
64-
cudaStream_t high_priority[kStreamsPerPool]{};
65-
std::atomic<uint32_t> lp_counter{0};
66-
std::atomic<uint32_t> hp_counter{0};
67-
std::once_flag init_flag;
68-
};
69-
70-
inline StreamPoolState& get_pool(int device_index) {
71-
check_device_index(device_index);
72-
static StreamPoolState states[kMaxDevices];
73-
return states[device_index];
91+
// Init a single CUDA stream with given priority
92+
inline void initSingleStream(int priority_idx,
93+
int device_index,
94+
int stream_idx) {
95+
phi::backends::gpu::GPUDeviceGuard guard(device_index);
96+
auto& stream = streams()[priority_idx][device_index][stream_idx];
97+
// priority_idx 0 = highest priority (most negative value)
98+
int pri = -(priority_idx);
99+
C10_CUDA_CHECK(cudaStreamCreateWithPriority(&stream, kDefaultFlags, pri));
74100
}
75101

76-
inline void init_pool(int device_index, StreamPoolState* state) {
77-
phi::backends::gpu::GPUDeviceGuard guard(device_index);
78-
int lo_pri = 0, hi_pri = 0;
79-
C10_CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
102+
// Init stream pools for a device (called once per device)
103+
inline void initDeviceStreamState(int device_index) {
80104
for (int i = 0; i < kStreamsPerPool; ++i) {
81-
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
82-
&state->low_priority[i], cudaStreamNonBlocking, lo_pri));
83-
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
84-
&state->high_priority[i], cudaStreamNonBlocking, hi_pri));
105+
for (int p = 0; p < max_compile_time_stream_priorities; ++p) {
106+
initSingleStream(p, device_index, i);
107+
}
85108
}
86109
}
87110

111+
// Helper to get round-robin index
112+
inline uint32_t get_idx(std::atomic<uint32_t>* counter) {
113+
return counter->fetch_add(1) % kStreamsPerPool;
114+
}
115+
88116
struct TLSStreamState {
89117
cudaStream_t streams[kMaxDevices]{};
90118
bool has_stream[kMaxDevices]{};
@@ -95,6 +123,25 @@ inline TLSStreamState& get_tls() {
95123
return s;
96124
}
97125

126+
// Global initialization flag and max priorities
127+
inline void initGlobalStreamState() {
128+
// This is called once to initialize global state
129+
int leastPriority = 0, greatestPriority = 0;
130+
C10_CUDA_CHECK(
131+
cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority));
132+
// max_stream_priorities is computed at runtime but bounded by compile-time
133+
// constant For simplicity, we use the compile-time max
134+
}
135+
136+
inline void initCUDAStreamsOnce() {
137+
static std::once_flag init_flag;
138+
std::call_once(init_flag, initGlobalStreamState);
139+
140+
auto& tls = get_tls();
141+
// Initialize TLS current streams to default (null)
142+
// This is lazy - we don't need to pre-initialize all entries
143+
}
144+
98145
} // namespace detail
99146

100147
class CUDAStream {
@@ -191,6 +238,7 @@ inline CUDAStream make_cuda_stream(cudaStream_t raw,
191238
}
192239

193240
inline CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index = -1) {
241+
detail::initCUDAStreamsOnce();
194242
if (device_index == -1) {
195243
device_index = phi::backends::gpu::GetCurrentDeviceId();
196244
}
@@ -210,37 +258,41 @@ inline CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index = -1) {
210258

211259
inline CUDAStream getStreamFromPool(const int priority,
212260
c10::DeviceIndex device_index = -1) {
261+
detail::initCUDAStreamsOnce();
213262
if (device_index == -1) {
214263
device_index = phi::backends::gpu::GetCurrentDeviceId();
215264
}
216-
auto& state = detail::get_pool(device_index);
217-
std::call_once(state.init_flag, [device_index, &state]() {
218-
detail::init_pool(device_index, &state);
219-
});
265+
detail::check_device_index(device_index);
220266

221-
cudaStream_t raw;
267+
// Initialize stream pools for this device (once)
268+
std::call_once(detail::device_flags()[device_index],
269+
detail::initDeviceStreamState,
270+
device_index);
222271

223-
// Keep parity with PyTorch API shape: negative priority selects the
224-
// high-priority pool, non-negative selects the low-priority pool.
225-
if (priority < 0) {
226-
raw = state.high_priority[state.hp_counter.fetch_add(1) %
227-
detail::kStreamsPerPool];
228-
} else {
229-
raw = state.low_priority[state.lp_counter.fetch_add(1) %
230-
detail::kStreamsPerPool];
231-
}
272+
// Map priority to priority index: higher priority = lower (more negative)
273+
// value PyTorch: pri_idx = clamp(-priority, 0, max_stream_priorities - 1)
274+
int pri_idx =
275+
std::clamp(-priority, 0, max_compile_time_stream_priorities - 1);
276+
const auto idx =
277+
detail::get_idx(&detail::priority_counters()[pri_idx][device_index]);
278+
279+
cudaStream_t raw = detail::streams()[pri_idx][device_index][idx];
232280
return make_cuda_stream(raw, device_index);
233281
}
234282

235283
/**
236284
* Get a new stream from the CUDA stream pool.
237285
*
238-
* This overload matches PyTorch's bool-based entry point and preserves the
239-
* single-argument form `getStreamFromPool(true)` for high-priority requests.
286+
* This overload matches PyTorch's bool-based entry point.
240287
*/
241288
inline CUDAStream getStreamFromPool(const bool isHighPriority = false,
242289
c10::DeviceIndex device_index = -1) {
243-
return getStreamFromPool(isHighPriority ? -1 : 0, device_index);
290+
// High priority: -1 (highest priority)
291+
// Low priority: 0 (default priority)
292+
// Using -1 to match typical CUDA priority range and ensure
293+
// getStreamFromPool(true) and getStreamFromPool(-1) behave consistently
294+
int priority = isHighPriority ? -1 : 0;
295+
return getStreamFromPool(priority, device_index);
244296
}
245297

246298
inline CUDAStream getStreamFromExternal(cudaStream_t ext_stream,
@@ -252,20 +304,18 @@ inline CUDAStream getStreamFromExternal(cudaStream_t ext_stream,
252304
/**
253305
* Set the current CUDA stream for the device of the given stream in the
254306
* calling thread.
255-
*
256-
* Implements per-thread, per-device current stream semantics: the change is
257-
* local to the calling OS thread and does not affect any shared state such as
258-
* Paddle's GPUContext. Other threads continue to see their own current stream.
259307
*/
260308
inline void setCurrentCUDAStream(CUDAStream stream) {
261309
c10::DeviceIndex idx = stream.unwrap().device_index();
262310
detail::check_device_index(idx);
311+
detail::initCUDAStreamsOnce();
263312
auto& tls = detail::get_tls();
264313
tls.streams[idx] = stream.stream();
265314
tls.has_stream[idx] = true;
266315
}
267316

268317
inline CUDAStream getDefaultCUDAStream(c10::DeviceIndex device_index = -1) {
318+
detail::initCUDAStreamsOnce();
269319
if (device_index == -1) {
270320
device_index = phi::backends::gpu::GetCurrentDeviceId();
271321
}

0 commit comments

Comments
 (0)