Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
runtime/shims/cuda_guard.cpp
)

# Only build int4mm shim when CUDA language/toolchain is available.
# Only build CUDA shims when CUDA language/toolchain is available.
if(CMAKE_CUDA_COMPILER)
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu)
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
runtime/shims/rand.cu
)
endif()

add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources})
Expand Down Expand Up @@ -160,7 +162,7 @@ else()
aoti_cuda_shims
PRIVATE cuda_platform
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
CUDA::cudart ${CMAKE_DL_LIBS}
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
)
endif()

Expand Down
1 change: 1 addition & 0 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool:
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
return {
"at::_ops::_weight_int4pack_mm::call": None,
"aoti_torch_cuda_randint_low_out": None,
}

@classmethod
Expand Down
255 changes: 255 additions & 0 deletions backends/cuda/runtime/shims/rand.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cuda/runtime/shims/rand.h>

#include <executorch/backends/aoti/slim/cuda/guard.h>
#include <executorch/backends/aoti/slim/factory/empty.h>
#include <executorch/backends/aoti/slim/util/size_util.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>

#include <cuda_runtime.h>
#include <curand_kernel.h>

#include <cstdint>
#include <ctime>
#include <vector>

namespace executorch::backends::cuda {

namespace c10 = executorch::backends::aoti::slim::c10;
using c10::Device;
using c10::DeviceIndex;
using c10::DeviceType;
using c10::ScalarType;
using executorch::backends::aoti::slim::empty_strided;
using executorch::backends::aoti::slim::IntArrayRef;
using executorch::backends::aoti::slim::makeArrayRef;

namespace {

// ---- GPU-resident RNG state ----
// Seed and counter live in device memory allocated during the first call
// (warmup phase, before CUDA graph capture). The counter is atomically
// advanced by each kernel invocation on-device, so it automatically
// produces different random sequences on every CUDA graph replay.

struct RngState {
unsigned long long seed;
unsigned long long counter;
};

static RngState* d_rng = nullptr;
static bool g_rng_init_done = false;

// Initialize RNG state on the given stream.
// Must be called during warmup (before graph capture).
void ensure_rng_init(cudaStream_t stream) {
if (!g_rng_init_done) {
cudaMallocAsync(&d_rng, sizeof(RngState), stream);
RngState h;
h.seed = static_cast<unsigned long long>(time(nullptr));
h.counter = 0;
cudaMemcpyAsync(
d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream);
// Synchronize to ensure the copy completes before we return
// (the host-side RngState `h` is on the stack).
cudaStreamSynchronize(stream);
g_rng_init_done = true;
}
}

// Philox-based randint kernel that reads seed from device-resident state
// and atomically advances the counter. The counter pointer survives CUDA
// graph replay, so each replay produces different values.
__global__ void philox_randint_graph_kernel(
int64_t* __restrict__ out,
int64_t numel,
int64_t low,
int64_t range,
RngState* __restrict__ rng) {
// Each thread reads the seed and computes its unique offset.
// The "base offset" is read from rng->counter. We can't atomicAdd per
// thread, so we use a two-pass approach: first a single-thread kernel
// advances the counter, then the main kernel uses the old value.
// But that requires two kernel launches...
//
// Simpler: since numel=1 for randint seed generation, just one thread.
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx < numel) {
// Each invocation atomically grabs `numel` slots from the counter.
// For numel=1, this is just one atomicAdd.
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
curandStatePhilox4_32_10_t state;
curand_init(rng->seed, idx, my_offset, &state);
double val = curand_uniform_double(&state);
int64_t ival = static_cast<int64_t>(val * range);
out[idx] = low + (ival >= range ? range - 1 : ival);
}
}

// Philox-based uniform float32 generator (graph-safe version).
__global__ void philox_rand_float_graph_kernel(
float* __restrict__ out,
int64_t numel,
RngState* __restrict__ rng) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx < numel) {
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
curandStatePhilox4_32_10_t state;
curand_init(rng->seed, idx, my_offset, &state);
out[idx] = curand_uniform(&state);
}
}

// Philox-based uniform bfloat16 generator (graph-safe version).
__global__ void philox_rand_bf16_graph_kernel(
uint16_t* __restrict__ out,
int64_t numel,
RngState* __restrict__ rng) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx < numel) {
unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL);
curandStatePhilox4_32_10_t state;
curand_init(rng->seed, idx, my_offset, &state);
float val = curand_uniform(&state);
uint32_t bits;
memcpy(&bits, &val, sizeof(uint32_t));
uint32_t lsb = (bits >> 16) & 1;
bits += 0x7FFFu + lsb;
out[idx] = static_cast<uint16_t>(bits >> 16);
}
}

} // anonymous namespace

