Skip to content

Commit c7450dd

Browse files
author
gasoonjia
committed
Add GPU-side Gumbel-max sampling for CUDA graph compatibility
1 parent ee75c2e commit c7450dd

8 files changed

Lines changed: 395 additions & 42 deletions

File tree

backends/cuda/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
107107
runtime/shims/cuda_guard.cpp
108108
)
109109

110-
# Only build int4mm shim when CUDA language/toolchain is available.
110+
# Only build CUDA shims when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112-
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu)
112+
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu runtime/shims/rand.cu)
113113
endif()
114114

115115
add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources})
@@ -160,7 +160,7 @@ else()
160160
aoti_cuda_shims
161161
PRIVATE cuda_platform
162162
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
163-
CUDA::cudart ${CMAKE_DL_LIBS}
163+
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
164164
)
165165
endif()
166166

backends/cuda/cuda_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool:
145145
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
146146
return {
147147
"at::_ops::_weight_int4pack_mm::call": None,
148+
"aoti_torch_cuda_randint_low_out": None,
148149
}
149150

150151
@classmethod
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cuda/runtime/shims/rand.h>
10+
11+
#include <executorch/backends/aoti/slim/cuda/guard.h>
12+
#include <executorch/backends/aoti/slim/factory/empty.h>
13+
#include <executorch/backends/aoti/slim/util/size_util.h>
14+
#include <executorch/runtime/platform/assert.h>
15+
#include <executorch/runtime/platform/log.h>
16+
17+
#include <cuda_runtime.h>
18+
#include <curand_kernel.h>
19+
20+
#include <cstdint>
21+
#include <ctime>
22+
#include <vector>
23+
24+
namespace executorch::backends::cuda {
25+
26+
namespace c10 = executorch::backends::aoti::slim::c10;
27+
using c10::Device;
28+
using c10::DeviceIndex;
29+
using c10::DeviceType;
30+
using c10::ScalarType;
31+
using executorch::backends::aoti::slim::empty_strided;
32+
using executorch::backends::aoti::slim::IntArrayRef;
33+
using executorch::backends::aoti::slim::makeArrayRef;
34+
35+
namespace {
36+
37+
// ---- GPU-resident RNG state ----
38+
// Seed and counter live in device memory allocated during the first call
39+
// (warmup phase, before CUDA graph capture). The counter is atomically
40+
// advanced by each kernel invocation on-device, so it automatically
41+
// produces different random sequences on every CUDA graph replay.
42+
43+
struct RngState {
44+
unsigned long long seed;
45+
unsigned long long counter;
46+
};
47+
48+
static RngState* d_rng = nullptr;
49+
static bool g_rng_init_done = false;
50+
51+
// Initialize RNG state on the given stream.
52+
// Must be called during warmup (before graph capture).
53+
void ensure_rng_init(cudaStream_t stream) {
54+
if (!g_rng_init_done) {
55+
cudaMallocAsync(&d_rng, sizeof(RngState), stream);
56+
RngState h;
57+
h.seed = static_cast<unsigned long long>(time(nullptr));
58+
h.counter = 0;
59+
cudaMemcpyAsync(
60+
d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream);
61+
// Synchronize to ensure the copy completes before we return
62+
// (the host-side RngState `h` is on the stack).
63+
cudaStreamSynchronize(stream);
64+
g_rng_init_done = true;
65+
}
66+
}
67+
68+
// Philox-based randint kernel that reads seed from device-resident state
69+
// and atomically advances the counter. The counter pointer survives CUDA
70+
// graph replay, so each replay produces different values.
71+
__global__ void philox_randint_graph_kernel(
72+
int64_t* __restrict__ out,
73+
int64_t numel,
74+
int64_t low,
75+
int64_t range,
76+
RngState* __restrict__ rng) {
77+
// Each thread reads the seed and computes its unique offset.
78+
// The "base offset" is read from rng->counter. We can't atomicAdd per
79+
// thread, so we use a two-pass approach: first a single-thread kernel
80+
// advances the counter, then the main kernel uses the old value.
81+
// But that requires two kernel launches...
82+
//
83+
// Simpler: since numel=1 for randint seed generation, just one thread.
84+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
85+
if (idx < numel) {
86+
// Each invocation atomically grabs `numel` slots from the counter.
87+
// For numel=1, this is just one atomicAdd.
88+
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
89+
curandStatePhilox4_32_10_t state;
90+
curand_init(rng->seed, idx, my_offset, &state);
91+
double val = curand_uniform_double(&state);
92+
int64_t ival = static_cast<int64_t>(val * range);
93+
out[idx] = low + (ival >= range ? range - 1 : ival);
94+
}
95+
}
96+
97+
// Philox-based uniform float32 generator (graph-safe version).
98+
__global__ void philox_rand_float_graph_kernel(
99+
float* __restrict__ out,
100+
int64_t numel,
101+
RngState* __restrict__ rng) {
102+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
103+
if (idx < numel) {
104+
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
105+
curandStatePhilox4_32_10_t state;
106+
curand_init(rng->seed, idx, my_offset, &state);
107+
out[idx] = curand_uniform(&state);
108+
}
109+
}
110+
111+
// Philox-based uniform bfloat16 generator (graph-safe version).
112+
__global__ void philox_rand_bf16_graph_kernel(
113+
uint16_t* __restrict__ out,
114+
int64_t numel,
115+
RngState* __restrict__ rng) {
116+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
117+
if (idx < numel) {
118+
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
119+
curandStatePhilox4_32_10_t state;
120+
curand_init(rng->seed, idx, my_offset, &state);
121+
float val = curand_uniform(&state);
122+
uint32_t bits;
123+
memcpy(&bits, &val, sizeof(uint32_t));
124+
uint32_t lsb = (bits >> 16) & 1;
125+
bits += 0x7FFFu + lsb;
126+
out[idx] = static_cast<uint16_t>(bits >> 16);
127+
}
128+
}
129+
130+
} // anonymous namespace
131+
132+
extern "C" {
133+
134+
AOTITorchError aoti_torch_cuda_rand(
135+
const int64_t* size,
136+
int64_t size_len_,
137+
int32_t* dtype,
138+
int32_t* layout,
139+
int32_t* device,
140+
int32_t device_index_,
141+
int32_t* pin_memory,
142+
SlimTensor** ret0) {
143+
(void)layout;
144+
(void)device;
145+
(void)pin_memory;
146+
147+
ET_CHECK_OR_RETURN_ERROR(
148+
ret0 != nullptr,
149+
InvalidArgument,
150+
"aoti_torch_cuda_rand: ret0 is null");
151+
152+
// Default to float32 if dtype not specified.
153+
ScalarType scalar_type = ScalarType::Float;
154+
if (dtype != nullptr) {
155+
scalar_type = static_cast<ScalarType>(*dtype);
156+
}
157+
158+
// Compute contiguous strides and total elements.
159+
std::vector<int64_t> strides(size_len_);
160+
int64_t numel = 1;
161+
for (int64_t i = size_len_ - 1; i >= 0; i--) {
162+
strides[i] = numel;
163+
numel *= size[i];
164+
}
165+
166+
// Allocate output tensor.
167+
IntArrayRef sizes_ref(size, static_cast<size_t>(size_len_));
168+
*ret0 = new SlimTensor(empty_strided(
169+
sizes_ref,
170+
makeArrayRef(strides),
171+
scalar_type,
172+
Device(DeviceType::CUDA, static_cast<DeviceIndex>(device_index_))));
173+
174+
if (numel == 0) {
175+
return Error::Ok;
176+
}
177+
178+
// Get the current CUDA stream.
179+
auto stream_result = getCurrentCUDAStream(0);
180+
ET_CHECK_OR_RETURN_ERROR(
181+
stream_result.ok(),
182+
Internal,
183+
"aoti_torch_cuda_rand: failed to get CUDA stream");
184+
cudaStream_t stream = stream_result.get();
185+
186+
ensure_rng_init(stream);
187+
188+
constexpr int kThreads = 256;
189+
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
190+
191+
if (scalar_type == ScalarType::Float) {
192+
philox_rand_float_graph_kernel<<<blocks, kThreads, 0, stream>>>(
193+
static_cast<float*>((*ret0)->data_ptr()), numel, d_rng);
194+
} else if (scalar_type == ScalarType::BFloat16) {
195+
philox_rand_bf16_graph_kernel<<<blocks, kThreads, 0, stream>>>(
196+
static_cast<uint16_t*>((*ret0)->data_ptr()), numel, d_rng);
197+
} else {
198+
ET_LOG(
199+
Error,
200+
"aoti_torch_cuda_rand: unsupported dtype %d",
201+
static_cast<int>(scalar_type));
202+
return Error::NotSupported;
203+
}
204+
205+
return Error::Ok;
206+
}
207+
208+
AOTITorchError aoti_torch_cuda_randint_low_out(
209+
SlimTensor* out,
210+
int64_t low,
211+
int64_t high,
212+
const int64_t* size,
213+
int64_t size_len_) {
214+
ET_CHECK_OR_RETURN_ERROR(
215+
out != nullptr,
216+
InvalidArgument,
217+
"aoti_torch_cuda_randint_low_out: out tensor is null");
218+
219+
ET_CHECK_OR_RETURN_ERROR(
220+
high > low,
221+
InvalidArgument,
222+
"aoti_torch_cuda_randint_low_out: requires high > low");
223+
224+
int64_t numel = 1;
225+
for (int64_t i = 0; i < size_len_; i++) {
226+
numel *= size[i];
227+
}
228+
if (numel == 0) {
229+
return Error::Ok;
230+
}
231+
232+
// Get the current CUDA stream.
233+
auto stream_result = getCurrentCUDAStream(0);
234+
ET_CHECK_OR_RETURN_ERROR(
235+
stream_result.ok(),
236+
Internal,
237+
"aoti_torch_cuda_randint_low_out: failed to get CUDA stream");
238+
cudaStream_t stream = stream_result.get();
239+
240+
ensure_rng_init(stream);
241+
242+
int64_t range = high - low;
243+
int64_t* out_data = static_cast<int64_t*>(out->data_ptr());
244+
245+
constexpr int kThreads = 256;
246+
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
247+
philox_randint_graph_kernel<<<blocks, kThreads, 0, stream>>>(
248+
out_data, numel, low, range, d_rng);
249+
250+
return Error::Ok;
251+
}
252+
253+
} // extern "C"
254+
255+
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/rand.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
13+
#include <executorch/backends/aoti/export.h>
14+
#include <executorch/backends/aoti/slim/core/slim_tensor.h>
15+
#include <executorch/backends/aoti/slim/core/slim_tensor_view_incl.h>
16+
#include <executorch/runtime/core/error.h>
17+
18+
namespace executorch::backends::cuda {
19+
20+
using executorch::runtime::Error;
21+
using AOTITorchError = Error;
22+
23+
using SlimTensor = executorch::backends::aoti::slim::SlimTensor;
24+
25+
extern "C" {
26+
27+
/**
28+
* Generates a tensor filled with uniform random values in [0, 1).
29+
*
30+
* Implements the AOTI shim for aten::rand.default on CUDA. Uses cuRAND
31+
* Philox counter-based RNG with GPU-resident state. The counter is
32+
* atomically advanced by each kernel invocation on-device, making it
33+
* fully compatible with CUDA graph capture and replay — each replay
34+
* produces different values because the counter increments on the GPU.
35+
*
36+
* Supports float32 and bfloat16 output dtypes.
37+
*/
38+
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_rand(
39+
const int64_t* size,
40+
int64_t size_len_,
41+
int32_t* dtype,
42+
int32_t* layout,
43+
int32_t* device,
44+
int32_t device_index_,
45+
int32_t* pin_memory,
46+
SlimTensor** ret0);
47+
48+
/**
49+
* Fills a pre-allocated int64 tensor with random integers in [low, high).
50+
*
51+
* Implements the AOTI shim for aten::randint.low_out on CUDA. Used by
52+
* Inductor's Philox RNG to generate random seeds. Each thread atomically
53+
* advances a GPU-resident counter for unique offsets, making this fully
54+
* compatible with CUDA graph capture and replay.
55+
*/
56+
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(
57+
SlimTensor* out,
58+
int64_t low,
59+
int64_t high,
60+
const int64_t* size,
61+
int64_t size_len_);
62+
63+
} // extern "C"
64+
65+
} // namespace executorch::backends::cuda

examples/models/qwen3_5_moe/export.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,11 @@ def export_and_lower(model, config, args):
416416
print("Exporting decode method...")
417417
decode_tokens = torch.tensor([[0]], dtype=torch.long)
418418
decode_pos = torch.tensor([0], dtype=torch.long)
419+
decode_temperature = torch.tensor([1.0], dtype=torch.float32)
419420
with torch.no_grad():
420421
decode_ep = export(
421422
model,
422-
(decode_tokens, decode_pos),
423+
(decode_tokens, decode_pos, decode_temperature),
423424
strict=True,
424425
)
425426
print("Decode export successful!")
@@ -428,15 +429,17 @@ def export_and_lower(model, config, args):
428429
print("Exporting prefill method...")
429430
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
430431
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
432+
prefill_temperature = torch.tensor([1.0], dtype=torch.float32)
431433
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
432434
prefill_dynamic_shapes = (
433435
{1: seq_dim}, # tokens
434436
{0: seq_dim}, # input_pos
437+
None, # temperature (static scalar)
435438
)
436439
with torch.no_grad():
437440
prefill_ep = export(
438441
model,
439-
(prefill_tokens, prefill_pos),
442+
(prefill_tokens, prefill_pos, prefill_temperature),
440443
dynamic_shapes=prefill_dynamic_shapes,
441444
strict=True,
442445
)

0 commit comments

Comments
 (0)