From c7450ddfd9d68b5bfcc2b0953756b4f4de55761d Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 13 Apr 2026 11:08:56 -0700 Subject: [PATCH 1/2] Add GPU-side Gumbel-max sampling for CUDA graph compatibility --- backends/cuda/CMakeLists.txt | 6 +- backends/cuda/cuda_backend.py | 1 + backends/cuda/runtime/shims/rand.cu | 255 +++++++++++++++++++++++ backends/cuda/runtime/shims/rand.h | 65 ++++++ examples/models/qwen3_5_moe/export.py | 7 +- examples/models/qwen3_5_moe/inference.py | 35 ++-- examples/models/qwen3_5_moe/main.cpp | 53 +++-- examples/models/qwen3_5_moe/model.py | 15 +- 8 files changed, 395 insertions(+), 42 deletions(-) create mode 100644 backends/cuda/runtime/shims/rand.cu create mode 100644 backends/cuda/runtime/shims/rand.h diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 2befd78b41b..36186d0c6fe 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -107,9 +107,9 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp runtime/shims/cuda_guard.cpp ) -# Only build int4mm shim when CUDA language/toolchain is available. +# Only build CUDA shims when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) - list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu) + list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu runtime/shims/rand.cu) endif() add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources}) @@ -160,7 +160,7 @@ else() aoti_cuda_shims PRIVATE cuda_platform PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive - CUDA::cudart ${CMAKE_DL_LIBS} + CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS} ) endif() diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 661b4f2b960..e55eefd2d24 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool: def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "at::_ops::_weight_int4pack_mm::call": None, + "aoti_torch_cuda_randint_low_out": None, } @classmethod diff --git a/backends/cuda/runtime/shims/rand.cu b/backends/cuda/runtime/shims/rand.cu new file mode 100644 index 00000000000..c36b9ffb7bc --- /dev/null +++ b/backends/cuda/runtime/shims/rand.cu @@ -0,0 +1,255 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace executorch::backends::cuda { + +namespace c10 = executorch::backends::aoti::slim::c10; +using c10::Device; +using c10::DeviceIndex; +using c10::DeviceType; +using c10::ScalarType; +using executorch::backends::aoti::slim::empty_strided; +using executorch::backends::aoti::slim::IntArrayRef; +using executorch::backends::aoti::slim::makeArrayRef; + +namespace { + +// ---- GPU-resident RNG state ---- +// Seed and counter live in device memory allocated during the first call +// (warmup phase, before CUDA graph capture). The counter is atomically +// advanced by each kernel invocation on-device, so it automatically +// produces different random sequences on every CUDA graph replay. + +struct RngState { + unsigned long long seed; + unsigned long long counter; +}; + +static RngState* d_rng = nullptr; +static bool g_rng_init_done = false; + +// Initialize RNG state on the given stream. +// Must be called during warmup (before graph capture). +void ensure_rng_init(cudaStream_t stream) { + if (!g_rng_init_done) { + cudaMallocAsync(&d_rng, sizeof(RngState), stream); + RngState h; + h.seed = static_cast(time(nullptr)); + h.counter = 0; + cudaMemcpyAsync( + d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream); + // Synchronize to ensure the copy completes before we return + // (the host-side RngState `h` is on the stack). + cudaStreamSynchronize(stream); + g_rng_init_done = true; + } +} + +// Philox-based randint kernel that reads seed from device-resident state +// and atomically advances the counter. The counter pointer survives CUDA +// graph replay, so each replay produces different values. +__global__ void philox_randint_graph_kernel( + int64_t* __restrict__ out, + int64_t numel, + int64_t low, + int64_t range, + RngState* __restrict__ rng) { + // Each thread reads the seed and computes its unique offset. + // The "base offset" is read from rng->counter. We can't atomicAdd per + // thread, so we use a two-pass approach: first a single-thread kernel + // advances the counter, then the main kernel uses the old value. + // But that requires two kernel launches... + // + // Simpler: since numel=1 for randint seed generation, just one thread. + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + // Each invocation atomically grabs `numel` slots from the counter. + // For numel=1, this is just one atomicAdd. + unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL); + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, my_offset, &state); + double val = curand_uniform_double(&state); + int64_t ival = static_cast(val * range); + out[idx] = low + (ival >= range ? range - 1 : ival); + } +} + +// Philox-based uniform float32 generator (graph-safe version). +__global__ void philox_rand_float_graph_kernel( + float* __restrict__ out, + int64_t numel, + RngState* __restrict__ rng) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL); + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, my_offset, &state); + out[idx] = curand_uniform(&state); + } +} + +// Philox-based uniform bfloat16 generator (graph-safe version). +__global__ void philox_rand_bf16_graph_kernel( + uint16_t* __restrict__ out, + int64_t numel, + RngState* __restrict__ rng) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + unsigned long long my_offset = atomicAdd(&rng->counter, 1ULL); + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, my_offset, &state); + float val = curand_uniform(&state); + uint32_t bits; + memcpy(&bits, &val, sizeof(uint32_t)); + uint32_t lsb = (bits >> 16) & 1; + bits += 0x7FFFu + lsb; + out[idx] = static_cast(bits >> 16); + } +} + +} // anonymous namespace + +extern "C" { + +AOTITorchError aoti_torch_cuda_rand( + const int64_t* size, + int64_t size_len_, + int32_t* dtype, + int32_t* layout, + int32_t* device, + int32_t device_index_, + int32_t* pin_memory, + SlimTensor** ret0) { + (void)layout; + (void)device; + (void)pin_memory; + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda_rand: ret0 is null"); + + // Default to float32 if dtype not specified. + ScalarType scalar_type = ScalarType::Float; + if (dtype != nullptr) { + scalar_type = static_cast(*dtype); + } + + // Compute contiguous strides and total elements. + std::vector strides(size_len_); + int64_t numel = 1; + for (int64_t i = size_len_ - 1; i >= 0; i--) { + strides[i] = numel; + numel *= size[i]; + } + + // Allocate output tensor. + IntArrayRef sizes_ref(size, static_cast(size_len_)); + *ret0 = new SlimTensor(empty_strided( + sizes_ref, + makeArrayRef(strides), + scalar_type, + Device(DeviceType::CUDA, static_cast(device_index_)))); + + if (numel == 0) { + return Error::Ok; + } + + // Get the current CUDA stream. + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_OR_RETURN_ERROR( + stream_result.ok(), + Internal, + "aoti_torch_cuda_rand: failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + ensure_rng_init(stream); + + constexpr int kThreads = 256; + int blocks = static_cast((numel + kThreads - 1) / kThreads); + + if (scalar_type == ScalarType::Float) { + philox_rand_float_graph_kernel<<>>( + static_cast((*ret0)->data_ptr()), numel, d_rng); + } else if (scalar_type == ScalarType::BFloat16) { + philox_rand_bf16_graph_kernel<<>>( + static_cast((*ret0)->data_ptr()), numel, d_rng); + } else { + ET_LOG( + Error, + "aoti_torch_cuda_rand: unsupported dtype %d", + static_cast(scalar_type)); + return Error::NotSupported; + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_cuda_randint_low_out( + SlimTensor* out, + int64_t low, + int64_t high, + const int64_t* size, + int64_t size_len_) { + ET_CHECK_OR_RETURN_ERROR( + out != nullptr, + InvalidArgument, + "aoti_torch_cuda_randint_low_out: out tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + high > low, + InvalidArgument, + "aoti_torch_cuda_randint_low_out: requires high > low"); + + int64_t numel = 1; + for (int64_t i = 0; i < size_len_; i++) { + numel *= size[i]; + } + if (numel == 0) { + return Error::Ok; + } + + // Get the current CUDA stream. + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_OR_RETURN_ERROR( + stream_result.ok(), + Internal, + "aoti_torch_cuda_randint_low_out: failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + ensure_rng_init(stream); + + int64_t range = high - low; + int64_t* out_data = static_cast(out->data_ptr()); + + constexpr int kThreads = 256; + int blocks = static_cast((numel + kThreads - 1) / kThreads); + philox_randint_graph_kernel<<>>( + out_data, numel, low, range, d_rng); + + return Error::Ok; +} + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/rand.h b/backends/cuda/runtime/shims/rand.h new file mode 100644 index 00000000000..e0c63be75a1 --- /dev/null +++ b/backends/cuda/runtime/shims/rand.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::runtime::Error; +using AOTITorchError = Error; + +using SlimTensor = executorch::backends::aoti::slim::SlimTensor; + +extern "C" { + +/** + * Generates a tensor filled with uniform random values in [0, 1). + * + * Implements the AOTI shim for aten::rand.default on CUDA. Uses cuRAND + * Philox counter-based RNG with GPU-resident state. The counter is + * atomically advanced by each kernel invocation on-device, making it + * fully compatible with CUDA graph capture and replay — each replay + * produces different values because the counter increments on the GPU. + * + * Supports float32 and bfloat16 output dtypes. + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_rand( + const int64_t* size, + int64_t size_len_, + int32_t* dtype, + int32_t* layout, + int32_t* device, + int32_t device_index_, + int32_t* pin_memory, + SlimTensor** ret0); + +/** + * Fills a pre-allocated int64 tensor with random integers in [low, high). + * + * Implements the AOTI shim for aten::randint.low_out on CUDA. Used by + * Inductor's Philox RNG to generate random seeds. Each thread atomically + * advances a GPU-resident counter for unique offsets, making this fully + * compatible with CUDA graph capture and replay. + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out( + SlimTensor* out, + int64_t low, + int64_t high, + const int64_t* size, + int64_t size_len_); + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 19a720a2e79..e7b265b575f 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -416,10 +416,11 @@ def export_and_lower(model, config, args): print("Exporting decode method...") decode_tokens = torch.tensor([[0]], dtype=torch.long) decode_pos = torch.tensor([0], dtype=torch.long) + decode_temperature = torch.tensor([1.0], dtype=torch.float32) with torch.no_grad(): decode_ep = export( model, - (decode_tokens, decode_pos), + (decode_tokens, decode_pos, decode_temperature), strict=True, ) print("Decode export successful!") @@ -428,15 +429,17 @@ def export_and_lower(model, config, args): print("Exporting prefill method...") prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long) prefill_pos = torch.tensor([0, 1], dtype=torch.long) + prefill_temperature = torch.tensor([1.0], dtype=torch.float32) seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1) prefill_dynamic_shapes = ( {1: seq_dim}, # tokens {0: seq_dim}, # input_pos + None, # temperature (static scalar) ) with torch.no_grad(): prefill_ep = export( model, - (prefill_tokens, prefill_pos), + (prefill_tokens, prefill_pos, prefill_temperature), dynamic_shapes=prefill_dynamic_shapes, strict=True, ) diff --git a/examples/models/qwen3_5_moe/inference.py b/examples/models/qwen3_5_moe/inference.py index c824f6a6444..6f9aa933535 100644 --- a/examples/models/qwen3_5_moe/inference.py +++ b/examples/models/qwen3_5_moe/inference.py @@ -77,46 +77,45 @@ def generate( Prefills one token at a time (the recurrent path; chunked FLA via @triton_op is used for T>1 prefill in the exported PTE). + + The model performs Gumbel-max sampling on-device: forward() returns + a sampled token ID [B, 1] instead of logits [B, T, V]. """ if eos_token_ids is None: eos_token_ids = set() input_ids = tokenizer.encode(prompt).ids + # Temperature tensor (use small epsilon for greedy to avoid div-by-zero) + temp_val = max(temperature, 1e-6) + temp_tensor = torch.tensor([temp_val], dtype=torch.float32, device="cuda") + # Prefill: one token at a time with torch.no_grad(): for i, tok_id in enumerate(input_ids): tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") pos = torch.tensor([i], dtype=torch.long, device="cuda") - logits = model(tok, pos) + sampled = model(tok, pos, temp_tensor) - # Sample first generated token - next_token = _sample(logits[:, -1, :], temperature) - generated = [next_token.item()] + # First generated token (model returns [B, 1] float token ID) + next_token_id = int(sampled.item()) + generated = [next_token_id] # Decode: one token at a time seq_len = len(input_ids) with torch.no_grad(): for i in range(max_new_tokens - 1): - pos = torch.tensor([seq_len + i], device="cuda") - logits = model(next_token.unsqueeze(0), pos) - next_token = _sample(logits[:, -1, :], temperature) - tok_id = next_token.item() - generated.append(tok_id) - if tok_id in eos_token_ids: + tok = torch.tensor([[next_token_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([seq_len + i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + next_token_id = int(sampled.item()) + generated.append(next_token_id) + if next_token_id in eos_token_ids: break return tokenizer.decode(generated) -def _sample(logits, temperature): - """Sample from logits with temperature.""" - if temperature <= 0: - return logits.argmax(dim=-1) - probs = torch.softmax(logits / temperature, dim=-1) - return torch.multinomial(probs, num_samples=1).squeeze(-1) - - def main(): parser = argparse.ArgumentParser( description="Run inference on prequantized Qwen3.5 MoE" diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 9603cfe6c78..ab41d5cf6e3 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -9,8 +9,6 @@ #include #include -#include -#include #include #include #include @@ -41,6 +39,25 @@ using ::executorch::runtime::EValue; using SizesType = executorch::aten::SizesType; +// Read a sampled token from the model output tensor [B, 1]. +// The model performs Gumbel-max sampling on-device and returns a single +// float token ID. This function copies it from GPU and casts to uint64. +static uint64_t read_token(const executorch::aten::Tensor& output) { + const void* ptr = output.const_data_ptr(); + + cudaPointerAttributes attrs; + bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + + float val; + if (on_device) { + cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); + } else { + memcpy(&val, ptr, sizeof(float)); + } + return static_cast(val); +} + int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -64,7 +81,7 @@ int main(int argc, char** argv) { return 1; } - // Create Module with share_memory_arenas=true so prefill and forward + // Create Module with share_memory_arenas=true so prefill and decode // share mutable buffers (KV cache, conv_state, recurrent_state). std::vector data_files; if (!FLAGS_data_path.empty()) { @@ -134,14 +151,24 @@ int main(int argc, char** argv) { printf("Prompt tokens: %ld\n", num_prompt_tokens); // --------------------------------------------------------------- - // Prefill or decode-only + // Temperature tensor (shared between prefill and decode) // --------------------------------------------------------------- auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + // Use a very small temperature for greedy to avoid division by zero + // while keeping the Gumbel noise negligible relative to logit differences. + float temp_val = FLAGS_temperature <= 0.0 + ? 1e-6f + : static_cast(FLAGS_temperature); + auto temp_tensor = from_blob( + &temp_val, {1}, executorch::aten::ScalarType::Float); + + // --------------------------------------------------------------- + // Prefill + // --------------------------------------------------------------- uint64_t cur_token = 0; auto prefill_start = std::chrono::steady_clock::now(); - // Chunked prefill std::vector pos_data(num_prompt_tokens); for (int64_t i = 0; i < num_prompt_tokens; i++) { pos_data[i] = i; @@ -159,6 +186,7 @@ int main(int argc, char** argv) { std::vector prefill_inputs; prefill_inputs.push_back(tokens_tensor); prefill_inputs.push_back(pos_tensor); + prefill_inputs.push_back(temp_tensor); auto prefill_result = module->execute(prefill_method, prefill_inputs); if (prefill_result.error() != Error::Ok) { @@ -167,10 +195,7 @@ int main(int argc, char** argv) { } auto& prefill_outputs = prefill_result.get(); - auto logits_tensor = prefill_outputs[0].toTensor(); - auto logits_ptr = - std::make_shared(std::move(logits_tensor)); - cur_token = llm::logits_to_token(*logits_ptr, FLAGS_temperature); + cur_token = read_token(prefill_outputs[0].toTensor()); auto prefill_end = std::chrono::steady_clock::now(); double prefill_ms = @@ -195,7 +220,6 @@ int main(int argc, char** argv) { // --------------------------------------------------------------- // Decode — generate tokens one at a time // --------------------------------------------------------------- - llm::Stats stats; int64_t pos = num_prompt_tokens; uint64_t prev_token; @@ -215,6 +239,7 @@ int main(int argc, char** argv) { std::vector decode_inputs; decode_inputs.push_back(EValue(decode_tokens)); decode_inputs.push_back(EValue(decode_pos)); + decode_inputs.push_back(EValue(temp_tensor)); auto decode_result = module->execute("decode", decode_inputs); if (decode_result.error() != Error::Ok) { @@ -223,14 +248,8 @@ int main(int argc, char** argv) { } auto& decode_outputs = decode_result.get(); - auto step_logits = decode_outputs[0].toTensor(); - auto step_logits_ptr = - std::make_shared(std::move(step_logits)); - prev_token = cur_token; - stats.on_sampling_begin(); - cur_token = llm::logits_to_token(*step_logits_ptr, FLAGS_temperature); - stats.on_sampling_end(); + cur_token = read_token(decode_outputs[0].toTensor()); pos++; diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 751915fb123..6cc3527bd6c 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -594,13 +594,24 @@ def __init__(self, config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward( - self, tokens: torch.LongTensor, input_pos: torch.LongTensor + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + temperature: torch.Tensor, ) -> torch.Tensor: x = self.embed_tokens(tokens) for layer in self.layers: x = layer(x, input_pos) x = self.norm(x) - return self.lm_head(x) + # Only compute logits for the last token position — avoids + # materializing the full [B, T, V] tensor during prefill. + logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32 + # GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) + # Equivalent to sampling from softmax(logits/T) but fully on-device. + logits = logits / temperature.clamp(min=1e-6) + noise = torch.rand_like(logits) + gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + return (logits + gumbel).argmax(dim=-1, keepdim=True).float() # [B, 1] @staticmethod def from_hf_checkpoint(model_dir, max_seq_len=4096): From 028894ef8e450350cdb4e7c2619102c5223bd0d2 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 13 Apr 2026 14:26:07 -0700 Subject: [PATCH 2/2] lintrunner --- backends/cuda/CMakeLists.txt | 4 ++- backends/cuda/runtime/cuda_backend.cpp | 32 ++++++++++++-------- backends/cuda/runtime/cuda_delegate_handle.h | 2 +- examples/models/qwen3_5_moe/export.py | 2 +- examples/models/qwen3_5_moe/main.cpp | 9 +++--- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 36186d0c6fe..78daf8d2010 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -109,7 +109,9 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp # Only build CUDA shims when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) - list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu runtime/shims/rand.cu) + list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu + runtime/shims/rand.cu + ) endif() add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources}) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 927f418f8f0..2f31439ab08 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -80,8 +80,7 @@ namespace { constexpr char kSkipCopyOutputToCpuForMethod[] = "skip_copy_output_to_cpu_for_method"; constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream"; -constexpr char kEnableCudaGraphForMethod[] = - "enable_cuda_graph_for_method"; +constexpr char kEnableCudaGraphForMethod[] = "enable_cuda_graph_for_method"; constexpr int kCudaGraphWarmupSteps = 3; } // anonymous namespace @@ -410,7 +409,9 @@ class ET_EXPERIMENTAL CudaBackend final cudaDeviceSynchronize(); buffer_res->Free(); } else { - ET_LOG(Info, "weights_blob '%s' not found or update fn is null", + ET_LOG( + Info, + "weights_blob '%s' not found or update fn is null", weights_blob_key.c_str()); } @@ -649,13 +650,17 @@ class ET_EXPERIMENTAL CudaBackend final void* static_ptr = nullptr; cudaError_t merr = cudaMalloc(&static_ptr, nbytes); ET_CHECK_OR_RETURN_ERROR( - merr == cudaSuccess, Internal, + merr == cudaSuccess, + Internal, "cudaMalloc for static input %zu failed: %s", - i, cudaGetErrorString(merr)); + i, + cudaGetErrorString(merr)); cudaMemcpy( - static_ptr, cpu_tensor->const_data_ptr(), - nbytes, cudaMemcpyHostToDevice); + static_ptr, + cpu_tensor->const_data_ptr(), + nbytes, + cudaMemcpyHostToDevice); handle->static_input_ptrs.push_back(static_ptr); handle->static_input_sizes.push_back(sizes_vec); @@ -669,7 +674,8 @@ class ET_EXPERIMENTAL CudaBackend final slim::makeArrayRef(sizes_vec), slim::makeArrayRef(strides_vec), static_cast(cpu_tensor->scalar_type()), - DEFAULT_CUDA_DEVICE, 0)); + DEFAULT_CUDA_DEVICE, + 0)); continue; } @@ -755,8 +761,8 @@ class ET_EXPERIMENTAL CudaBackend final "CUDA graph: beginning stream capture for '%s'", handle->method_name.c_str()); - cudaError_t cerr = cudaStreamBeginCapture( - cuda_stream, cudaStreamCaptureModeRelaxed); + cudaError_t cerr = + cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeRelaxed); ET_CHECK_OR_RETURN_ERROR( cerr == cudaSuccess, Internal, @@ -791,8 +797,7 @@ class ET_EXPERIMENTAL CudaBackend final if (is_capture_step) { // End capture → instantiate graph - cudaError_t gerr = - cudaStreamEndCapture(cuda_stream, &handle->cuda_graph); + cudaError_t gerr = cudaStreamEndCapture(cuda_stream, &handle->cuda_graph); ET_CHECK_OR_RETURN_ERROR( gerr == cudaSuccess, Internal, @@ -800,7 +805,8 @@ class ET_EXPERIMENTAL CudaBackend final cudaGetErrorString(gerr)); gerr = cudaGraphInstantiate( - &handle->cuda_graph_exec, handle->cuda_graph, + &handle->cuda_graph_exec, + handle->cuda_graph, cudaGraphInstantiateFlagAutoFreeOnLaunch); ET_CHECK_OR_RETURN_ERROR( gerr == cudaSuccess, diff --git a/backends/cuda/runtime/cuda_delegate_handle.h b/backends/cuda/runtime/cuda_delegate_handle.h index 2d37e6cebcb..33a8a51a1a1 100644 --- a/backends/cuda/runtime/cuda_delegate_handle.h +++ b/backends/cuda/runtime/cuda_delegate_handle.h @@ -73,7 +73,7 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle { // These hold the tensor metadata; the underlying data pointers are fixed // addresses that CUDA graph replay will write to / read from. // SlimTensor pointers — owned by this handle. - std::vector static_input_ptrs; // raw GPU data pointers for inputs + std::vector static_input_ptrs; // raw GPU data pointers for inputs std::vector static_output_ptrs; // raw GPU data pointers for outputs std::vector> static_input_sizes; std::vector> static_input_strides; diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index e7b265b575f..20dce405c24 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -434,7 +434,7 @@ def export_and_lower(model, config, args): prefill_dynamic_shapes = ( {1: seq_dim}, # tokens {0: seq_dim}, # input_pos - None, # temperature (static scalar) + None, # temperature (static scalar) ) with torch.no_grad(): prefill_ep = export( diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index ab41d5cf6e3..cd8206e14e2 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -157,11 +157,10 @@ int main(int argc, char** argv) { // Use a very small temperature for greedy to avoid division by zero // while keeping the Gumbel noise negligible relative to logit differences. - float temp_val = FLAGS_temperature <= 0.0 - ? 1e-6f - : static_cast(FLAGS_temperature); - auto temp_tensor = from_blob( - &temp_val, {1}, executorch::aten::ScalarType::Float); + float temp_val = + FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); + auto temp_tensor = + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); // --------------------------------------------------------------- // Prefill