extern "C" {

AOTITorchError aoti_torch_cuda_rand(
const int64_t* size,
int64_t size_len_,
int32_t* dtype,
int32_t* layout,
int32_t* device,
int32_t device_index_,
int32_t* pin_memory,
SlimTensor** ret0) {
(void)layout;
(void)device;
(void)pin_memory;

ET_CHECK_OR_RETURN_ERROR(
ret0 != nullptr,
InvalidArgument,
"aoti_torch_cuda_rand: ret0 is null");

// Default to float32 if dtype not specified.
ScalarType scalar_type = ScalarType::Float;
if (dtype != nullptr) {
scalar_type = static_cast<ScalarType>(*dtype);
}

// Compute contiguous strides and total elements.
std::vector<int64_t> strides(size_len_);
int64_t numel = 1;
for (int64_t i = size_len_ - 1; i >= 0; i--) {
strides[i] = numel;
numel *= size[i];
}

// Allocate output tensor.
IntArrayRef sizes_ref(size, static_cast<size_t>(size_len_));
*ret0 = new SlimTensor(empty_strided(
sizes_ref,
makeArrayRef(strides),
scalar_type,
Device(DeviceType::CUDA, static_cast<DeviceIndex>(device_index_))));

if (numel == 0) {
return Error::Ok;
}

// Get the current CUDA stream.
auto stream_result = getCurrentCUDAStream(0);
ET_CHECK_OR_RETURN_ERROR(
stream_result.ok(),
Internal,
"aoti_torch_cuda_rand: failed to get CUDA stream");
cudaStream_t stream = stream_result.get();

ensure_rng_init(stream);

constexpr int kThreads = 256;
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);

if (scalar_type == ScalarType::Float) {
philox_rand_float_graph_kernel<<<blocks, kThreads, 0, stream>>>(
static_cast<float*>((*ret0)->data_ptr()), numel, d_rng);
} else if (scalar_type == ScalarType::BFloat16) {
philox_rand_bf16_graph_kernel<<<blocks, kThreads, 0, stream>>>(
static_cast<uint16_t*>((*ret0)->data_ptr()), numel, d_rng);
} else {
ET_LOG(
Error,
"aoti_torch_cuda_rand: unsupported dtype %d",
static_cast<int>(scalar_type));
return Error::NotSupported;
}

return Error::Ok;
}

AOTITorchError aoti_torch_cuda_randint_low_out(
SlimTensor* out,
int64_t low,
int64_t high,
const int64_t* size,
int64_t size_len_) {
ET_CHECK_OR_RETURN_ERROR(
out != nullptr,
InvalidArgument,
"aoti_torch_cuda_randint_low_out: out tensor is null");

ET_CHECK_OR_RETURN_ERROR(
high > low,
InvalidArgument,
"aoti_torch_cuda_randint_low_out: requires high > low");

int64_t numel = 1;
for (int64_t i = 0; i < size_len_; i++) {
numel *= size[i];
}
if (numel == 0) {
return Error::Ok;
}

// Get the current CUDA stream.
auto stream_result = getCurrentCUDAStream(0);
ET_CHECK_OR_RETURN_ERROR(
stream_result.ok(),
Internal,
"aoti_torch_cuda_randint_low_out: failed to get CUDA stream");
cudaStream_t stream = stream_result.get();

ensure_rng_init(stream);

int64_t range = high - low;
int64_t* out_data = static_cast<int64_t*>(out->data_ptr());

constexpr int kThreads = 256;
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
philox_randint_graph_kernel<<<blocks, kThreads, 0, stream>>>(
out_data, numel, low, range, d_rng);

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
65 changes: 65 additions & 0 deletions backends/cuda/runtime/shims/rand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

#include <executorch/backends/aoti/export.h>
#include <executorch/backends/aoti/slim/core/slim_tensor.h>
#include <executorch/backends/aoti/slim/core/slim_tensor_view_incl.h>
#include <executorch/runtime/core/error.h>

namespace executorch::backends::cuda {

using executorch::runtime::Error;
using AOTITorchError = Error;

using SlimTensor = executorch::backends::aoti::slim::SlimTensor;

extern "C" {

/**
* Generates a tensor filled with uniform random values in [0, 1).
*
* Implements the AOTI shim for aten::rand.default on CUDA. Uses cuRAND
* Philox counter-based RNG with GPU-resident state. The counter is
* atomically advanced by each kernel invocation on-device, making it
* fully compatible with CUDA graph capture and replay — each replay
* produces different values because the counter increments on the GPU.
*
* Supports float32 and bfloat16 output dtypes.
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_rand(
const int64_t* size,
int64_t size_len_,
int32_t* dtype,
int32_t* layout,
int32_t* device,
int32_t device_index_,
int32_t* pin_memory,
SlimTensor** ret0);

/**
* Fills a pre-allocated int64 tensor with random integers in [low, high).
*
* Implements the AOTI shim for aten::randint.low_out on CUDA. Used by
* Inductor's Philox RNG to generate random seeds. Each thread atomically
* advances a GPU-resident counter for unique offsets, making this fully
* compatible with CUDA graph capture and replay.
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(
SlimTensor* out,
int64_t low,
int64_t high,
const int64_t* size,
int64_t size_len_);

} // extern "C"

} // namespace executorch::backends::cuda
12 changes: 7 additions & 5 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,10 +629,11 @@ def _export_cuda(model, config, args):
print("Exporting decode method...")
decode_tokens = torch.tensor([[0]], dtype=torch.long)
decode_pos = torch.tensor([0], dtype=torch.long)
decode_temperature = torch.tensor([1.0], dtype=torch.float32)
with torch.no_grad():
decode_ep = export(
model,
(decode_tokens, decode_pos),
(decode_tokens, decode_pos, decode_temperature),
strict=True,
)
print("Decode export successful!")
Expand All @@ -643,18 +644,19 @@ def _export_cuda(model, config, args):
# lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes
# that reject longer prompts at runtime.
print("Exporting prefill method...")
example_prefill_len = config.max_seq_len - 1
prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long)
prefill_pos = torch.arange(example_prefill_len, dtype=torch.long)
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
prefill_temperature = torch.tensor([1.0], dtype=torch.float32)
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
prefill_dynamic_shapes = (
{1: seq_dim}, # tokens
{0: seq_dim}, # input_pos
None, # temperature (static scalar)
)
with torch.no_grad():
prefill_ep = export(
model,
(prefill_tokens, prefill_pos),
(prefill_tokens, prefill_pos, prefill_temperature),
dynamic_shapes=prefill_dynamic_shapes,
strict=True,
)
Expand Down
Loading
Loading