Problem
CudaContext::new() takes ~5 seconds because NVRTC compiles all 31 kernels from tropical_gemm.cu on every startup:
| Category |
Count |
Details |
| Basic GEMM |
12 |
4 types (f32/f64/i32/i64) x 3 semirings |
| GEMM with Argmax |
12 |
same combinations |
| Backward pass |
4 |
f32/f64 x grad_A/grad_B |
| Batched GEMM |
3 |
f32 only x 3 semirings |
A typical PyTorch use case (f32 MaxPlus with backward) only needs ~5 of these 31 kernels.
Proposed Solutions
Option A: PTX disk caching
Cache the compiled PTX to disk. First launch pays the ~5s cost; subsequent launches load from cache in ~10ms.
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
fn ptx_cache_path(source: &str, major: i32, minor: i32) -> PathBuf {
let mut hasher = DefaultHasher::new();
source.hash(&mut hasher);
let source_hash = hasher.finish();
// Cache key MUST include GPU architecture:
// - sm_86 PTX loaded on sm_89 -> misses newer instructions (suboptimal)
// - sm_89 PTX loaded on sm_86 -> uses unavailable instructions (crash)
let filename = format!("{:016x}_sm_{}{}.ptx", source_hash, major, minor);
dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"))
.join("tropical-gemm")
.join(filename)
}
pub fn from_device(device: Arc<CudaDevice>) -> Result<Self> {
let major = device.attribute(COMPUTE_CAPABILITY_MAJOR)?;
let minor = device.attribute(COMPUTE_CAPABILITY_MINOR)?;
let cache_path = ptx_cache_path(KERNEL_SOURCE, major, minor);
let ptx = if cache_path.exists() {
// Load cached PTX (~10ms)
let bytes = std::fs::read(&cache_path)?;
cudarc::driver::Ptx::from_bytes(&bytes)
} else {
// Compile and cache (~5s, first time only)
let arch = format!("--gpu-architecture=sm_{}{}", major, minor);
let ptx = cudarc::nvrtc::compile_ptx_with_opts(
KERNEL_SOURCE,
CompileOptions {
arch: Some(arch),
use_fast_math: Some(true),
..Default::default()
},
)?;
// Save to cache (best-effort, don't fail if cache dir is unwritable)
if let Some(parent) = cache_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let _ = std::fs::write(&cache_path, ptx.as_bytes());
ptx
};
device.load_ptx(ptx, "tropical_gemm", KERNEL_NAMES)?;
// ...
}
Pros: Simple, no API change, benefits all users automatically.
Cons: Still compiles all 31 kernels on first run. Cache can become stale (mitigated by source hash in key).
Option B: Lazy compilation (compile on first use)
Only compile each kernel when it's first requested. Split the .cu source into per-kernel fragments and compile individually.
use std::collections::HashMap;
use std::sync::Mutex;
/// Per-kernel source fragments for lazy compilation.
/// Each fragment includes shared preamble + one kernel instantiation.
fn kernel_source(name: &str) -> String {
// Shared preamble: constants, utility functions, macros
let preamble = &KERNEL_SOURCE[..PREAMBLE_END_OFFSET];
// Kernel-specific instantiation
let instantiation = match name {
"tropical_maxplus_f32_nn" =>
"TROPICAL_GEMM_F32(tropical_maxplus_f32_nn, NEG_INF_F32, fmaxf, +)",
"tropical_minplus_f32_nn" =>
"TROPICAL_GEMM_F32(tropical_minplus_f32_nn, INF_F32, fminf, +)",
// ... other kernels
_ => panic!("Unknown kernel: {}", name),
};
format!("{}\n{}\n", preamble, instantiation)
}
pub struct CudaContext {
device: Arc<CudaDevice>,
kernels: Mutex<HashMap<&'static str, CudaFunction>>,
arch: String,
}
impl CudaContext {
pub fn new_on_device(device_id: usize) -> Result<Self> {
let device = CudaDevice::new(device_id)?;
let major = device.attribute(COMPUTE_CAPABILITY_MAJOR)?;
let minor = device.attribute(COMPUTE_CAPABILITY_MINOR)?;
let arch = format!("--gpu-architecture=sm_{}{}", major, minor);
Ok(Self {
device,
kernels: Mutex::new(HashMap::new()),
arch,
})
}
/// Get a kernel, compiling it on first access (~200ms per kernel).
pub fn get_kernel(&self, name: &'static str) -> Result<CudaFunction> {
let mut kernels = self.kernels.lock().unwrap();
if let Some(func) = kernels.get(name) {
return Ok(func.clone());
}
// Compile just this one kernel
let source = kernel_source(name);
let ptx = cudarc::nvrtc::compile_ptx_with_opts(
&source,
CompileOptions {
arch: Some(self.arch.clone()),
use_fast_math: Some(true),
..Default::default()
},
)?;
// Each kernel gets its own module to avoid name collisions
let module_name = format!("tropical_{}", name);
self.device.load_ptx(ptx, &module_name, &[name])?;
let func = self.device
.get_func(&module_name, name)
.ok_or_else(|| CudaError::KernelNotFound(name.to_string()))?;
kernels.insert(name, func.clone());
Ok(func)
}
}
Pros: Near-zero startup (~0ms). Only pays compilation cost for kernels actually used (~200ms each).
Cons: First kernel call has latency. Needs Mutex (or could use OnceLock per kernel). More complex code.
Recommendation
Option A (caching) is simpler and should be implemented first. Option B (lazy) can be layered on top later if startup time remains a concern (e.g., for short-lived scripts that only use one semiring).
Both options can be combined: lazy compile on first use + cache each compiled kernel to disk.
Problem
CudaContext::new()takes ~5 seconds because NVRTC compiles all 31 kernels fromtropical_gemm.cuon every startup:A typical PyTorch use case (f32 MaxPlus with backward) only needs ~5 of these 31 kernels.
Proposed Solutions
Option A: PTX disk caching
Cache the compiled PTX to disk. First launch pays the ~5s cost; subsequent launches load from cache in ~10ms.
Pros: Simple, no API change, benefits all users automatically.
Cons: Still compiles all 31 kernels on first run. Cache can become stale (mitigated by source hash in key).
Option B: Lazy compilation (compile on first use)
Only compile each kernel when it's first requested. Split the
.cusource into per-kernel fragments and compile individually.Pros: Near-zero startup (~0ms). Only pays compilation cost for kernels actually used (~200ms each).
Cons: First kernel call has latency. Needs
Mutex(or could useOnceLockper kernel). More complex code.Recommendation
Option A (caching) is simpler and should be implemented first. Option B (lazy) can be layered on top later if startup time remains a concern (e.g., for short-lived scripts that only use one semiring).
Both options can be combined: lazy compile on first use + cache each compiled kernel to disk.