diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index b3e82e2ec51..a5c5571d22e 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -107,9 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp runtime/shims/cuda_guard.cpp ) -# Only build int4mm shim when CUDA language/toolchain is available. +# Only build CUDA shims when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu + runtime/shims/rand.cu runtime/shims/sort.cu ) endif() @@ -162,7 +163,7 @@ else() aoti_cuda_shims PRIVATE cuda_platform PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive - CUDA::cudart ${CMAKE_DL_LIBS} + CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS} ) endif() diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 061b0d6a29a..5c9196acfcf 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool: def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "at::_ops::_weight_int4pack_mm::call": None, + "aoti_torch_cuda_randint_low_out": None, "at::_ops::sort_stable::call": None, } diff --git a/backends/cuda/runtime/shims/rand.cu b/backends/cuda/runtime/shims/rand.cu new file mode 100644 index 00000000000..c36b9ffb7bc --- /dev/null +++ b/backends/cuda/runtime/shims/rand.cu @@ -0,0 +1,255 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace executorch::backends::cuda { + +namespace c10 = executorch::backends::aoti::slim::c10; +using c10::Device; +using c10::DeviceIndex; +using c10::DeviceType; +using c10::ScalarType; +using executorch::backends::aoti::slim::empty_strided; +using executorch::backends::aoti::slim::IntArrayRef; +using executorch::backends::aoti::slim::makeArrayRef; + +namespace { + +// ---- GPU-resident RNG state ---- +// Seed and counter live in device memory allocated during the first call +// (warmup phase, before CUDA graph capture). The counter is atomically +// advanced by each kernel invocation on-device, so it automatically +// produces different random sequences on every CUDA graph replay. + +struct RngState { + unsigned long long seed; + unsigned long long counter; +}; + +static RngState* d_rng = nullptr; +static bool g_rng_init_done = false; + +// Initialize RNG state on the given stream. +// Must be called during warmup (before graph capture). +void ensure_rng_init(cudaStream_t stream) { + if (!g_rng_init_done) { + cudaMallocAsync(&d_rng, sizeof(RngState), stream); + RngState h; + h.seed = static_cast(time(nullptr)); + h.counter = 0; + cudaMemcpyAsync( + d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream); + // Synchronize to ensure the copy completes before we return + // (the host-side RngState `h` is on the stack). + cudaStreamSynchronize(stream); + g_rng_init_done = true; + } +} + +// Philox-based randint kernel that reads seed from device-resident state +// and atomically advances the counter. The counter pointer survives CUDA +// graph replay, so each replay produces different values. +__global__ void philox_randint_graph_kernel( + int64_t* __restrict__ out, + int64_t numel, + int64_t low, + int64_t range, + RngState* __restrict__ rng) { + // Each thread reads the seed and computes its unique offset. + // The "base offset" is read from rng->counter. We can't atomicAdd per + // thread, so we use a two-pass approach: first a single-thread kernel + // advances the counter, then the main kernel uses the old value. + // But that requires two kernel launches... + // + // Simpler: since numel=1 for randint seed generation, just one thread. + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + // Each invocation atomically grabs `numel` slots from the counter. + // For numel=1, this is just one atomicAdd. + unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL); + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, my_offset, &state); + double val = curand_uniform_double(&state); + int64_t ival = static_cast(val * range); + out[idx] = low + (ival >= range ? range - 1 : ival); + } +} + +// Philox-based uniform float32 generator (graph-safe version). +__global__ void philox_rand_float_graph_kernel( + float* __restrict__ out, + int64_t numel, + RngState* __restrict__ rng) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL); + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, my_offset, &state); + out[idx] = curand_uniform(&state); + } +} + +// Philox-based uniform bfloat16 generator (graph-safe version). +__global__ void philox_rand_bf16_graph_kernel( + uint16_t* __restrict__ out, + int64_t numel, + RngState* __restrict__ rng) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL); + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, my_offset, &state); + float val = curand_uniform(&state); + uint32_t bits; + memcpy(&bits, &val, sizeof(uint32_t)); + uint32_t lsb = (bits >> 16) & 1; + bits += 0x7FFFu + lsb; + out[idx] = static_cast(bits >> 16); + } +} + +} // anonymous namespace + +extern "C" { + +AOTITorchError aoti_torch_cuda_rand( + const int64_t* size, + int64_t size_len_, + int32_t* dtype, + int32_t* layout, + int32_t* device, + int32_t device_index_, + int32_t* pin_memory, + SlimTensor** ret0) { + (void)layout; + (void)device; + (void)pin_memory; + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda_rand: ret0 is null"); + + // Default to float32 if dtype not specified. + ScalarType scalar_type = ScalarType::Float; + if (dtype != nullptr) { + scalar_type = static_cast(*dtype); + } + + // Compute contiguous strides and total elements. + std::vector strides(size_len_); + int64_t numel = 1; + for (int64_t i = size_len_ - 1; i >= 0; i--) { + strides[i] = numel; + numel *= size[i]; + } + + // Allocate output tensor. + IntArrayRef sizes_ref(size, static_cast(size_len_)); + *ret0 = new SlimTensor(empty_strided( + sizes_ref, + makeArrayRef(strides), + scalar_type, + Device(DeviceType::CUDA, static_cast(device_index_)))); + + if (numel == 0) { + return Error::Ok; + } + + // Get the current CUDA stream. + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_OR_RETURN_ERROR( + stream_result.ok(), + Internal, + "aoti_torch_cuda_rand: failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + ensure_rng_init(stream); + + constexpr int kThreads = 256; + int blocks = static_cast((numel + kThreads - 1) / kThreads); + + if (scalar_type == ScalarType::Float) { + philox_rand_float_graph_kernel<<>>( + static_cast((*ret0)->data_ptr()), numel, d_rng); + } else if (scalar_type == ScalarType::BFloat16) { + philox_rand_bf16_graph_kernel<<>>( + static_cast((*ret0)->data_ptr()), numel, d_rng); + } else { + ET_LOG( + Error, + "aoti_torch_cuda_rand: unsupported dtype %d", + static_cast(scalar_type)); + return Error::NotSupported; + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_cuda_randint_low_out( + SlimTensor* out, + int64_t low, + int64_t high, + const int64_t* size, + int64_t size_len_) { + ET_CHECK_OR_RETURN_ERROR( + out != nullptr, + InvalidArgument, + "aoti_torch_cuda_randint_low_out: out tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + high > low, + InvalidArgument, + "aoti_torch_cuda_randint_low_out: requires high > low"); + + int64_t numel = 1; + for (int64_t i = 0; i < size_len_; i++) { + numel *= size[i]; + } + if (numel == 0) { + return Error::Ok; + } + + // Get the current CUDA stream. + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_OR_RETURN_ERROR( + stream_result.ok(), + Internal, + "aoti_torch_cuda_randint_low_out: failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + ensure_rng_init(stream); + + int64_t range = high - low; + int64_t* out_data = static_cast(out->data_ptr()); + + constexpr int kThreads = 256; + int blocks = static_cast((numel + kThreads - 1) / kThreads); + philox_randint_graph_kernel<<>>( + out_data, numel, low, range, d_rng); + + return Error::Ok; +} + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/rand.h b/backends/cuda/runtime/shims/rand.h new file mode 100644 index 00000000000..e0c63be75a1 --- /dev/null +++ b/backends/cuda/runtime/shims/rand.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::runtime::Error; +using AOTITorchError = Error; + +using SlimTensor = executorch::backends::aoti::slim::SlimTensor; + +extern "C" { + +/** + * Generates a tensor filled with uniform random values in [0, 1). + * + * Implements the AOTI shim for aten::rand.default on CUDA. Uses cuRAND + * Philox counter-based RNG with GPU-resident state. The counter is + * atomically advanced by each kernel invocation on-device, making it + * fully compatible with CUDA graph capture and replay — each replay + * produces different values because the counter increments on the GPU. + * + * Supports float32 and bfloat16 output dtypes. + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_rand( + const int64_t* size, + int64_t size_len_, + int32_t* dtype, + int32_t* layout, + int32_t* device, + int32_t device_index_, + int32_t* pin_memory, + SlimTensor** ret0); + +/** + * Fills a pre-allocated int64 tensor with random integers in [low, high). + * + * Implements the AOTI shim for aten::randint.low_out on CUDA. Used by + * Inductor's Philox RNG to generate random seeds. Each thread atomically + * advances a GPU-resident counter for unique offsets, making this fully + * compatible with CUDA graph capture and replay. + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out( + SlimTensor* out, + int64_t low, + int64_t high, + const int64_t* size, + int64_t size_len_); + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index d4839cb5e42..c8ee95d4965 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -636,10 +636,11 @@ def _export_cuda(model, config, args): print("Exporting decode method...") decode_tokens = torch.tensor([[0]], dtype=torch.long) decode_pos = torch.tensor([0], dtype=torch.long) + decode_temperature = torch.tensor([1.0], dtype=torch.float32) with torch.no_grad(): decode_ep = export( model, - (decode_tokens, decode_pos), + (decode_tokens, decode_pos, decode_temperature), strict=True, ) print("Decode export successful!") @@ -651,18 +652,19 @@ def _export_cuda(model, config, args): # that reject longer prompts at runtime. _set_batched_moe(model, True) print("Exporting prefill method...") - example_prefill_len = config.max_seq_len - 1 - prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long) - prefill_pos = torch.arange(example_prefill_len, dtype=torch.long) + prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long) + prefill_pos = torch.tensor([0, 1], dtype=torch.long) + prefill_temperature = torch.tensor([1.0], dtype=torch.float32) seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1) prefill_dynamic_shapes = ( {1: seq_dim}, # tokens {0: seq_dim}, # input_pos + None, # temperature (static scalar) ) with torch.no_grad(): prefill_ep = export( model, - (prefill_tokens, prefill_pos), + (prefill_tokens, prefill_pos, prefill_temperature), dynamic_shapes=prefill_dynamic_shapes, strict=True, ) diff --git a/examples/models/qwen3_5_moe/inference.py b/examples/models/qwen3_5_moe/inference.py index c824f6a6444..6f9aa933535 100644 --- a/examples/models/qwen3_5_moe/inference.py +++ b/examples/models/qwen3_5_moe/inference.py @@ -77,46 +77,45 @@ def generate( Prefills one token at a time (the recurrent path; chunked FLA via @triton_op is used for T>1 prefill in the exported PTE). + + The model performs Gumbel-max sampling on-device: forward() returns + a sampled token ID [B, 1] instead of logits [B, T, V]. """ if eos_token_ids is None: eos_token_ids = set() input_ids = tokenizer.encode(prompt).ids + # Temperature tensor (use small epsilon for greedy to avoid div-by-zero) + temp_val = max(temperature, 1e-6) + temp_tensor = torch.tensor([temp_val], dtype=torch.float32, device="cuda") + # Prefill: one token at a time with torch.no_grad(): for i, tok_id in enumerate(input_ids): tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") pos = torch.tensor([i], dtype=torch.long, device="cuda") - logits = model(tok, pos) + sampled = model(tok, pos, temp_tensor) - # Sample first generated token - next_token = _sample(logits[:, -1, :], temperature) - generated = [next_token.item()] + # First generated token (model returns [B, 1] float token ID) + next_token_id = int(sampled.item()) + generated = [next_token_id] # Decode: one token at a time seq_len = len(input_ids) with torch.no_grad(): for i in range(max_new_tokens - 1): - pos = torch.tensor([seq_len + i], device="cuda") - logits = model(next_token.unsqueeze(0), pos) - next_token = _sample(logits[:, -1, :], temperature) - tok_id = next_token.item() - generated.append(tok_id) - if tok_id in eos_token_ids: + tok = torch.tensor([[next_token_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([seq_len + i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + next_token_id = int(sampled.item()) + generated.append(next_token_id) + if next_token_id in eos_token_ids: break return tokenizer.decode(generated) -def _sample(logits, temperature): - """Sample from logits with temperature.""" - if temperature <= 0: - return logits.argmax(dim=-1) - probs = torch.softmax(logits / temperature, dim=-1) - return torch.multinomial(probs, num_samples=1).squeeze(-1) - - def main(): parser = argparse.ArgumentParser( description="Run inference on prequantized Qwen3.5 MoE" diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index e7cd83dddc2..3e6f8032b5a 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -9,8 +9,6 @@ #include #include -#include -#include #include #include #include @@ -41,6 +39,25 @@ using ::executorch::runtime::EValue; using SizesType = executorch::aten::SizesType; +// Read a sampled token from the model output tensor [B, 1]. +// The model performs Gumbel-max sampling on-device and returns a single +// float token ID. This function copies it from GPU and casts to uint64. +static uint64_t read_token(const executorch::aten::Tensor& output) { + const void* ptr = output.const_data_ptr(); + + cudaPointerAttributes attrs; + bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + + float val; + if (on_device) { + cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); + } else { + memcpy(&val, ptr, sizeof(float)); + } + return static_cast(val); +} + int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -64,7 +81,7 @@ int main(int argc, char** argv) { return 1; } - // Create Module with share_memory_arenas=true so prefill and forward + // Create Module with share_memory_arenas=true so prefill and decode // share mutable buffers (KV cache, conv_state, recurrent_state). std::vector data_files; if (!FLAGS_data_path.empty()) { @@ -122,14 +139,23 @@ int main(int argc, char** argv) { printf("Prompt tokens: %ld\n", num_prompt_tokens); // --------------------------------------------------------------- - // Prefill or decode-only + // Temperature tensor (shared between prefill and decode) // --------------------------------------------------------------- auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + // Use a very small temperature for greedy to avoid division by zero + // while keeping the Gumbel noise negligible relative to logit differences. + float temp_val = + FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); + auto temp_tensor = + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); + + // --------------------------------------------------------------- + // Prefill + // --------------------------------------------------------------- uint64_t cur_token = 0; auto prefill_start = std::chrono::steady_clock::now(); - // Chunked prefill std::vector pos_data(num_prompt_tokens); for (int64_t i = 0; i < num_prompt_tokens; i++) { pos_data[i] = i; @@ -147,6 +173,7 @@ int main(int argc, char** argv) { std::vector prefill_inputs; prefill_inputs.push_back(tokens_tensor); prefill_inputs.push_back(pos_tensor); + prefill_inputs.push_back(temp_tensor); auto prefill_result = module->execute("prefill", prefill_inputs); if (prefill_result.error() != Error::Ok) { @@ -155,10 +182,7 @@ int main(int argc, char** argv) { } auto& prefill_outputs = prefill_result.get(); - auto logits_tensor = prefill_outputs[0].toTensor(); - auto logits_ptr = - std::make_shared(std::move(logits_tensor)); - cur_token = llm::logits_to_token(*logits_ptr, FLAGS_temperature); + cur_token = read_token(prefill_outputs[0].toTensor()); auto prefill_end = std::chrono::steady_clock::now(); double prefill_ms = @@ -178,7 +202,6 @@ int main(int argc, char** argv) { // --------------------------------------------------------------- // Decode — generate tokens one at a time // --------------------------------------------------------------- - llm::Stats stats; int64_t pos = num_prompt_tokens; uint64_t prev_token; @@ -198,6 +221,7 @@ int main(int argc, char** argv) { std::vector decode_inputs; decode_inputs.push_back(EValue(decode_tokens)); decode_inputs.push_back(EValue(decode_pos)); + decode_inputs.push_back(EValue(temp_tensor)); auto decode_result = module->execute("decode", decode_inputs); if (decode_result.error() != Error::Ok) { @@ -206,14 +230,8 @@ int main(int argc, char** argv) { } auto& decode_outputs = decode_result.get(); - auto step_logits = decode_outputs[0].toTensor(); - auto step_logits_ptr = - std::make_shared(std::move(step_logits)); - prev_token = cur_token; - stats.on_sampling_begin(); - cur_token = llm::logits_to_token(*step_logits_ptr, FLAGS_temperature); - stats.on_sampling_end(); + cur_token = read_token(decode_outputs[0].toTensor()); pos++; diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 261008b43e2..d0c1078893d 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -621,13 +621,24 @@ def __init__(self, config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward( - self, tokens: torch.LongTensor, input_pos: torch.LongTensor + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + temperature: torch.Tensor, ) -> torch.Tensor: x = self.embed_tokens(tokens) for layer in self.layers: x = layer(x, input_pos) x = self.norm(x) - return self.lm_head(x) + # Only compute logits for the last token position — avoids + # materializing the full [B, T, V] tensor during prefill. + logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32 + # GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) + # Equivalent to sampling from softmax(logits/T) but fully on-device. + logits = logits / temperature.clamp(min=1e-6) + noise = torch.rand_like(logits) + gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + return (logits + gumbel).argmax(dim=-1, keepdim=True).float() # [B, 1] @staticmethod def from_hf_checkpoint(model_dir, max_seq_len=4096):