-
Notifications
You must be signed in to change notification settings - Fork 6k
[Cpp API Compatibility] Fix CUDAContext.h to align with Pytorch
#78584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
043933d
[Cpp API Compatibility] Fix `CUDAContext.h` to align with Pytorch
youge325 a68b7d4
fix
youge325 ffd2b86
simplify getStreamFromPool
youge325 6d71323
implement getStreamFromPool and getStreamFromExternal in CUDAStream.cpp
youge325 dafd7e8
fix
youge325 2e6b1b7
try to fix dcu
youge325 ecb94cf
fix dcu again
youge325 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| // Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include <c10/cuda/CUDAStream.h> | ||
|
|
||
| #include <atomic> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <vector> | ||
|
|
||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
| #include "paddle/phi/backends/gpu/gpu_info.h" | ||
| #endif | ||
|
|
||
| namespace c10::cuda { | ||
|
|
||
| namespace { | ||
|
|
||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
|
|
||
| constexpr int kStreamsPerPool = 32; | ||
|
|
||
| std::once_flag g_init_once; | ||
| c10::DeviceIndex g_num_gpus = -1; | ||
|
|
||
| struct DevicePools { | ||
| std::vector<cudaStream_t> low_priority; | ||
| std::vector<cudaStream_t> high_priority; | ||
| std::atomic<uint32_t> lp_counter{0}; | ||
| std::atomic<uint32_t> hp_counter{0}; | ||
| std::once_flag init_flag; | ||
| }; | ||
|
|
||
| std::vector<std::unique_ptr<DevicePools>> g_pools; | ||
|
|
||
| thread_local std::vector<cudaStream_t> tls_current_streams; | ||
| thread_local bool tls_streams_initialized = false; | ||
|
|
||
| void initGlobalState() { | ||
| std::call_once(g_init_once, []() { | ||
| g_num_gpus = | ||
| static_cast<c10::DeviceIndex>(phi::backends::gpu::GetGPUDeviceCount()); | ||
| g_pools.resize(g_num_gpus); | ||
| for (auto& ptr : g_pools) { | ||
| ptr = std::make_unique<DevicePools>(); | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| void initDevicePools(c10::DeviceIndex device_index) { | ||
| phi::backends::gpu::GPUDeviceGuard guard(device_index); | ||
| int lo_pri = 0, hi_pri = 0; | ||
| C10_CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri)); | ||
|
|
||
| auto& pool = *g_pools[device_index]; | ||
| pool.low_priority.resize(kStreamsPerPool); | ||
| pool.high_priority.resize(kStreamsPerPool); | ||
|
|
||
| for (int i = 0; i < kStreamsPerPool; ++i) { | ||
| C10_CUDA_CHECK(cudaStreamCreateWithPriority( | ||
| &pool.low_priority[i], cudaStreamNonBlocking, lo_pri)); | ||
| C10_CUDA_CHECK(cudaStreamCreateWithPriority( | ||
| &pool.high_priority[i], cudaStreamNonBlocking, hi_pri)); | ||
| } | ||
| } | ||
|
|
||
| inline void check_gpu(c10::DeviceIndex device_index) { | ||
| TORCH_CHECK(device_index >= 0 && device_index < g_num_gpus, | ||
| "Device index value ", | ||
| static_cast<int>(device_index), | ||
| " is out of index range [0, ", | ||
| static_cast<int>(g_num_gpus), | ||
| ")"); | ||
| } | ||
|
|
||
| inline void initTLSCurrentStreams() { | ||
| if (!tls_streams_initialized) { | ||
| tls_current_streams.resize(g_num_gpus, nullptr); | ||
| tls_streams_initialized = true; | ||
| } | ||
| } | ||
|
|
||
| #endif // defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
|
|
||
| } // namespace | ||
|
|
||
| CUDAStream getStreamFromPool(const int priority, | ||
| c10::DeviceIndex device_index) { | ||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
| initGlobalState(); | ||
| if (device_index == -1) { | ||
| device_index = | ||
| static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId()); | ||
| } | ||
| check_gpu(device_index); | ||
|
|
||
| std::call_once( | ||
| g_pools[device_index]->init_flag, initDevicePools, device_index); | ||
|
|
||
| const uint32_t idx = (priority < 0 ? g_pools[device_index]->hp_counter++ | ||
| : g_pools[device_index]->lp_counter++) % | ||
| kStreamsPerPool; | ||
| cudaStream_t raw = (priority < 0 ? g_pools[device_index]->high_priority[idx] | ||
| : g_pools[device_index]->low_priority[idx]); | ||
|
|
||
| return make_cuda_stream(raw, device_index); | ||
| #else | ||
| TORCH_CHECK(false, "getStreamFromPool is not supported without CUDA/HIP"); | ||
| return getDefaultCUDAStream(device_index); | ||
| #endif | ||
| } | ||
|
|
||
| CUDAStream getStreamFromPool(const bool isHighPriority, | ||
| c10::DeviceIndex device_index) { | ||
| return getStreamFromPool(isHighPriority ? -1 : 0, device_index); | ||
| } | ||
|
|
||
| CUDAStream getStreamFromExternal(cudaStream_t ext_stream, | ||
| c10::DeviceIndex device_index) { | ||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
| initGlobalState(); | ||
| check_gpu(device_index); | ||
| #endif | ||
| return make_cuda_stream(ext_stream, device_index); | ||
| } | ||
|
|
||
| CUDAStream getDefaultCUDAStream(c10::DeviceIndex device_index) { | ||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
| initGlobalState(); | ||
| if (device_index == -1) { | ||
| device_index = | ||
| static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId()); | ||
| } | ||
| check_gpu(device_index); | ||
| #endif | ||
| return CUDAStream(c10::Stream( | ||
| c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CUDA, device_index))); | ||
| } | ||
|
|
||
| CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index) { | ||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
| initGlobalState(); | ||
| if (device_index == -1) { | ||
| device_index = | ||
| static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId()); | ||
| } | ||
| check_gpu(device_index); | ||
| initTLSCurrentStreams(); | ||
| cudaStream_t raw = tls_current_streams[device_index]; | ||
| if (raw == nullptr) { | ||
| return getDefaultCUDAStream(device_index); | ||
| } | ||
| return make_cuda_stream(raw, device_index); | ||
| #else | ||
| return getDefaultCUDAStream(device_index); | ||
| #endif | ||
| } | ||
|
|
||
| void setCurrentCUDAStream(CUDAStream stream) { | ||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
| initGlobalState(); | ||
| c10::DeviceIndex idx = stream.unwrap().device_index(); | ||
| check_gpu(idx); | ||
| initTLSCurrentStreams(); | ||
| tls_current_streams[idx] = stream.stream(); | ||
| #else | ||
| (void)stream; | ||
| #endif | ||
| } | ||
|
|
||
| } // namespace c10::cuda |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDAContext.hused to only includeCUDAContextLight.hwhenPADDLE_WITH_CUDA/PADDLE_WITH_HIPwas enabled. After this change the header unconditionally includesCUDAContextLight.handc10/cuda/CUDAStream.h, which pulls in CUDA-only APIs (e.g.,cudaError_t/cudaStream_t) and will break compilation in non-CUDA/HIP builds that still include this header transitively. Please restore a build-flag guard around the CUDA-specific includes (or provide CPU stubs matching the previous behavior) while still addingc10/cuda/CUDAStream.hfor CUDA/HIP builds.