Skip to content

CUDA: ~5s startup latency from compiling all 31 kernels #41

@shinaoka

Description

@shinaoka

Problem

CudaContext::new() takes ~5 seconds because NVRTC compiles all 31 kernels from tropical_gemm.cu on every startup:

Category Count Details
Basic GEMM 12 4 types (f32/f64/i32/i64) x 3 semirings
GEMM with Argmax 12 same combinations
Backward pass 4 f32/f64 x grad_A/grad_B
Batched GEMM 3 f32 only x 3 semirings

A typical PyTorch use case (f32 MaxPlus with backward) only needs ~5 of these 31 kernels.

Proposed Solutions

Option A: PTX disk caching

Cache the compiled PTX to disk. First launch pays the ~5s cost; subsequent launches load from cache in ~10ms.

use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;

fn ptx_cache_path(source: &str, major: i32, minor: i32) -> PathBuf {
    let mut hasher = DefaultHasher::new();
    source.hash(&mut hasher);
    let source_hash = hasher.finish();

    // Cache key MUST include GPU architecture:
    // - sm_86 PTX loaded on sm_89 -> misses newer instructions (suboptimal)
    // - sm_89 PTX loaded on sm_86 -> uses unavailable instructions (crash)
    let filename = format!("{:016x}_sm_{}{}.ptx", source_hash, major, minor);

    dirs::cache_dir()
        .unwrap_or_else(|| PathBuf::from("/tmp"))
        .join("tropical-gemm")
        .join(filename)
}

pub fn from_device(device: Arc<CudaDevice>) -> Result<Self> {
    let major = device.attribute(COMPUTE_CAPABILITY_MAJOR)?;
    let minor = device.attribute(COMPUTE_CAPABILITY_MINOR)?;

    let cache_path = ptx_cache_path(KERNEL_SOURCE, major, minor);

    let ptx = if cache_path.exists() {
        // Load cached PTX (~10ms)
        let bytes = std::fs::read(&cache_path)?;
        cudarc::driver::Ptx::from_bytes(&bytes)
    } else {
        // Compile and cache (~5s, first time only)
        let arch = format!("--gpu-architecture=sm_{}{}", major, minor);
        let ptx = cudarc::nvrtc::compile_ptx_with_opts(
            KERNEL_SOURCE,
            CompileOptions {
                arch: Some(arch),
                use_fast_math: Some(true),
                ..Default::default()
            },
        )?;

        // Save to cache (best-effort, don't fail if cache dir is unwritable)
        if let Some(parent) = cache_path.parent() {
            let _ = std::fs::create_dir_all(parent);
        }
        let _ = std::fs::write(&cache_path, ptx.as_bytes());

        ptx
    };

    device.load_ptx(ptx, "tropical_gemm", KERNEL_NAMES)?;
    // ...
}

Pros: Simple, no API change, benefits all users automatically.
Cons: Still compiles all 31 kernels on first run. Cache can become stale (mitigated by source hash in key).

Option B: Lazy compilation (compile on first use)

Only compile each kernel when it's first requested. Split the .cu source into per-kernel fragments and compile individually.

use std::collections::HashMap;
use std::sync::Mutex;

/// Per-kernel source fragments for lazy compilation.
/// Each fragment includes shared preamble + one kernel instantiation.
fn kernel_source(name: &str) -> String {
    // Shared preamble: constants, utility functions, macros
    let preamble = &KERNEL_SOURCE[..PREAMBLE_END_OFFSET];

    // Kernel-specific instantiation
    let instantiation = match name {
        "tropical_maxplus_f32_nn" =>
            "TROPICAL_GEMM_F32(tropical_maxplus_f32_nn, NEG_INF_F32, fmaxf, +)",
        "tropical_minplus_f32_nn" =>
            "TROPICAL_GEMM_F32(tropical_minplus_f32_nn, INF_F32, fminf, +)",
        // ... other kernels
        _ => panic!("Unknown kernel: {}", name),
    };

    format!("{}\n{}\n", preamble, instantiation)
}

pub struct CudaContext {
    device: Arc<CudaDevice>,
    kernels: Mutex<HashMap<&'static str, CudaFunction>>,
    arch: String,
}

impl CudaContext {
    pub fn new_on_device(device_id: usize) -> Result<Self> {
        let device = CudaDevice::new(device_id)?;
        let major = device.attribute(COMPUTE_CAPABILITY_MAJOR)?;
        let minor = device.attribute(COMPUTE_CAPABILITY_MINOR)?;
        let arch = format!("--gpu-architecture=sm_{}{}", major, minor);

        Ok(Self {
            device,
            kernels: Mutex::new(HashMap::new()),
            arch,
        })
    }

    /// Get a kernel, compiling it on first access (~200ms per kernel).
    pub fn get_kernel(&self, name: &'static str) -> Result<CudaFunction> {
        let mut kernels = self.kernels.lock().unwrap();

        if let Some(func) = kernels.get(name) {
            return Ok(func.clone());
        }

        // Compile just this one kernel
        let source = kernel_source(name);
        let ptx = cudarc::nvrtc::compile_ptx_with_opts(
            &source,
            CompileOptions {
                arch: Some(self.arch.clone()),
                use_fast_math: Some(true),
                ..Default::default()
            },
        )?;

        // Each kernel gets its own module to avoid name collisions
        let module_name = format!("tropical_{}", name);
        self.device.load_ptx(ptx, &module_name, &[name])?;

        let func = self.device
            .get_func(&module_name, name)
            .ok_or_else(|| CudaError::KernelNotFound(name.to_string()))?;

        kernels.insert(name, func.clone());
        Ok(func)
    }
}

Pros: Near-zero startup (~0ms). Only pays compilation cost for kernels actually used (~200ms each).
Cons: First kernel call has latency. Needs Mutex (or could use OnceLock per kernel). More complex code.

Recommendation

Option A (caching) is simpler and should be implemented first. Option B (lazy) can be layered on top later if startup time remains a concern (e.g., for short-lived scripts that only use one semiring).

Both options can be combined: lazy compile on first use + cache each compiled kernel to disk.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions