Skip to content

Commit 6533c69

Browse files
authored
[Cpp API Compatibility] Fix CUDAContext.h to align with Pytorch (#78584)
1 parent d3ea546 commit 6533c69

6 files changed

Lines changed: 211 additions & 140 deletions

File tree

paddle/phi/api/include/compat/ATen/cuda/CUDAContextLight.cpp renamed to paddle/phi/api/include/compat/ATen/cuda/CUDAContext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
2020

21-
#include <ATen/cuda/CUDAContextLight.h>
21+
#include <ATen/cuda/CUDAContext.h>
2222

2323
#include <c10/core/Allocator.h>
2424
#include <mutex>

paddle/phi/api/include/compat/ATen/cuda/CUDAContext.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@
2020

2121
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
2222
#include <ATen/cuda/CUDAContextLight.h>
23+
#include <ATen/cuda/Exceptions.h>
24+
#include <c10/cuda/CUDAStream.h>
2325
#endif

paddle/phi/api/include/compat/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
collect_srcs(api_srcs SRCS c10/core/Device.cpp)
22
collect_srcs(api_srcs SRCS c10/core/Stream.cpp)
33
collect_srcs(api_srcs SRCS c10/cuda/CUDAFunctions.cpp)
4+
collect_srcs(api_srcs SRCS c10/cuda/CUDAStream.cpp)
45
collect_srcs(api_srcs SRCS c10/util/typeid.cpp)
56
collect_srcs(api_srcs SRCS ATen/cuda/EmptyTensor.cpp)
6-
collect_srcs(api_srcs SRCS ATen/cuda/CUDAContextLight.cpp)
7+
collect_srcs(api_srcs SRCS ATen/cuda/CUDAContext.cpp)
78
collect_srcs(api_srcs SRCS ATen/cuda/CUDABlas.cpp)
89
collect_srcs(api_srcs SRCS ATen/core/TensorMethods.cpp)
910
collect_srcs(api_srcs SRCS ATen/AccumulateType.cpp)

paddle/phi/api/include/compat/c10/cuda/CUDAException.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include <exception>
1818
#include <string>
1919

20+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
21+
#include <cuda_runtime.h>
22+
#endif
23+
2024
class CompatException : public std::exception {
2125
private:
2226
std::string message = {};
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <c10/cuda/CUDAStream.h>
16+
17+
#include <atomic>
18+
#include <memory>
19+
#include <mutex>
20+
#include <vector>
21+
22+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
23+
#include "paddle/phi/backends/gpu/gpu_info.h"
24+
#endif
25+
26+
namespace c10::cuda {
27+
28+
namespace {
29+
30+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
31+
32+
constexpr int kStreamsPerPool = 32;
33+
34+
std::once_flag g_init_once;
35+
c10::DeviceIndex g_num_gpus = -1;
36+
37+
struct DevicePools {
38+
std::vector<cudaStream_t> low_priority;
39+
std::vector<cudaStream_t> high_priority;
40+
std::atomic<uint32_t> lp_counter{0};
41+
std::atomic<uint32_t> hp_counter{0};
42+
std::once_flag init_flag;
43+
};
44+
45+
std::vector<std::unique_ptr<DevicePools>> g_pools;
46+
47+
thread_local std::vector<cudaStream_t> tls_current_streams;
48+
thread_local bool tls_streams_initialized = false;
49+
50+
void initGlobalState() {
51+
std::call_once(g_init_once, []() {
52+
g_num_gpus =
53+
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetGPUDeviceCount());
54+
g_pools.resize(g_num_gpus);
55+
for (auto& ptr : g_pools) {
56+
ptr = std::make_unique<DevicePools>();
57+
}
58+
});
59+
}
60+
61+
void initDevicePools(c10::DeviceIndex device_index) {
62+
phi::backends::gpu::GPUDeviceGuard guard(device_index);
63+
int lo_pri = 0, hi_pri = 0;
64+
C10_CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
65+
66+
auto& pool = *g_pools[device_index];
67+
pool.low_priority.resize(kStreamsPerPool);
68+
pool.high_priority.resize(kStreamsPerPool);
69+
70+
for (int i = 0; i < kStreamsPerPool; ++i) {
71+
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
72+
&pool.low_priority[i], cudaStreamNonBlocking, lo_pri));
73+
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
74+
&pool.high_priority[i], cudaStreamNonBlocking, hi_pri));
75+
}
76+
}
77+
78+
inline void check_gpu(c10::DeviceIndex device_index) {
79+
TORCH_CHECK(device_index >= 0 && device_index < g_num_gpus,
80+
"Device index value ",
81+
static_cast<int>(device_index),
82+
" is out of index range [0, ",
83+
static_cast<int>(g_num_gpus),
84+
")");
85+
}
86+
87+
inline void initTLSCurrentStreams() {
88+
if (!tls_streams_initialized) {
89+
tls_current_streams.resize(g_num_gpus, nullptr);
90+
tls_streams_initialized = true;
91+
}
92+
}
93+
94+
#endif // defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
95+
96+
} // namespace
97+
98+
CUDAStream getStreamFromPool(const int priority,
99+
c10::DeviceIndex device_index) {
100+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
101+
initGlobalState();
102+
if (device_index == -1) {
103+
device_index =
104+
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
105+
}
106+
check_gpu(device_index);
107+
108+
std::call_once(
109+
g_pools[device_index]->init_flag, initDevicePools, device_index);
110+
111+
const uint32_t idx = (priority < 0 ? g_pools[device_index]->hp_counter++
112+
: g_pools[device_index]->lp_counter++) %
113+
kStreamsPerPool;
114+
cudaStream_t raw = (priority < 0 ? g_pools[device_index]->high_priority[idx]
115+
: g_pools[device_index]->low_priority[idx]);
116+
117+
return make_cuda_stream(raw, device_index);
118+
#else
119+
TORCH_CHECK(false, "getStreamFromPool is not supported without CUDA/HIP");
120+
return getDefaultCUDAStream(device_index);
121+
#endif
122+
}
123+
124+
CUDAStream getStreamFromPool(const bool isHighPriority,
125+
c10::DeviceIndex device_index) {
126+
return getStreamFromPool(isHighPriority ? -1 : 0, device_index);
127+
}
128+
129+
CUDAStream getStreamFromExternal(cudaStream_t ext_stream,
130+
c10::DeviceIndex device_index) {
131+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
132+
initGlobalState();
133+
check_gpu(device_index);
134+
#endif
135+
return make_cuda_stream(ext_stream, device_index);
136+
}
137+
138+
CUDAStream getDefaultCUDAStream(c10::DeviceIndex device_index) {
139+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
140+
initGlobalState();
141+
if (device_index == -1) {
142+
device_index =
143+
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
144+
}
145+
check_gpu(device_index);
146+
#endif
147+
return CUDAStream(c10::Stream(
148+
c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CUDA, device_index)));
149+
}
150+
151+
CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index) {
152+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
153+
initGlobalState();
154+
if (device_index == -1) {
155+
device_index =
156+
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
157+
}
158+
check_gpu(device_index);
159+
initTLSCurrentStreams();
160+
cudaStream_t raw = tls_current_streams[device_index];
161+
if (raw == nullptr) {
162+
return getDefaultCUDAStream(device_index);
163+
}
164+
return make_cuda_stream(raw, device_index);
165+
#else
166+
return getDefaultCUDAStream(device_index);
167+
#endif
168+
}
169+
170+
void setCurrentCUDAStream(CUDAStream stream) {
171+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
172+
initGlobalState();
173+
c10::DeviceIndex idx = stream.unwrap().device_index();
174+
check_gpu(idx);
175+
initTLSCurrentStreams();
176+
tls_current_streams[idx] = stream.stream();
177+
#else
178+
(void)stream;
179+
#endif
180+
}
181+
182+
} // namespace c10::cuda

0 commit comments

Comments
 (0)