From 4a589bf73080a4d7ac8e839ca4eaceb3acd87396 Mon Sep 17 00:00:00 2001 From: Casey Sanchez Date: Fri, 5 Jun 2026 13:23:22 -0400 Subject: [PATCH 1/4] feat(cufft): add cufft_raw bindings, cufft safe wrapper, and fft example --- Cargo.lock | 26 +++ Cargo.toml | 3 + crates/cufft/Cargo.toml | 11 ++ crates/cufft/src/error.rs | 81 ++++++++ crates/cufft/src/lib.rs | 14 ++ crates/cufft/src/plan.rs | 317 +++++++++++++++++++++++++++++++ crates/cufft_raw/Cargo.toml | 16 ++ crates/cufft_raw/build/main.rs | 69 +++++++ crates/cufft_raw/build/wrapper.h | 1 + crates/cufft_raw/src/lib.rs | 8 + examples/fft/Cargo.toml | 9 + examples/fft/src/main.rs | 82 ++++++++ 12 files changed, 637 insertions(+) create mode 100644 crates/cufft/Cargo.toml create mode 100644 crates/cufft/src/error.rs create mode 100644 crates/cufft/src/lib.rs create mode 100644 crates/cufft/src/plan.rs create mode 100644 crates/cufft_raw/Cargo.toml create mode 100644 crates/cufft_raw/build/main.rs create mode 100644 crates/cufft_raw/build/wrapper.h create mode 100644 crates/cufft_raw/src/lib.rs create mode 100644 examples/fft/Cargo.toml create mode 100644 examples/fft/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 1afa046a..4a64dec5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -854,6 +854,23 @@ dependencies = [ "cust_raw", ] +[[package]] +name = "cufft" +version = "0.1.0" +dependencies = [ + "cufft_raw", + "cust", +] + +[[package]] +name = "cufft_raw" +version = "0.1.0" +dependencies = [ + "bindgen", + "cust_core", + "cust_raw", +] + [[package]] name = "curl" version = "0.4.49" @@ -1152,6 +1169,15 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "fft" +version = "0.1.0" +dependencies = [ + "cufft", + "cufft_raw", + "cust", +] + [[package]] name = "filetime" version = "0.2.26" diff --git a/Cargo.toml b/Cargo.toml index 108b31a6..3fb5fade 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,8 @@ members = [ "crates/cuda_std_macros", "crates/cudnn", "crates/cudnn-sys", + "crates/cufft", + "crates/cufft_raw", "crates/cust", "crates/cust_core", "crates/cust_derive", @@ -49,6 +51,7 @@ members = [ "examples/i128_demo/kernels", "examples/sha2_crates_io", "examples/sha2_crates_io/kernels", + "examples/fft", "examples/vecadd", "examples/vecadd/kernels", diff --git a/crates/cufft/Cargo.toml b/crates/cufft/Cargo.toml new file mode 100644 index 00000000..e70a5f95 --- /dev/null +++ b/crates/cufft/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "cufft" +version = "0.1.0" +edition = "2024" +license = "MIT OR Apache-2.0" +repository = "https://github.com/Rust-GPU/rust-cuda" +readme = "../../README.md" + +[dependencies] +cufft_raw = { version = "0.1.0", path = "../cufft_raw" } +cust = { version = "0.3.2", path = "../cust" } diff --git a/crates/cufft/src/error.rs b/crates/cufft/src/error.rs new file mode 100644 index 00000000..0354e538 --- /dev/null +++ b/crates/cufft/src/error.rs @@ -0,0 +1,81 @@ +use std::error::Error; +use std::fmt::Display; + +/// Error type for cuFFT operations. +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum CufftError { + /// The plan handle is invalid. + InvalidPlan, + /// Memory allocation failed. + AllocFailed, + /// The transform type is invalid. + InvalidType, + /// An invalid value was provided. + InvalidValue, + /// An internal cuFFT error occurred. + InternalError, + /// The transform failed to execute. + ExecFailed, + /// The library failed to initialize. + SetupFailed, + /// The transform size is invalid. + InvalidSize, + /// The device is invalid. + InvalidDevice, + /// No workspace has been provided. + NoWorkspace, + /// This feature is not implemented. + NotImplemented, + /// This feature is not supported. + NotSupported, +} + +impl Display for CufftError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let msg = match self { + CufftError::InvalidPlan => "invalid plan handle", + CufftError::AllocFailed => "allocation failed", + CufftError::InvalidType => "invalid type", + CufftError::InvalidValue => "invalid value", + CufftError::InternalError => "internal error", + CufftError::ExecFailed => "exec failed", + CufftError::SetupFailed => "setup failed", + CufftError::InvalidSize => "invalid size", + CufftError::InvalidDevice => "invalid device", + CufftError::NoWorkspace => "no workspace", + CufftError::NotImplemented => "not implemented", + CufftError::NotSupported => "not supported", + }; + f.write_str(msg) + } +} + +impl Error for CufftError {} + +pub trait IntoResult { + fn into_result(self) -> Result<(), CufftError>; +} + +impl IntoResult for cufft_raw::cufftResult { + fn into_result(self) -> Result<(), CufftError> { + use cufft_raw::cufftResult::*; + Err(match self { + CUFFT_SUCCESS => return Ok(()), + CUFFT_INVALID_PLAN => CufftError::InvalidPlan, + CUFFT_ALLOC_FAILED => CufftError::AllocFailed, + CUFFT_INVALID_TYPE => CufftError::InvalidType, + CUFFT_INVALID_VALUE => CufftError::InvalidValue, + CUFFT_INTERNAL_ERROR => CufftError::InternalError, + CUFFT_EXEC_FAILED => CufftError::ExecFailed, + CUFFT_SETUP_FAILED => CufftError::SetupFailed, + CUFFT_INVALID_SIZE => CufftError::InvalidSize, + CUFFT_UNALIGNED_DATA => CufftError::InvalidValue, + CUFFT_INVALID_DEVICE => CufftError::InvalidDevice, + CUFFT_NO_WORKSPACE => CufftError::NoWorkspace, + CUFFT_NOT_IMPLEMENTED => CufftError::NotImplemented, + CUFFT_NOT_SUPPORTED => CufftError::NotSupported, + _ => CufftError::InternalError, + }) + } +} diff --git a/crates/cufft/src/lib.rs b/crates/cufft/src/lib.rs new file mode 100644 index 00000000..0cf3f4f7 --- /dev/null +++ b/crates/cufft/src/lib.rs @@ -0,0 +1,14 @@ +//! Rust wrapper for the [cuFFT library](https://docs.nvidia.com/cuda/cufft/). +//! +//! Create a `FftPlan` with one of the plan constructors. +//! Optionally attach a stream with `FftPlan::set_stream` +//! Execute the plan with one of the `exec_*` methods. +//! Plans are destroyed when they dropped. +//! +//! Raw bindgen bindings are available in `cufft_raw`. + +mod error; +mod plan; + +pub use error::{CufftError, IntoResult}; +pub use plan::{Direction, FftPlan, FftType}; diff --git a/crates/cufft/src/plan.rs b/crates/cufft/src/plan.rs new file mode 100644 index 00000000..f24a6916 --- /dev/null +++ b/crates/cufft/src/plan.rs @@ -0,0 +1,317 @@ +use std::mem::MaybeUninit; + +use cust::memory::GpuBuffer; + +use crate::{CufftError, IntoResult}; + +/// cuFFT transform type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum FftType { + /// Single-precision real-to-complex. + R2C, + /// Single-precision complex-to-real. + C2R, + /// Single-precision complex-to-complex. + C2C, + /// Double-precision real-to-complex. + D2Z, + /// Double-precision complex-to-real. + Z2D, + /// Double-precision complex-to-complex. + Z2Z, +} + +impl FftType { + fn into_raw(self) -> cufft_raw::cufftType { + use cufft_raw::cufftType::*; + match self { + FftType::R2C => CUFFT_R2C, + FftType::C2R => CUFFT_C2R, + FftType::C2C => CUFFT_C2C, + FftType::D2Z => CUFFT_D2Z, + FftType::Z2D => CUFFT_Z2D, + FftType::Z2Z => CUFFT_Z2Z, + } + } +} + +/// FFT direction. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Direction { + /// Forward DFT (CUFFT_FORWARD = -1). + Forward, + /// Inverse DFT (CUFFT_INVERSE = +1). + Inverse, +} + +impl Direction { + pub(crate) fn into_raw(self) -> i32 { + match self { + Direction::Forward => cufft_raw::CUFFT_FORWARD as i32, + Direction::Inverse => cufft_raw::CUFFT_INVERSE as i32, + } + } +} + +/// Wrapper for a `cufftHandle`. +#[derive(Debug)] +pub struct FftPlan { + pub(crate) raw: cufft_raw::cufftHandle, +} + +impl FftPlan { + /// Creates a 1-D FFT plan. + /// + /// # Reference + /// + /// [cufftPlan1d](https://docs.nvidia.com/cuda/cufft/index.html#cufftplan1d) + pub fn plan_1d(nx: i32, fft_type: FftType, batch: i32) -> Result { + let mut raw = MaybeUninit::uninit(); + + unsafe { + cufft_raw::cufftPlan1d(raw.as_mut_ptr(), nx, fft_type.into_raw(), batch) + .into_result()?; + + Ok(Self { + raw: raw.assume_init(), + }) + } + } + + /// Creates a 2-D FFT plan. + /// + /// # Reference + /// + /// [cufftPlan2d](https://docs.nvidia.com/cuda/cufft/index.html#cufftplan2d) + pub fn plan_2d(nx: i32, ny: i32, fft_type: FftType) -> Result { + let mut raw = MaybeUninit::uninit(); + + unsafe { + cufft_raw::cufftPlan2d(raw.as_mut_ptr(), nx, ny, fft_type.into_raw()).into_result()?; + + Ok(Self { + raw: raw.assume_init(), + }) + } + } + + /// Creates a 3-D FFT plan. + /// + /// # Reference + /// + /// [cufftPlan3d](https://docs.nvidia.com/cuda/cufft/index.html#cufftplan3d) + pub fn plan_3d(nx: i32, ny: i32, nz: i32, fft_type: FftType) -> Result { + let mut raw = MaybeUninit::uninit(); + + unsafe { + cufft_raw::cufftPlan3d(raw.as_mut_ptr(), nx, ny, nz, fft_type.into_raw()) + .into_result()?; + + Ok(Self { + raw: raw.assume_init(), + }) + } + } + + /// Creates a batched N-D FFT plan with full stride/distance control. + /// + /// # Reference + /// + /// [cufftPlanMany](https://docs.nvidia.com/cuda/cufft/index.html#cufftplanmany) + #[allow(clippy::too_many_arguments)] + pub fn plan_many( + rank: i32, + n: &[i32], + inembed: Option<&[i32]>, + istride: i32, + idist: i32, + onembed: Option<&[i32]>, + ostride: i32, + odist: i32, + fft_type: FftType, + batch: i32, + ) -> Result { + let mut raw = MaybeUninit::uninit(); + + let inembed_ptr = inembed.map_or(std::ptr::null_mut(), |s| s.as_ptr().cast_mut()); + let onembed_ptr = onembed.map_or(std::ptr::null_mut(), |s| s.as_ptr().cast_mut()); + + unsafe { + cufft_raw::cufftPlanMany( + raw.as_mut_ptr(), + rank, + n.as_ptr().cast_mut(), + inembed_ptr, + istride, + idist, + onembed_ptr, + ostride, + odist, + fft_type.into_raw(), + batch, + ) + .into_result()?; + + Ok(Self { + raw: raw.assume_init(), + }) + } + } + + /// Set the CUDA stream for the plan. + /// + /// # Reference + /// + /// [cufftSetStream](https://docs.nvidia.com/cuda/cufft/index.html#cufftsetstream) + pub fn set_stream(&mut self, stream: &cust::stream::Stream) -> Result<(), CufftError> { + unsafe { + cufft_raw::cufftSetStream( + self.raw, + stream.as_inner() as *mut _ as cufft_raw::cudaStream_t, + ) + .into_result() + } + } + + /// Returns the raw `cufftHandle`. + pub fn as_raw(&self) -> cufft_raw::cufftHandle { + self.raw + } + + /// Executes a single-precision C2C FFT. + /// + /// # Reference + /// + /// [cufftExecC2C](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecc2c-and-cufftexecz2z) + pub fn exec_c2c( + &self, + idata: &impl GpuBuffer, + odata: &mut impl GpuBuffer, + direction: Direction, + ) -> Result<(), CufftError> { + unsafe { + cufft_raw::cufftExecC2C( + self.raw, + idata.as_device_ptr().as_mut_ptr(), + odata.as_device_ptr().as_mut_ptr(), + direction.into_raw(), + ) + .into_result() + } + } + + /// Executes a single-precision R2C FFT. + /// + /// # Reference + /// + /// [cufftExecR2C](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecr2c-and-cufftexecd2z) + pub fn exec_r2c( + &self, + idata: &impl GpuBuffer, + odata: &mut impl GpuBuffer, + ) -> Result<(), CufftError> { + unsafe { + cufft_raw::cufftExecR2C( + self.raw, + idata.as_device_ptr().as_mut_ptr(), + odata.as_device_ptr().as_mut_ptr(), + ) + .into_result() + } + } + + /// Executes a single-precision C2R inverse FFT. + /// + /// # Reference + /// + /// [cufftExecC2R](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecc2r-and-cufftexecz2d) + pub fn exec_c2r( + &self, + idata: &impl GpuBuffer, + odata: &mut impl GpuBuffer, + ) -> Result<(), CufftError> { + unsafe { + cufft_raw::cufftExecC2R( + self.raw, + idata.as_device_ptr().as_mut_ptr(), + odata.as_device_ptr().as_mut_ptr(), + ) + .into_result() + } + } + + /// Executes a double-precision Z2Z FFT. + /// + /// # Reference + /// + /// [cufftExecZ2Z](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecc2c-and-cufftexecz2z) + pub fn exec_z2z( + &self, + idata: &impl GpuBuffer, + odata: &mut impl GpuBuffer, + direction: Direction, + ) -> Result<(), CufftError> { + unsafe { + cufft_raw::cufftExecZ2Z( + self.raw, + idata.as_device_ptr().as_mut_ptr(), + odata.as_device_ptr().as_mut_ptr(), + direction.into_raw(), + ) + .into_result() + } + } + + /// Executes a double-precision D2Z FFT. + /// + /// # Reference + /// + /// [cufftExecD2Z](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecr2c-and-cufftexecd2z) + pub fn exec_d2z( + &self, + idata: &impl GpuBuffer, + odata: &mut impl GpuBuffer, + ) -> Result<(), CufftError> { + unsafe { + cufft_raw::cufftExecD2Z( + self.raw, + idata.as_device_ptr().as_mut_ptr(), + odata.as_device_ptr().as_mut_ptr(), + ) + .into_result() + } + } + + /// Executes a double-precision Z2D inverse FFT. + /// + /// # Reference + /// + /// [cufftExecZ2D](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecc2r-and-cufftexecz2d) + pub fn exec_z2d( + &self, + idata: &impl GpuBuffer, + odata: &mut impl GpuBuffer, + ) -> Result<(), CufftError> { + unsafe { + cufft_raw::cufftExecZ2D( + self.raw, + idata.as_device_ptr().as_mut_ptr(), + odata.as_device_ptr().as_mut_ptr(), + ) + .into_result() + } + } +} + +impl Drop for FftPlan { + /// Destroys the plan. + /// + /// # Reference + /// + /// [cufftDestroy)(https://docs.nvidia.com/cuda/cufft/index.html#cufftdestroy) + fn drop(&mut self) { + unsafe { + let _ = cufft_raw::cufftDestroy(self.raw); + } + } +} diff --git a/crates/cufft_raw/Cargo.toml b/crates/cufft_raw/Cargo.toml new file mode 100644 index 00000000..4f1ef74b --- /dev/null +++ b/crates/cufft_raw/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "cufft_raw" +version = "0.1.0" +edition = "2024" +license = "MIT OR Apache-2.0" +repository = "https://github.com/Rust-GPU/rust-cuda" +readme = "../../README.md" +links = "cufft" +build = "build/main.rs" + +[dependencies] +cust_raw = { version = "0.11.3", path = "../cust_raw", default-features = false, features = ["driver"] } +cust_core = { version = "0.1.1", path = "../cust_core" } + +[build-dependencies] +bindgen = "0.71.1" diff --git a/crates/cufft_raw/build/main.rs b/crates/cufft_raw/build/main.rs new file mode 100644 index 00000000..a3c2d1a9 --- /dev/null +++ b/crates/cufft_raw/build/main.rs @@ -0,0 +1,69 @@ +use std::env; +use std::path; + +fn main() { + let cuda_include_paths = env::var_os("DEP_CUDA_INCLUDES") + .map(|s| env::split_paths(s.as_os_str()).collect::>()) + .expect("DEP_CUDA_INCLUDES not set; ensure cust_raw is a dependency"); + + let cuda_root = env::var("DEP_CUDA_ROOT") + .map(path::PathBuf::from) + .expect("DEP_CUDA_ROOT not set; ensure cust_raw is a dependency"); + + println!("cargo::rerun-if-changed=build"); + + for dir in [ + cuda_root.join("lib64"), + cuda_root.join("lib"), + cuda_root.join("targets").join("x86_64-linux").join("lib"), + ] { + if dir.is_dir() { + println!("cargo::rustc-link-search=native={}", dir.display()); + } + } + + println!("cargo::rustc-link-lib=dylib=cufft"); + + create_cufft_bindings(&cuda_include_paths); +} + +fn create_cufft_bindings(cuda_include_paths: &[path::PathBuf]) { + println!("cargo::rerun-if-changed=build/wrapper.h"); + + let outdir = path::PathBuf::from(env::var("OUT_DIR").unwrap()); + let bindgen_path = outdir.join("cufft_raw.rs"); + + let bindings = bindgen::Builder::default() + .header("build/wrapper.h") + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .clang_args( + cuda_include_paths + .iter() + .map(|p| format!("-I{}", p.display())), + ) + .allowlist_function("^cufft.*") + .allowlist_type("^cufft.*") + .allowlist_var("^CUFFT.*") + // cuComplex/cuDoubleComplex are typedef'd as cufftComplex/cufftDoubleComplex + .allowlist_type("^cu.*Complex.*") + .allowlist_type("^float2$") + .allowlist_type("^double2$") + .default_enum_style(bindgen::EnumVariation::Rust { + non_exhaustive: false, + }) + .derive_default(true) + .derive_eq(true) + .derive_hash(true) + .derive_ord(true) + .size_t_is_usize(true) + .layout_tests(true) + .must_use_type("cufftResult") + .wrap_unsafe_ops(true) + .generate_comments(false) + .generate() + .expect("Unable to generate cuFFT bindings."); + + bindings + .write_to_file(&bindgen_path) + .expect("Cannot write cuFFT bindgen output to file."); +} diff --git a/crates/cufft_raw/build/wrapper.h b/crates/cufft_raw/build/wrapper.h new file mode 100644 index 00000000..38db8822 --- /dev/null +++ b/crates/cufft_raw/build/wrapper.h @@ -0,0 +1 @@ +#include "cufft.h" diff --git a/crates/cufft_raw/src/lib.rs b/crates/cufft_raw/src/lib.rs new file mode 100644 index 00000000..052c7cfe --- /dev/null +++ b/crates/cufft_raw/src/lib.rs @@ -0,0 +1,8 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!(concat!(env!("OUT_DIR"), "/cufft_raw.rs")); + +unsafe impl cust_core::DeviceCopy for float2 {} +unsafe impl cust_core::DeviceCopy for double2 {} diff --git a/examples/fft/Cargo.toml b/examples/fft/Cargo.toml new file mode 100644 index 00000000..27a4397a --- /dev/null +++ b/examples/fft/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "fft" +version = "0.1.0" +edition = "2024" + +[dependencies] +cufft = { path = "../../crates/cufft" } +cufft_raw = { path = "../../crates/cufft_raw" } +cust = { path = "../../crates/cust" } diff --git a/examples/fft/src/main.rs b/examples/fft/src/main.rs new file mode 100644 index 00000000..57edbe5b --- /dev/null +++ b/examples/fft/src/main.rs @@ -0,0 +1,82 @@ +use std::f32::consts::PI; + +use cufft::{Direction, FftPlan, FftType}; +use cufft_raw::float2; +use cust::memory::{CopyDestination, DeviceBuffer}; + +const FFT_SIZE: usize = 1024; +const FREQUENCY_BIN: usize = 13; + +fn main() -> Result<(), Box> { + let _ctx = cust::quick_init()?; + + // Build a complex sinusoid `x[n] = exp(j * 2 * pi * FREQUENCY_BIN * n / FFT_SIZE)` + + let host_signal: Vec = (0..FFT_SIZE) + .map(|n| { + let phase = 2.0 * PI * FREQUENCY_BIN as f32 * n as f32 / FFT_SIZE as f32; + + float2 { + x: phase.cos(), + y: phase.sin(), + } + }) + .collect(); + + let device_signal = DeviceBuffer::from_slice(&host_signal)?; + let mut device_spectrum = unsafe { DeviceBuffer::::uninitialized(FFT_SIZE)? }; + + let plan = FftPlan::plan_1d(FFT_SIZE as i32, FftType::C2C, 1)?; + + // Forward FFT. + + plan.exec_c2c(&device_signal, &mut device_spectrum, Direction::Forward)?; + + let mut host_spectrum = vec![float2 { x: 0.0, y: 0.0 }; FFT_SIZE]; + device_spectrum.copy_to(&mut host_spectrum)?; + + // The energy should be concentrated at FREQUENCY_BIN. + + let peak = host_spectrum + .iter() + .enumerate() + .max_by(|(_, lhs), (_, rhs)| { + let norm = |vector: &float2| vector.x * vector.x + vector.y * vector.y; + + norm(lhs).partial_cmp(&norm(rhs)).unwrap() + }) + .map(|(bin, _)| bin) + .unwrap(); + + println!("input frequency bin : {FREQUENCY_BIN}"); + println!("peak frequency bin : {peak}"); + + assert_eq!(peak, FREQUENCY_BIN, "Input frequency bin does not equal peak frequency bin."); + + // Inverse FFT then normalize by FFT_SIZE to recover the original signal. + + let mut device_recovered = unsafe { DeviceBuffer::::uninitialized(FFT_SIZE)? }; + + plan.exec_c2c(&device_spectrum, &mut device_recovered, Direction::Inverse)?; + + let mut host_recovered = vec![float2 { x: 0.0, y: 0.0 }; FFT_SIZE]; + + device_recovered.copy_to(&mut host_recovered)?; + + let max_err = host_signal + .iter() + .zip(host_recovered.iter()) + .map(|(lhs, rhs)| { + let dx = rhs.x / FFT_SIZE as f32 - lhs.x; + let dy = rhs.y / FFT_SIZE as f32 - lhs.y; + + (dx * dx + dy * dy).sqrt() + }) + .fold(0.0f32, f32::max); + + println!("Round-trip max error: {max_err}"); + + assert!(max_err < 1e-5, "Round-trip error too large: {max_err}"); + + Ok(()) +} From 5078c7fc2e4c61e97f5f00f3ae296f521b87965a Mon Sep 17 00:00:00 2001 From: Casey Sanchez Date: Fri, 5 Jun 2026 13:46:21 -0400 Subject: [PATCH 2/4] fix(cufft): apply clippy fix to plan.rs --- crates/cufft/src/plan.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/cufft/src/plan.rs b/crates/cufft/src/plan.rs index f24a6916..e92926a2 100644 --- a/crates/cufft/src/plan.rs +++ b/crates/cufft/src/plan.rs @@ -47,7 +47,7 @@ pub enum Direction { impl Direction { pub(crate) fn into_raw(self) -> i32 { match self { - Direction::Forward => cufft_raw::CUFFT_FORWARD as i32, + Direction::Forward => cufft_raw::CUFFT_FORWARD, Direction::Inverse => cufft_raw::CUFFT_INVERSE as i32, } } From 77292edd9306f63fcdd67b80f906dd7a5ddb778d Mon Sep 17 00:00:00 2001 From: Casey Sanchez Date: Fri, 5 Jun 2026 16:31:23 -0400 Subject: [PATCH 3/4] feat(cufft): improve API with type-state pattern of the FftType --- crates/cufft/src/lib.rs | 11 +-- crates/cufft/src/plan.rs | 145 ++++++++++++++++++++++++++------------- examples/fft/src/main.rs | 13 ++-- 3 files changed, 111 insertions(+), 58 deletions(-) diff --git a/crates/cufft/src/lib.rs b/crates/cufft/src/lib.rs index 0cf3f4f7..1accf865 100644 --- a/crates/cufft/src/lib.rs +++ b/crates/cufft/src/lib.rs @@ -1,9 +1,10 @@ //! Rust wrapper for the [cuFFT library](https://docs.nvidia.com/cuda/cufft/). //! -//! Create a `FftPlan` with one of the plan constructors. -//! Optionally attach a stream with `FftPlan::set_stream` -//! Execute the plan with one of the `exec_*` methods. -//! Plans are destroyed when they dropped. +//! Create a `FftPlan` with one of the plan constructors, +//! where `T` implements the `FftType` trait (`C2C`, `R2C`, `C2R`, `Z2Z`, `D2Z`, `Z2D`). +//! Optionally attach a stream with `FftPlan::set_stream`. +//! Execute the plan with `FftPlan::exec`. +//! Plans are destroyed when dropped. //! //! Raw bindgen bindings are available in `cufft_raw`. @@ -11,4 +12,4 @@ mod error; mod plan; pub use error::{CufftError, IntoResult}; -pub use plan::{Direction, FftPlan, FftType}; +pub use plan::{C2C, C2R, D2Z, Direction, FftPlan, FftType, R2C, Z2D, Z2Z}; diff --git a/crates/cufft/src/plan.rs b/crates/cufft/src/plan.rs index e92926a2..6584f8d1 100644 --- a/crates/cufft/src/plan.rs +++ b/crates/cufft/src/plan.rs @@ -1,37 +1,72 @@ +use std::marker::PhantomData; use std::mem::MaybeUninit; use cust::memory::GpuBuffer; use crate::{CufftError, IntoResult}; -/// cuFFT transform type. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum FftType { - /// Single-precision real-to-complex. - R2C, - /// Single-precision complex-to-real. - C2R, - /// Single-precision complex-to-complex. - C2C, - /// Double-precision real-to-complex. - D2Z, - /// Double-precision complex-to-real. - Z2D, - /// Double-precision complex-to-complex. - Z2Z, +mod sealed { + pub trait Sealed {} } -impl FftType { - fn into_raw(self) -> cufft_raw::cufftType { - use cufft_raw::cufftType::*; - match self { - FftType::R2C => CUFFT_R2C, - FftType::C2R => CUFFT_C2R, - FftType::C2C => CUFFT_C2C, - FftType::D2Z => CUFFT_D2Z, - FftType::Z2D => CUFFT_Z2D, - FftType::Z2Z => CUFFT_Z2Z, - } +pub trait FftType: sealed::Sealed { + #[doc(hidden)] + fn fft_type() -> cufft_raw::cufftType; +} + +/// Marker type for single-precision complex-to-complex transforms. +pub struct C2C; +/// Marker type for single-precision real-to-complex transforms. +pub struct R2C; +/// Marker type for single-precision complex-to-real transforms. +pub struct C2R; +/// Marker type for double-precision real-to-complex transforms. +pub struct D2Z; +/// Marker type for double-precision complex-to-real transforms. +pub struct Z2D; +/// Marker type for double-precision complex-to-complex transforms. +pub struct Z2Z; + +impl sealed::Sealed for C2C {} +impl sealed::Sealed for R2C {} +impl sealed::Sealed for C2R {} +impl sealed::Sealed for D2Z {} +impl sealed::Sealed for Z2D {} +impl sealed::Sealed for Z2Z {} + +impl FftType for C2C { + fn fft_type() -> cufft_raw::cufftType { + cufft_raw::cufftType::CUFFT_C2C + } +} + +impl FftType for R2C { + fn fft_type() -> cufft_raw::cufftType { + cufft_raw::cufftType::CUFFT_R2C + } +} + +impl FftType for C2R { + fn fft_type() -> cufft_raw::cufftType { + cufft_raw::cufftType::CUFFT_C2R + } +} + +impl FftType for D2Z { + fn fft_type() -> cufft_raw::cufftType { + cufft_raw::cufftType::CUFFT_D2Z + } +} + +impl FftType for Z2D { + fn fft_type() -> cufft_raw::cufftType { + cufft_raw::cufftType::CUFFT_Z2D + } +} + +impl FftType for Z2Z { + fn fft_type() -> cufft_raw::cufftType { + cufft_raw::cufftType::CUFFT_Z2Z } } @@ -55,25 +90,26 @@ impl Direction { /// Wrapper for a `cufftHandle`. #[derive(Debug)] -pub struct FftPlan { - pub(crate) raw: cufft_raw::cufftHandle, +pub struct FftPlan { + raw: cufft_raw::cufftHandle, + _marker: PhantomData, } -impl FftPlan { +impl FftPlan { /// Creates a 1-D FFT plan. /// /// # Reference /// /// [cufftPlan1d](https://docs.nvidia.com/cuda/cufft/index.html#cufftplan1d) - pub fn plan_1d(nx: i32, fft_type: FftType, batch: i32) -> Result { + pub fn plan_1d(nx: i32, batch: i32) -> Result { let mut raw = MaybeUninit::uninit(); unsafe { - cufft_raw::cufftPlan1d(raw.as_mut_ptr(), nx, fft_type.into_raw(), batch) - .into_result()?; + cufft_raw::cufftPlan1d(raw.as_mut_ptr(), nx, T::fft_type(), batch).into_result()?; Ok(Self { raw: raw.assume_init(), + _marker: PhantomData, }) } } @@ -83,14 +119,15 @@ impl FftPlan { /// # Reference /// /// [cufftPlan2d](https://docs.nvidia.com/cuda/cufft/index.html#cufftplan2d) - pub fn plan_2d(nx: i32, ny: i32, fft_type: FftType) -> Result { + pub fn plan_2d(nx: i32, ny: i32) -> Result { let mut raw = MaybeUninit::uninit(); unsafe { - cufft_raw::cufftPlan2d(raw.as_mut_ptr(), nx, ny, fft_type.into_raw()).into_result()?; + cufft_raw::cufftPlan2d(raw.as_mut_ptr(), nx, ny, T::fft_type()).into_result()?; Ok(Self { raw: raw.assume_init(), + _marker: PhantomData, }) } } @@ -100,15 +137,15 @@ impl FftPlan { /// # Reference /// /// [cufftPlan3d](https://docs.nvidia.com/cuda/cufft/index.html#cufftplan3d) - pub fn plan_3d(nx: i32, ny: i32, nz: i32, fft_type: FftType) -> Result { + pub fn plan_3d(nx: i32, ny: i32, nz: i32) -> Result { let mut raw = MaybeUninit::uninit(); unsafe { - cufft_raw::cufftPlan3d(raw.as_mut_ptr(), nx, ny, nz, fft_type.into_raw()) - .into_result()?; + cufft_raw::cufftPlan3d(raw.as_mut_ptr(), nx, ny, nz, T::fft_type()).into_result()?; Ok(Self { raw: raw.assume_init(), + _marker: PhantomData, }) } } @@ -128,7 +165,6 @@ impl FftPlan { onembed: Option<&[i32]>, ostride: i32, odist: i32, - fft_type: FftType, batch: i32, ) -> Result { let mut raw = MaybeUninit::uninit(); @@ -147,18 +183,19 @@ impl FftPlan { onembed_ptr, ostride, odist, - fft_type.into_raw(), + T::fft_type(), batch, ) .into_result()?; Ok(Self { raw: raw.assume_init(), + _marker: PhantomData, }) } } - /// Set the CUDA stream for the plan. + /// Sets the CUDA stream for the plan. /// /// # Reference /// @@ -177,13 +214,15 @@ impl FftPlan { pub fn as_raw(&self) -> cufft_raw::cufftHandle { self.raw } +} +impl FftPlan { /// Executes a single-precision C2C FFT. /// /// # Reference /// /// [cufftExecC2C](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecc2c-and-cufftexecz2z) - pub fn exec_c2c( + pub fn exec( &self, idata: &impl GpuBuffer, odata: &mut impl GpuBuffer, @@ -199,13 +238,15 @@ impl FftPlan { .into_result() } } +} +impl FftPlan { /// Executes a single-precision R2C FFT. /// /// # Reference /// /// [cufftExecR2C](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecr2c-and-cufftexecd2z) - pub fn exec_r2c( + pub fn exec( &self, idata: &impl GpuBuffer, odata: &mut impl GpuBuffer, @@ -219,13 +260,15 @@ impl FftPlan { .into_result() } } +} +impl FftPlan { /// Executes a single-precision C2R inverse FFT. /// /// # Reference /// /// [cufftExecC2R](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecc2r-and-cufftexecz2d) - pub fn exec_c2r( + pub fn exec( &self, idata: &impl GpuBuffer, odata: &mut impl GpuBuffer, @@ -239,13 +282,15 @@ impl FftPlan { .into_result() } } +} +impl FftPlan { /// Executes a double-precision Z2Z FFT. /// /// # Reference /// /// [cufftExecZ2Z](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecc2c-and-cufftexecz2z) - pub fn exec_z2z( + pub fn exec( &self, idata: &impl GpuBuffer, odata: &mut impl GpuBuffer, @@ -261,13 +306,15 @@ impl FftPlan { .into_result() } } +} +impl FftPlan { /// Executes a double-precision D2Z FFT. /// /// # Reference /// /// [cufftExecD2Z](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecr2c-and-cufftexecd2z) - pub fn exec_d2z( + pub fn exec( &self, idata: &impl GpuBuffer, odata: &mut impl GpuBuffer, @@ -281,13 +328,15 @@ impl FftPlan { .into_result() } } +} +impl FftPlan { /// Executes a double-precision Z2D inverse FFT. /// /// # Reference /// /// [cufftExecZ2D](https://docs.nvidia.com/cuda/cufft/index.html#cufftexecc2r-and-cufftexecz2d) - pub fn exec_z2d( + pub fn exec( &self, idata: &impl GpuBuffer, odata: &mut impl GpuBuffer, @@ -303,12 +352,12 @@ impl FftPlan { } } -impl Drop for FftPlan { +impl Drop for FftPlan { /// Destroys the plan. /// /// # Reference /// - /// [cufftDestroy)(https://docs.nvidia.com/cuda/cufft/index.html#cufftdestroy) + /// [cufftDestroy](https://docs.nvidia.com/cuda/cufft/index.html#cufftdestroy) fn drop(&mut self) { unsafe { let _ = cufft_raw::cufftDestroy(self.raw); diff --git a/examples/fft/src/main.rs b/examples/fft/src/main.rs index 57edbe5b..6903dae3 100644 --- a/examples/fft/src/main.rs +++ b/examples/fft/src/main.rs @@ -1,6 +1,6 @@ use std::f32::consts::PI; -use cufft::{Direction, FftPlan, FftType}; +use cufft::{C2C, Direction, FftPlan}; use cufft_raw::float2; use cust::memory::{CopyDestination, DeviceBuffer}; @@ -26,11 +26,11 @@ fn main() -> Result<(), Box> { let device_signal = DeviceBuffer::from_slice(&host_signal)?; let mut device_spectrum = unsafe { DeviceBuffer::::uninitialized(FFT_SIZE)? }; - let plan = FftPlan::plan_1d(FFT_SIZE as i32, FftType::C2C, 1)?; + let plan = FftPlan::::plan_1d(FFT_SIZE as i32, 1)?; // Forward FFT. - plan.exec_c2c(&device_signal, &mut device_spectrum, Direction::Forward)?; + plan.exec(&device_signal, &mut device_spectrum, Direction::Forward)?; let mut host_spectrum = vec![float2 { x: 0.0, y: 0.0 }; FFT_SIZE]; device_spectrum.copy_to(&mut host_spectrum)?; @@ -51,13 +51,16 @@ fn main() -> Result<(), Box> { println!("input frequency bin : {FREQUENCY_BIN}"); println!("peak frequency bin : {peak}"); - assert_eq!(peak, FREQUENCY_BIN, "Input frequency bin does not equal peak frequency bin."); + assert_eq!( + peak, FREQUENCY_BIN, + "Input frequency bin does not equal peak frequency bin." + ); // Inverse FFT then normalize by FFT_SIZE to recover the original signal. let mut device_recovered = unsafe { DeviceBuffer::::uninitialized(FFT_SIZE)? }; - plan.exec_c2c(&device_spectrum, &mut device_recovered, Direction::Inverse)?; + plan.exec(&device_spectrum, &mut device_recovered, Direction::Inverse)?; let mut host_recovered = vec![float2 { x: 0.0, y: 0.0 }; FFT_SIZE]; From 7b3898f9da2e2b0c687f713dc8e5ab285086fcee Mon Sep 17 00:00:00 2001 From: Casey Sanchez Date: Sat, 6 Jun 2026 08:28:12 -0400 Subject: [PATCH 4/4] fix(cufft): remove redundant rustc-link-search instruction and more closely align with cudnn-sys build script --- crates/cufft_raw/build/main.rs | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/crates/cufft_raw/build/main.rs b/crates/cufft_raw/build/main.rs index a3c2d1a9..a563da82 100644 --- a/crates/cufft_raw/build/main.rs +++ b/crates/cufft_raw/build/main.rs @@ -4,27 +4,12 @@ use std::path; fn main() { let cuda_include_paths = env::var_os("DEP_CUDA_INCLUDES") .map(|s| env::split_paths(s.as_os_str()).collect::>()) - .expect("DEP_CUDA_INCLUDES not set; ensure cust_raw is a dependency"); - - let cuda_root = env::var("DEP_CUDA_ROOT") - .map(path::PathBuf::from) - .expect("DEP_CUDA_ROOT not set; ensure cust_raw is a dependency"); + .expect("Cannot find transitive metadata 'cuda_include' from cust_raw package."); println!("cargo::rerun-if-changed=build"); - for dir in [ - cuda_root.join("lib64"), - cuda_root.join("lib"), - cuda_root.join("targets").join("x86_64-linux").join("lib"), - ] { - if dir.is_dir() { - println!("cargo::rustc-link-search=native={}", dir.display()); - } - } - - println!("cargo::rustc-link-lib=dylib=cufft"); - create_cufft_bindings(&cuda_include_paths); + println!("cargo::rustc-link-lib=dylib=cufft"); } fn create_cufft_bindings(cuda_include_paths: &[path::PathBuf]) {