From febe6f68a15a5d51050de35ecf8ee20a70e383b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 1 Dec 2025 15:39:45 +0000 Subject: [PATCH 01/23] Add v3 that is the same as v3 still --- build2cmake/src/config/common.rs | 24 ++ build2cmake/src/config/mod.rs | 33 +-- build2cmake/src/config/v1.rs | 2 +- build2cmake/src/config/v2.rs | 135 +---------- build2cmake/src/config/v3.rs | 388 +++++++++++++++++++++++++++++++ 5 files changed, 438 insertions(+), 144 deletions(-) create mode 100644 build2cmake/src/config/common.rs create mode 100644 build2cmake/src/config/v3.rs diff --git a/build2cmake/src/config/common.rs b/build2cmake/src/config/common.rs new file mode 100644 index 00000000..7b0eeced --- /dev/null +++ b/build2cmake/src/config/common.rs @@ -0,0 +1,24 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +#[non_exhaustive] +#[serde(rename_all = "lowercase")] +pub enum Dependency { + #[serde(rename = "cutlass_2_10")] + Cutlass2_10, + #[serde(rename = "cutlass_3_5")] + Cutlass3_5, + #[serde(rename = "cutlass_3_6")] + Cutlass3_6, + #[serde(rename = "cutlass_3_8")] + Cutlass3_8, + #[serde(rename = "cutlass_3_9")] + Cutlass3_9, + #[serde(rename = "cutlass_4_0")] + Cutlass4_0, + #[serde(rename = "cutlass_sycl")] + CutlassSycl, + #[serde(rename = "metal-cpp")] + MetalCpp, + Torch, +} diff --git a/build2cmake/src/config/mod.rs b/build2cmake/src/config/mod.rs index ceb6e823..f9dc234d 100644 --- a/build2cmake/src/config/mod.rs +++ b/build2cmake/src/config/mod.rs @@ -1,16 +1,22 @@ use eyre::Result; use serde::Deserialize; +use serde_value::Value; pub mod v1; +mod common; + mod v2; -use serde_value::Value; -pub use v2::{Backend, Build, Dependency, General, Kernel, Torch}; + +mod v3; +pub use common::Dependency; +pub use v3::{Backend, Build, General, Kernel, Torch}; #[derive(Debug)] pub enum BuildCompat { V1(v1::Build), - V2(Build), + V2(v2::Build), + V3(Build), } impl<'de> Deserialize<'de> for BuildCompat { @@ -20,14 +26,11 @@ impl<'de> Deserialize<'de> for BuildCompat { { let value = Value::deserialize(deserializer)?; - match v1::Build::deserialize(value.clone()) { - Ok(v1_build) => Ok(BuildCompat::V1(v1_build)), - Err(_) => { - let v2_build: Build = - Build::deserialize(value).map_err(serde::de::Error::custom)?; - Ok(BuildCompat::V2(v2_build)) - } - } + v1::Build::deserialize(value.clone()) + .map(BuildCompat::V1) + .or_else(|_| v2::Build::deserialize(value.clone()).map(BuildCompat::V2)) + .or_else(|_| Build::deserialize(value.clone()).map(BuildCompat::V3)) + .map_err(serde::de::Error::custom) } } @@ -36,8 +39,12 @@ impl TryFrom for Build { fn try_from(compat: BuildCompat) -> Result { match compat { - BuildCompat::V1(v1_build) => v1_build.try_into(), - BuildCompat::V2(v2_build) => Ok(v2_build), + BuildCompat::V1(v1_build) => { + let v2_build: v2::Build = v1_build.try_into()?; + v2_build.try_into() + } + BuildCompat::V2(v2_build) => v2_build.try_into(), + BuildCompat::V3(v3_build) => Ok(v3_build), } } } diff --git a/build2cmake/src/config/v1.rs b/build2cmake/src/config/v1.rs index 8b095bb7..5582d7ff 100644 --- a/build2cmake/src/config/v1.rs +++ b/build2cmake/src/config/v1.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, fmt::Display, path::PathBuf}; use serde::Deserialize; -use super::v2::Dependency; +use super::common::Dependency; #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 03cba382..b003bc33 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -1,17 +1,14 @@ -use std::{ - collections::{BTreeSet, HashMap}, - fmt::Display, - path::PathBuf, - str::FromStr, -}; +use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr}; use eyre::{bail, Result}; -use itertools::Itertools; use serde::{Deserialize, Serialize}; use crate::version::Version; -use super::v1::{self, Language}; +use super::{ + common::Dependency, + v1::{self, Language}, +}; #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] @@ -23,25 +20,6 @@ pub struct Build { pub kernels: HashMap, } -impl Build { - pub fn has_kernel_with_backend(&self, backend: &Backend) -> bool { - self.backends().contains(backend) - } - - pub fn backends(&self) -> BTreeSet { - self.kernels - .values() - .map(|kernel| match kernel { - Kernel::Cpu { .. } => Backend::Cpu, - Kernel::Cuda { .. } => Backend::Cuda, - Kernel::Metal { .. } => Backend::Metal, - Kernel::Rocm { .. } => Backend::Rocm, - Kernel::Xpu { .. } => Backend::Xpu, - }) - .collect() - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub struct General { @@ -58,13 +36,6 @@ pub struct General { pub python_depends: Option>, } -impl General { - /// Name of the kernel as a Python extension. - pub fn python_name(&self) -> String { - self.name.replace("-", "_") - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub struct Hub { @@ -100,27 +71,6 @@ pub struct Torch { pub src: Vec, } -impl Torch { - pub fn data_globs(&self) -> Option> { - match self.pyext.as_ref() { - Some(exts) => { - let globs = exts - .iter() - .filter(|&ext| ext != "py" && ext != "pyi") - .map(|ext| format!("\"**/*.{ext}\"")) - .collect_vec(); - if globs.is_empty() { - None - } else { - Some(globs) - } - } - - None => None, - } - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")] pub enum Kernel { @@ -167,58 +117,6 @@ pub enum Kernel { }, } -impl Kernel { - pub fn cxx_flags(&self) -> Option<&[String]> { - match self { - Kernel::Cpu { cxx_flags, .. } - | Kernel::Cuda { cxx_flags, .. } - | Kernel::Metal { cxx_flags, .. } - | Kernel::Rocm { cxx_flags, .. } - | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(), - } - } - - pub fn include(&self) -> Option<&[String]> { - match self { - Kernel::Cpu { include, .. } - | Kernel::Cuda { include, .. } - | Kernel::Metal { include, .. } - | Kernel::Rocm { include, .. } - | Kernel::Xpu { include, .. } => include.as_deref(), - } - } - - pub fn backend(&self) -> Backend { - match self { - Kernel::Cpu { .. } => Backend::Cpu, - Kernel::Cuda { .. } => Backend::Cuda, - Kernel::Metal { .. } => Backend::Metal, - Kernel::Rocm { .. } => Backend::Rocm, - Kernel::Xpu { .. } => Backend::Xpu, - } - } - - pub fn depends(&self) -> &[Dependency] { - match self { - Kernel::Cpu { depends, .. } - | Kernel::Cuda { depends, .. } - | Kernel::Metal { depends, .. } - | Kernel::Rocm { depends, .. } - | Kernel::Xpu { depends, .. } => depends, - } - } - - pub fn src(&self) -> &[String] { - match self { - Kernel::Cpu { src, .. } - | Kernel::Cuda { src, .. } - | Kernel::Metal { src, .. } - | Kernel::Rocm { src, .. } - | Kernel::Xpu { src, .. } => src, - } - } -} - #[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub enum Backend { @@ -256,29 +154,6 @@ impl FromStr for Backend { } } -#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -#[non_exhaustive] -#[serde(rename_all = "lowercase")] -pub enum Dependency { - #[serde(rename = "cutlass_2_10")] - Cutlass2_10, - #[serde(rename = "cutlass_3_5")] - Cutlass3_5, - #[serde(rename = "cutlass_3_6")] - Cutlass3_6, - #[serde(rename = "cutlass_3_8")] - Cutlass3_8, - #[serde(rename = "cutlass_3_9")] - Cutlass3_9, - #[serde(rename = "cutlass_4_0")] - Cutlass4_0, - #[serde(rename = "cutlass_sycl")] - CutlassSycl, - #[serde(rename = "metal-cpp")] - MetalCpp, - Torch, -} - impl TryFrom for Build { type Error = eyre::Error; diff --git a/build2cmake/src/config/v3.rs b/build2cmake/src/config/v3.rs new file mode 100644 index 00000000..90eadcc4 --- /dev/null +++ b/build2cmake/src/config/v3.rs @@ -0,0 +1,388 @@ +use std::{ + collections::{BTreeSet, HashMap}, + fmt::Display, + path::PathBuf, + str::FromStr, +}; + +use eyre::Result; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::version::Version; + +use super::{common::Dependency, v2}; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct Build { + pub general: General, + pub torch: Option, + + #[serde(rename = "kernel", default)] + pub kernels: HashMap, +} + +impl Build { + pub fn has_kernel_with_backend(&self, backend: &Backend) -> bool { + self.backends().contains(backend) + } + + pub fn backends(&self) -> BTreeSet { + self.kernels + .values() + .map(|kernel| match kernel { + Kernel::Cpu { .. } => Backend::Cpu, + Kernel::Cuda { .. } => Backend::Cuda, + Kernel::Metal { .. } => Backend::Metal, + Kernel::Rocm { .. } => Backend::Rocm, + Kernel::Xpu { .. } => Backend::Xpu, + }) + .collect() + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct General { + pub name: String, + #[serde(default)] + pub universal: bool, + + pub cuda_maxver: Option, + + pub cuda_minver: Option, + + pub hub: Option, + + pub python_depends: Option>, +} + +impl General { + /// Name of the kernel as a Python extension. + pub fn python_name(&self) -> String { + self.name.replace("-", "_") + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct Hub { + pub repo_id: Option, + pub branch: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub enum PythonDependency { + Einops, + NvidiaCutlassDsl, +} + +impl Display for PythonDependency { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PythonDependency::Einops => write!(f, "einops"), + PythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"), + } + } +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields)] +pub struct Torch { + pub include: Option>, + pub pyext: Option>, + + #[serde(default)] + pub src: Vec, +} + +impl Torch { + pub fn data_globs(&self) -> Option> { + match self.pyext.as_ref() { + Some(exts) => { + let globs = exts + .iter() + .filter(|&ext| ext != "py" && ext != "pyi") + .map(|ext| format!("\"**/*.{ext}\"")) + .collect_vec(); + if globs.is_empty() { + None + } else { + Some(globs) + } + } + + None => None, + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")] +pub enum Kernel { + #[serde(rename_all = "kebab-case")] + Cpu { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + #[serde(rename_all = "kebab-case")] + Cuda { + cuda_capabilities: Option>, + cuda_flags: Option>, + cuda_minver: Option, + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + #[serde(rename_all = "kebab-case")] + Metal { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + #[serde(rename_all = "kebab-case")] + Rocm { + cxx_flags: Option>, + depends: Vec, + rocm_archs: Option>, + hip_flags: Option>, + include: Option>, + src: Vec, + }, + #[serde(rename_all = "kebab-case")] + Xpu { + cxx_flags: Option>, + depends: Vec, + sycl_flags: Option>, + include: Option>, + src: Vec, + }, +} + +impl Kernel { + pub fn cxx_flags(&self) -> Option<&[String]> { + match self { + Kernel::Cpu { cxx_flags, .. } + | Kernel::Cuda { cxx_flags, .. } + | Kernel::Metal { cxx_flags, .. } + | Kernel::Rocm { cxx_flags, .. } + | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(), + } + } + + pub fn include(&self) -> Option<&[String]> { + match self { + Kernel::Cpu { include, .. } + | Kernel::Cuda { include, .. } + | Kernel::Metal { include, .. } + | Kernel::Rocm { include, .. } + | Kernel::Xpu { include, .. } => include.as_deref(), + } + } + + pub fn backend(&self) -> Backend { + match self { + Kernel::Cpu { .. } => Backend::Cpu, + Kernel::Cuda { .. } => Backend::Cuda, + Kernel::Metal { .. } => Backend::Metal, + Kernel::Rocm { .. } => Backend::Rocm, + Kernel::Xpu { .. } => Backend::Xpu, + } + } + + pub fn depends(&self) -> &[Dependency] { + match self { + Kernel::Cpu { depends, .. } + | Kernel::Cuda { depends, .. } + | Kernel::Metal { depends, .. } + | Kernel::Rocm { depends, .. } + | Kernel::Xpu { depends, .. } => depends, + } + } + + pub fn src(&self) -> &[String] { + match self { + Kernel::Cpu { src, .. } + | Kernel::Cuda { src, .. } + | Kernel::Metal { src, .. } + | Kernel::Rocm { src, .. } + | Kernel::Xpu { src, .. } => src, + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub enum Backend { + Cpu, + Cuda, + Metal, + Rocm, + Xpu, +} + +impl Display for Backend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Backend::Cpu => write!(f, "cpu"), + Backend::Cuda => write!(f, "cuda"), + Backend::Metal => write!(f, "metal"), + Backend::Rocm => write!(f, "rocm"), + Backend::Xpu => write!(f, "xpu"), + } + } +} + +impl FromStr for Backend { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "cpu" => Ok(Backend::Cpu), + "cuda" => Ok(Backend::Cuda), + "metal" => Ok(Backend::Metal), + "rocm" => Ok(Backend::Rocm), + "xpu" => Ok(Backend::Xpu), + _ => Err(format!("Unknown backend: {s}")), + } + } +} + +impl TryFrom for Build { + type Error = eyre::Error; + + fn try_from(build: v2::Build) -> Result { + Ok(Self { + general: build.general.into(), + torch: build.torch.map(Into::into), + kernels: build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(), + }) + } +} + +impl From for General { + fn from(general: v2::General) -> Self { + Self { + name: general.name, + universal: general.universal, + cuda_maxver: general.cuda_maxver, + cuda_minver: general.cuda_minver, + hub: general.hub.map(Into::into), + python_depends: general + .python_depends + .map(|deps| deps.into_iter().map(Into::into).collect()), + } + } +} + +impl From for Hub { + fn from(hub: v2::Hub) -> Self { + Self { + repo_id: hub.repo_id, + branch: hub.branch, + } + } +} + +impl From for PythonDependency { + fn from(dep: v2::PythonDependency) -> Self { + match dep { + v2::PythonDependency::Einops => PythonDependency::Einops, + v2::PythonDependency::NvidiaCutlassDsl => PythonDependency::NvidiaCutlassDsl, + } + } +} + +impl From for Torch { + fn from(torch: v2::Torch) -> Self { + Self { + include: torch.include, + pyext: torch.pyext, + src: torch.src, + } + } +} + +impl From for Kernel { + fn from(kernel: v2::Kernel) -> Self { + match kernel { + v2::Kernel::Cpu { + cxx_flags, + depends, + include, + src, + } => Kernel::Cpu { + cxx_flags, + depends, + include, + src, + }, + v2::Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + } => Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + }, + v2::Kernel::Metal { + cxx_flags, + depends, + include, + src, + } => Kernel::Metal { + cxx_flags, + depends, + include, + src, + }, + v2::Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + } => Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + }, + v2::Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + } => Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + }, + } + } +} From 12921f13e4f3c16322a36d21aed09cc8815f45e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 1 Dec 2025 15:47:58 +0000 Subject: [PATCH 02/23] Move cuda-minver/maxver to `general.cuda` --- build2cmake/src/config/v3.rs | 27 ++++++++++++++++++++++----- build2cmake/src/torch/cuda.rs | 4 ++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/build2cmake/src/config/v3.rs b/build2cmake/src/config/v3.rs index 90eadcc4..25ae3836 100644 --- a/build2cmake/src/config/v3.rs +++ b/build2cmake/src/config/v3.rs @@ -49,9 +49,7 @@ pub struct General { #[serde(default)] pub universal: bool, - pub cuda_maxver: Option, - - pub cuda_minver: Option, + pub cuda: Option, pub hub: Option, @@ -65,6 +63,13 @@ impl General { } } +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct CudaGeneral { + pub minver: Option, + pub maxver: Option, +} + #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub struct Hub { @@ -92,6 +97,8 @@ impl Display for PythonDependency { #[serde(deny_unknown_fields)] pub struct Torch { pub include: Option>, + pub minver: Option, + pub maxver: Option, pub pyext: Option>, #[serde(default)] @@ -272,11 +279,19 @@ impl TryFrom for Build { impl From for General { fn from(general: v2::General) -> Self { + let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() { + Some(CudaGeneral { + minver: general.cuda_minver, + maxver: general.cuda_maxver, + }) + } else { + None + }; + Self { name: general.name, universal: general.universal, - cuda_maxver: general.cuda_maxver, - cuda_minver: general.cuda_minver, + cuda, hub: general.hub.map(Into::into), python_depends: general .python_depends @@ -307,6 +322,8 @@ impl From for Torch { fn from(torch: v2::Torch) -> Self { Self { include: torch.include, + minver: torch.minver, + maxver: torch.maxver, pyext: torch.pyext, src: torch.src, } diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index 297df3d5..e96432fe 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -166,8 +166,8 @@ fn write_cmake( render_preamble( env, name, - build.general.cuda_minver.as_ref(), - build.general.cuda_maxver.as_ref(), + build.general.cuda.as_ref().and_then(|c| c.minver.as_ref()), + build.general.cuda.as_ref().and_then(|c| c.maxver.as_ref()), torch.minver.as_ref(), torch.maxver.as_ref(), cmake_writer, From 264d17d8b31ea30ddaa084a763e4593f2d3fea1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 1 Dec 2025 16:07:09 +0000 Subject: [PATCH 03/23] Remove `universal` from the v3 config. Universal kernels still exist in the project writers. --- build2cmake/src/config/v3.rs | 45 +++++++++++++++++++++++++++--------- build2cmake/src/main.rs | 4 ++-- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/build2cmake/src/config/v3.rs b/build2cmake/src/config/v3.rs index 25ae3836..e2bc693a 100644 --- a/build2cmake/src/config/v3.rs +++ b/build2cmake/src/config/v3.rs @@ -24,6 +24,11 @@ pub struct Build { } impl Build { + /// Check if this is a universal build (supports all backends). + pub fn is_universal(&self) -> bool { + self.kernels.is_empty() + } + pub fn has_kernel_with_backend(&self, backend: &Backend) -> bool { self.backends().contains(backend) } @@ -46,8 +51,8 @@ impl Build { #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub struct General { pub name: String, - #[serde(default)] - pub universal: bool, + + pub backends: Vec, pub cuda: Option, @@ -265,20 +270,38 @@ impl TryFrom for Build { type Error = eyre::Error; fn try_from(build: v2::Build) -> Result { + let kernels: HashMap = build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(); + + let backends = if build.general.universal { + vec![ + "cpu".to_string(), + "cuda".to_string(), + "metal".to_string(), + "rocm".to_string(), + "xpu".to_string(), + ] + } else { + let backend_set: BTreeSet = kernels + .values() + .map(|kernel| kernel.backend().to_string()) + .collect(); + backend_set.into_iter().collect() + }; + Ok(Self { - general: build.general.into(), + general: General::from_v2(build.general, backends), torch: build.torch.map(Into::into), - kernels: build - .kernels - .into_iter() - .map(|(k, v)| (k, v.into())) - .collect(), + kernels, }) } } -impl From for General { - fn from(general: v2::General) -> Self { +impl General { + fn from_v2(general: v2::General, backends: Vec) -> Self { let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() { Some(CudaGeneral { minver: general.cuda_minver, @@ -290,7 +313,7 @@ impl From for General { Self { name: general.name, - universal: general.universal, + backends, cuda, hub: general.hub.map(Into::into), python_depends: general diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index ca419389..a6211b3c 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -140,7 +140,7 @@ fn generate_torch( env.set_trim_blocks(true); minijinja_embed::load_templates!(&mut env); - let backend = match (backend, build.general.universal) { + let backend = match (backend, build.is_universal()) { (None, true) => { let file_set = write_torch_ext_universal(&env, &build, target_dir.clone(), ops_id)?; file_set.write(&target_dir, force)?; @@ -393,7 +393,7 @@ fn get_generated_files( all_set.extend(set); } - if build.general.universal { + if build.is_universal() { let set = write_torch_ext_universal(env, build, target_dir, ops_id)?; all_set.extend(set); From 333da803d7f38ff28b4f8ba9869e11d02d2a7a4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 2 Dec 2025 09:44:29 +0000 Subject: [PATCH 04/23] build2cmake: fix backend handling --- build2cmake/src/config/v3.rs | 40 +++----- build2cmake/src/main.rs | 97 ++++++++++--------- build2cmake/src/torch/mod.rs | 4 +- .../src/torch/{universal.rs => noarch.rs} | 2 +- 4 files changed, 64 insertions(+), 79 deletions(-) rename build2cmake/src/torch/{universal.rs => noarch.rs} (98%) diff --git a/build2cmake/src/config/v3.rs b/build2cmake/src/config/v3.rs index e2bc693a..5043bc04 100644 --- a/build2cmake/src/config/v3.rs +++ b/build2cmake/src/config/v3.rs @@ -24,26 +24,12 @@ pub struct Build { } impl Build { - /// Check if this is a universal build (supports all backends). - pub fn is_universal(&self) -> bool { + pub fn is_noarch(&self) -> bool { self.kernels.is_empty() } - pub fn has_kernel_with_backend(&self, backend: &Backend) -> bool { - self.backends().contains(backend) - } - - pub fn backends(&self) -> BTreeSet { - self.kernels - .values() - .map(|kernel| match kernel { - Kernel::Cpu { .. } => Backend::Cpu, - Kernel::Cuda { .. } => Backend::Cuda, - Kernel::Metal { .. } => Backend::Metal, - Kernel::Rocm { .. } => Backend::Rocm, - Kernel::Xpu { .. } => Backend::Xpu, - }) - .collect() + pub fn supports_backend(&self, backend: &Backend) -> bool { + self.general.backends.contains(backend) } } @@ -52,7 +38,7 @@ impl Build { pub struct General { pub name: String, - pub backends: Vec, + pub backends: Vec, pub cuda: Option, @@ -278,17 +264,15 @@ impl TryFrom for Build { let backends = if build.general.universal { vec![ - "cpu".to_string(), - "cuda".to_string(), - "metal".to_string(), - "rocm".to_string(), - "xpu".to_string(), + Backend::Cpu, + Backend::Cuda, + Backend::Metal, + Backend::Rocm, + Backend::Xpu, ] } else { - let backend_set: BTreeSet = kernels - .values() - .map(|kernel| kernel.backend().to_string()) - .collect(); + let backend_set: BTreeSet = + kernels.values().map(|kernel| kernel.backend()).collect(); backend_set.into_iter().collect() }; @@ -301,7 +285,7 @@ impl TryFrom for Build { } impl General { - fn from_v2(general: v2::General, backends: Vec) -> Self { + fn from_v2(general: v2::General, backends: Vec) -> Self { let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() { Some(CudaGeneral { minver: general.cuda_minver, diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index a6211b3c..741424e8 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -10,7 +10,7 @@ use minijinja::Environment; mod torch; use torch::{ - write_torch_ext_cpu, write_torch_ext_cuda, write_torch_ext_metal, write_torch_ext_universal, + write_torch_ext_cpu, write_torch_ext_cuda, write_torch_ext_metal, write_torch_ext_noarch, write_torch_ext_xpu, }; @@ -140,51 +140,48 @@ fn generate_torch( env.set_trim_blocks(true); minijinja_embed::load_templates!(&mut env); - let backend = match (backend, build.is_universal()) { - (None, true) => { - let file_set = write_torch_ext_universal(&env, &build, target_dir.clone(), ops_id)?; - file_set.write(&target_dir, force)?; - return Ok(()); - } - (Some(backend), true) => bail!("Universal kernel, cannot generate for backend {}", backend), - (Some(backend), false) => { - if !build.has_kernel_with_backend(&backend) { - bail!("No kernels found for backend {}", backend); + let backend = match backend { + Some(backend) => { + if !build.supports_backend(&backend) { + bail!("Kernel does not support backend: {}", backend); } backend } - (None, false) => { - let mut kernel_backends = build.backends(); - let backend = if let Some(backend) = kernel_backends.pop_first() { - backend - } else { - bail!("No kernels found in build.toml"); - }; - - if !kernel_backends.is_empty() { - let kernel_backends: Vec<_> = build - .backends() - .into_iter() - .map(|backend| backend.to_string()) - .collect(); + None => { + let kernel_backends = &build.general.backends; + + if kernel_backends.len() > 1 { + let mut kernel_backends = kernel_backends + .iter() + .map(ToString::to_string) + .collect::>(); + kernel_backends.sort(); bail!( "Multiple supported backends found in build.toml: {}. Please specify one with --backend.", kernel_backends.join(", ") ); } - backend + if let Some(backend) = kernel_backends.first() { + *backend + } else { + bail!("No backends are specified in build.toml"); + } } }; - let file_set = match backend { - Backend::Cpu => write_torch_ext_cpu(&env, &build, target_dir.clone(), ops_id)?, - Backend::Cuda | Backend::Rocm => { - write_torch_ext_cuda(&env, backend, &build, target_dir.clone(), ops_id)? + let file_set = if build.is_noarch() { + write_torch_ext_noarch(&env, &build, target_dir.clone(), ops_id)? + } else { + match backend { + Backend::Cpu => write_torch_ext_cpu(&env, &build, target_dir.clone(), ops_id)?, + Backend::Cuda | Backend::Rocm => { + write_torch_ext_cuda(&env, backend, &build, target_dir.clone(), ops_id)? + } + Backend::Metal => write_torch_ext_metal(&env, &build, target_dir.clone(), ops_id)?, + Backend::Xpu => write_torch_ext_xpu(&env, &build, target_dir.clone(), ops_id)?, } - Backend::Metal => write_torch_ext_metal(&env, &build, target_dir.clone(), ops_id)?, - Backend::Xpu => write_torch_ext_xpu(&env, &build, target_dir.clone(), ops_id)?, }; file_set.write(&target_dir, force)?; @@ -378,25 +375,29 @@ fn get_generated_files( ) -> Result> { let mut all_set = FileSet::new(); - for backend in build.backends() { - let set = match backend { - Backend::Cpu => write_torch_ext_cpu(env, build, target_dir.clone(), ops_id.clone())?, - Backend::Cuda | Backend::Rocm => { - write_torch_ext_cuda(env, backend, build, target_dir.clone(), ops_id.clone())? - } - Backend::Metal => { - write_torch_ext_metal(env, build, target_dir.clone(), ops_id.clone())? - } - Backend::Xpu => write_torch_ext_xpu(env, build, target_dir.clone(), ops_id.clone())?, - }; + if build.is_noarch() { + let set = write_torch_ext_noarch(env, build, target_dir.clone(), ops_id.clone())?; all_set.extend(set); - } - - if build.is_universal() { - let set = write_torch_ext_universal(env, build, target_dir, ops_id)?; + } else { + for backend in &build.general.backends { + let set = match backend { + Backend::Cpu => { + write_torch_ext_cpu(env, build, target_dir.clone(), ops_id.clone())? + } + Backend::Cuda | Backend::Rocm => { + write_torch_ext_cuda(env, *backend, build, target_dir.clone(), ops_id.clone())? + } + Backend::Metal => { + write_torch_ext_metal(env, build, target_dir.clone(), ops_id.clone())? + } + Backend::Xpu => { + write_torch_ext_xpu(env, build, target_dir.clone(), ops_id.clone())? + } + }; - all_set.extend(set); + all_set.extend(set); + } } Ok(all_set.into_names()) diff --git a/build2cmake/src/torch/mod.rs b/build2cmake/src/torch/mod.rs index bc5ba1e2..f9cbfeaf 100644 --- a/build2cmake/src/torch/mod.rs +++ b/build2cmake/src/torch/mod.rs @@ -12,8 +12,8 @@ pub use metal::write_torch_ext_metal; mod ops_identifier; pub(crate) use ops_identifier::kernel_ops_identifier; -mod universal; -pub use universal::write_torch_ext_universal; +mod noarch; +pub use noarch::write_torch_ext_noarch; mod xpu; pub use xpu::write_torch_ext_xpu; diff --git a/build2cmake/src/torch/universal.rs b/build2cmake/src/torch/noarch.rs similarity index 98% rename from build2cmake/src/torch/universal.rs rename to build2cmake/src/torch/noarch.rs index 4622a234..f9821d61 100644 --- a/build2cmake/src/torch/universal.rs +++ b/build2cmake/src/torch/noarch.rs @@ -10,7 +10,7 @@ use crate::{ torch::kernel_ops_identifier, }; -pub fn write_torch_ext_universal( +pub fn write_torch_ext_noarch( env: &Environment, build: &Build, target_dir: PathBuf, From 3f89453887499e1d7cd0e3f0c093aec1e7f67077 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 2 Dec 2025 12:31:23 +0000 Subject: [PATCH 05/23] Make `update-build` update to config v3 --- build2cmake/src/main.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index 741424e8..b4223031 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -126,9 +126,9 @@ fn generate_torch( let build_compat = parse_and_validate(build_toml)?; - if matches!(build_compat, BuildCompat::V1(_)) { + if matches!(build_compat, BuildCompat::V1(_) | BuildCompat::V2(_)) { eprintln!( - "build.toml is in the deprecated V1 format, use `build2cmake update-build` to update." + "build.toml is in the deprecated V1 or V2 format, use `build2cmake update-build` to update." ) } @@ -191,7 +191,7 @@ fn generate_torch( fn update_build(build_toml: PathBuf) -> Result<()> { let build_compat: BuildCompat = parse_and_validate(&build_toml)?; - if matches!(build_compat, BuildCompat::V2(_)) { + if matches!(build_compat, BuildCompat::V3(_)) { return Ok(()); } @@ -263,9 +263,9 @@ fn clean( let build_compat = parse_and_validate(build_toml)?; - if matches!(build_compat, BuildCompat::V1(_)) { + if matches!(build_compat, BuildCompat::V1(_) | BuildCompat::V2(_)) { eprintln!( - "build.toml is in the deprecated V1 format, use `build2cmake update-build` to update." + "build.toml is in the deprecated V1 or V2 format, use `build2cmake update-build` to update." ) } From d843829512d840f5174018034193dbade4d08e18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 2 Dec 2025 13:12:57 +0000 Subject: [PATCH 06/23] Nix: work with new build.toml format changes This still generates noarch variants with unnecessary bits (CUDA version, system, etc.). --- lib/build-sets.nix | 28 ++++++++--------- lib/build-variants.nix | 31 ++++--------------- lib/build.nix | 55 +++++++++++++++++---------------- lib/torch-extension/no-arch.nix | 26 +++++++++++++++- lib/torch-version-utils.nix | 29 +++++++++++++---- 5 files changed, 96 insertions(+), 73 deletions(-) diff --git a/lib/build-sets.nix b/lib/build-sets.nix index 919cd255..1d9e9c5c 100644 --- a/lib/build-sets.nix +++ b/lib/build-sets.nix @@ -10,12 +10,8 @@ let overlay = import ../overlay.nix; inherit (import ./torch-version-utils.nix { inherit lib; }) + backend flattenSystems - isCpu - isCuda - isMetal - isRocm - isXpu ; # All build configurations supported by Torch. @@ -23,10 +19,11 @@ let system: let filterMap = f: xs: builtins.filter (x: x != null) (builtins.map f xs); + systemBuildConfigs = filterMap (version: if version.system == system then version else null) ( + flattenSystems torchVersions + ); in - filterMap (version: if version.system == system then version else null) ( - flattenSystems torchVersions - ); + builtins.map (buildConfig: buildConfig // { backend = backend buildConfig; }) systemBuildConfigs; cudaVersions = let @@ -61,8 +58,9 @@ let xpuPackages = super."xpuPackages_${flattenVersion xpuVersion}"; }; # Construct the nixpkgs package set for the given versions. - pkgsForVersions = + mkBuildSet = buildConfig@{ + backend, cpu ? false, cudaVersion ? null, metal ? false, @@ -76,15 +74,15 @@ let }: let pkgs = - if isCpu buildConfig then + if buildConfig.backend == "cpu" then pkgsForCpu - else if isCuda buildConfig then + else if buildConfig.backend == "cuda" then pkgsByCudaVer.${cudaVersion} - else if isRocm buildConfig then + else if buildConfig.backend == "rocm" then pkgsByRocmVer.${rocmVersion} - else if isMetal buildConfig then + else if buildConfig.backend == "metal" then pkgsForMetal - else if isXpu buildConfig then + else if buildConfig.backend == "xpu" then pkgsByXpuVer.${xpuVersion} else throw "No compute framework set in Torch version"; @@ -185,4 +183,4 @@ let pkgsByRocmVer = pkgsForRocmVersions rocmVersions; in -map pkgsForVersions (buildConfigs system) +map mkBuildSet (buildConfigs system) diff --git a/lib/build-variants.nix b/lib/build-variants.nix index 98bb834d..741d471f 100644 --- a/lib/build-variants.nix +++ b/lib/build-variants.nix @@ -2,43 +2,24 @@ let inherit (import ./torch-version-utils.nix { inherit lib; }) flattenSystems - isCpu - isCuda - isMetal - isRocm - isXpu ; in rec { - computeFramework = - buildConfig: - if buildConfig.cpu or false then - "cpu" - else if buildConfig ? cudaVersion then - "cuda" - else if buildConfig.metal or false then - "metal" - else if buildConfig ? "rocmVersion" then - "rocm" - else if buildConfig ? xpuVersion then - "xpu" - else - throw "Could not find compute framework: no CUDA, ROCm, XPU version specified and CPU and Metal are not enabled"; buildName = let inherit (import ./version-utils.nix { inherit lib; }) abiString flattenVersion; computeString = version: - if isCpu version then + if version.backend == "cpu" then "cpu" - else if isCuda version then + else if version.backend == "cuda" then "cu${flattenVersion (lib.versions.majorMinor version.cudaVersion)}" - else if isRocm version then + else if version.backend == "rocm" then "rocm${flattenVersion (lib.versions.majorMinor version.rocmVersion)}" - else if isMetal version then + else if version.backend == "metal" then "metal" - else if isXpu version then + else if version.backend == "xpu" then "xpu${flattenVersion (lib.versions.majorMinor version.xpuVersion)}" else throw "No compute framework set in Torch version"; @@ -60,7 +41,7 @@ rec { let path = [ version.system - (computeFramework version) + version.backend ]; pathVersions = lib.attrByPath path [ ] acc ++ [ (buildName version) ]; in diff --git a/lib/build.nix b/lib/build.nix index a39e1a00..0ea8e1ea 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -17,14 +17,6 @@ let supportedCudaCapabilities = builtins.fromJSON ( builtins.readFile ../build2cmake/src/cuda_supported_archs.json ); - inherit (import ./torch-version-utils.nix { inherit lib; }) - isCpu - isCuda - isMetal - isRocm - isXpu - ; - inherit (import ./build-variants.nix { inherit lib; }) computeFramework; in rec { readToml = path: builtins.fromTOML (builtins.readFile path); @@ -42,7 +34,24 @@ rec { build2cmake update-build build.toml''; buildToml; + # Backends supported by the kernel. backends = + buildToml: + let + init = { + cpu = false; + cuda = false; + metal = false; + rocm = false; + xpu = false; + }; + in + lib.foldl (backends: backend: backends // { ${backend} = true; }) init ( + buildToml.general.backends or [ ] + ); + + # Backends for which there is a native (compiled kernel). + kernelBackends = buildToml: let kernels = lib.attrValues (buildToml.kernel or { }); @@ -71,8 +80,9 @@ rec { buildToml: buildSets: let backends' = backends buildToml; - minCuda = buildToml.general.cuda-minver or "11.8"; - maxCuda = buildToml.general.cuda-maxver or "99.9"; + # COMPAT: buildToml.general.cuda-{minver,maxver} are backwards compat for v2 build.toml. + minCuda = buildToml.general.cuda.cuda-minver or buildToml.general.cuda-minver or "11.8"; + maxCuda = buildToml.general.cuda.cuda-maxver or buildToml.general.cuda-maxver or "99.9"; minTorch = buildToml.torch.minver or "2.0"; maxTorch = buildToml.torch.maxver or "99.9"; versionBetween = @@ -81,15 +91,9 @@ rec { supportedBuildSet = buildSet: let - backendSupported = - (isCpu buildSet.buildConfig && backends'.cpu) - || (isCuda buildSet.buildConfig && backends'.cuda) - || (isRocm buildSet.buildConfig && backends'.rocm) - || (isMetal buildSet.buildConfig && backends'.metal) - || (isXpu buildSet.buildConfig && backends'.xpu) - || (buildToml.general.universal or false); + backendSupported = backends'.${buildSet.buildConfig.backend}; cudaVersionSupported = - !(isCuda buildSet.buildConfig) + buildSet.buildConfig.backend != "cuda" || versionBetween minCuda maxCuda buildSet.pkgs.cudaPackages.cudaMajorMinorVersion; torchVersionParts = lib.splitString "." buildSet.torch.version; torchMajorMinor = lib.concatStringsSep "." (lib.take 2 torchVersionParts); @@ -120,7 +124,8 @@ rec { let inherit (lib) fileset; buildToml = readBuildConfig path; - kernels = lib.filterAttrs (_: kernel: computeFramework buildConfig == kernel.backend) ( + kernelBackends' = kernelBackends buildToml; + kernels = lib.filterAttrs (_: kernel: buildConfig.backend == kernel.backend) ( buildToml.kernel or { } ); extraDeps = @@ -141,11 +146,12 @@ rec { ) buildToml.kernel ); in - if buildToml.general.universal then - # No torch extension sources? Treat it as a noarch package. + if !kernelBackends'.${buildConfig.backend} then + # No compiled kernel files? Treat it as a noarch package. extension.mkNoArchExtension { inherit + buildConfig src rev doGetKernelCheck @@ -214,11 +220,8 @@ rec { }; buildToml = readBuildConfig path; namePaths = - if buildToml.general.universal then - # Noarch, just get the first extension. - { "torch-universal" = builtins.head (builtins.attrValues extensions); } - else - lib.mapAttrs (name: pkg: toString pkg) extensions; + # TODO: treat kernels without compiled parts differently. + lib.mapAttrs (name: pkg: toString pkg) extensions; in import ./join-paths { inherit pkgs namePaths; diff --git a/lib/torch-extension/no-arch.nix b/lib/torch-extension/no-arch.nix index bae6471c..28c3d0a0 100644 --- a/lib/torch-extension/no-arch.nix +++ b/lib/torch-extension/no-arch.nix @@ -1,4 +1,8 @@ { + cudaSupport ? torch.cudaSupport, + rocmSupport ? torch.rocmSupport, + xpuSupport ? torch.xpuSupport, + lib, pkgs, stdenv, @@ -14,6 +18,8 @@ }: { + buildConfig, + # Whether to run get-kernel-check. doGetKernelCheck ? true, @@ -30,6 +36,12 @@ pythonDeps, }: +# Extra validation - the environment should correspind to the build config. +assert (buildConfig ? cudaVersion) -> cudaSupport; +assert (buildConfig ? rocmVersion) -> rocmSupport; +assert (buildConfig ? xpuVersion) -> xpuSupport; +assert (buildConfig.metal or false) -> stdenv.hostPlatform.isDarwin; + let inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps; dependencies = resolvePythonDeps pythonDeps ++ [ torch ]; @@ -38,6 +50,7 @@ let python-depends = pythonDeps; }; metadataFile = writeText "metadata.json" metadata; + metalSupport = buildConfig.metal or false; in stdenv.mkDerivation (prevAttrs: { @@ -64,7 +77,18 @@ stdenv.mkDerivation (prevAttrs: { # build. But `build2cmake` does proper validation of the build.toml, so # we run it anyway. postPatch = '' - build2cmake generate-torch --ops-id ${rev} build.toml + build2cmake generate-torch --backend ${ + if cudaSupport then + "cuda" + else if rocmSupport then + "rocm" + else if xpuSupport then + "xpu" + else if metalSupport then + "metal" + else + "cpu" + } --ops-id ${rev} build.toml ''; installPhase = '' diff --git a/lib/torch-version-utils.nix b/lib/torch-version-utils.nix index 7bda0313..5b43482c 100644 --- a/lib/torch-version-utils.nix +++ b/lib/torch-version-utils.nix @@ -1,5 +1,13 @@ { lib }: -{ +let + isCpu = version: version.cpu or false; + isCuda = version: version ? cudaVersion; + isMetal = version: version.metal or false; + isRocm = version: version ? rocmVersion; + isXpu = version: version ? xpuVersion; + +in +rec { # Expand { systems = [ a b ]; .. } to [ { system = a; ..} { system = b; .. } ] flattenSystems = lib.foldl' ( acc: version: @@ -7,9 +15,18 @@ ++ map (system: (builtins.removeAttrs version [ "systems" ]) // { inherit system; }) version.systems ) [ ]; - isCpu = version: version.cpu or false; - isCuda = version: version ? cudaVersion; - isMetal = version: version.metal or false; - isRocm = version: version ? rocmVersion; - isXpu = version: version ? xpuVersion; + backend = + version: + if isCpu version then + "cpu" + else if isCuda version then + "cuda" + else if isMetal version then + "metal" + else if isRocm version then + "rocm" + else if isXpu version then + "xpu" + else + throw "Could not find compute framework: no CUDA, ROCm, XPU version specified and CPU and Metal are not enabled"; } From da53c14d701e56bda06366b85d52376905392de1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 2 Dec 2025 14:28:03 +0000 Subject: [PATCH 07/23] Support no-arch variants --- lib/build.nix | 4 ++-- lib/torch-extension/arch.nix | 37 +++++++++++++++++++++------------ lib/torch-extension/no-arch.nix | 18 +++++----------- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/lib/build.nix b/lib/build.nix index 0ea8e1ea..57968d62 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -188,8 +188,8 @@ rec { let extensionForTorch = { path, rev }: - buildSet: { - name = buildName buildSet.buildConfig; + buildSet: rec { + name = value.variant; value = mkTorchExtension buildSet { inherit path rev doGetKernelCheck; stripRPath = true; diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index 954e803a..22111cb1 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -74,6 +74,26 @@ assert (buildConfig.metal or false) -> stdenv.hostPlatform.isDarwin; let inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps; + abiString = cxx11Abi: if cxx11Abi then "cxx11" else "cxx98"; + + computeStrings = { + cpu = "cpu"; + cuda = "cu${flattenVersion (lib.versions.majorMinor buildConfig.cudaVersion)}"; + metal = "metal"; + rocm = "rocm${flattenVersion (lib.versions.majorMinor buildConfig.rocmVersion)}"; + xpu = "xpu${flattenVersion (lib.versions.majorMinor buildConfig.xpuVersion)}"; + }; + computeString = computeStrings.${buildConfig.backend}; + + flattenVersion = + version: lib.replaceStrings [ "." ] [ "" ] (lib.versions.majorMinor (lib.versions.pad 2 version)); + + variant = + if buildConfig.system == "aarch64-darwin" then + "torch${flattenVersion buildConfig.torchVersion}-${computeString}-${buildConfig.system}" + else + "torch${flattenVersion buildConfig.torchVersion}-${abiString buildConfig.cxx11Abi}-${computeString}-${buildConfig.system}"; + dependencies = resolvePythonDeps pythonDeps ++ [ torch ]; moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName; @@ -108,18 +128,9 @@ stdenv.mkDerivation (prevAttrs: { # Generate build files. postPatch = '' - build2cmake generate-torch --backend ${ - if cudaSupport then - "cuda" - else if rocmSupport then - "rocm" - else if xpuSupport then - "xpu" - else if metalSupport then - "metal" - else - "cpu" - } --ops-id ${rev} build.toml + build2cmake generate-torch \ + --backend ${buildConfig.backend} \ + --ops-id ${rev} build.toml ''; preConfigure = @@ -280,6 +291,6 @@ stdenv.mkDerivation (prevAttrs: { __noChroot = metalSupport; passthru = { - inherit dependencies torch; + inherit dependencies torch variant; }; }) diff --git a/lib/torch-extension/no-arch.nix b/lib/torch-extension/no-arch.nix index 28c3d0a0..2fa79f6d 100644 --- a/lib/torch-extension/no-arch.nix +++ b/lib/torch-extension/no-arch.nix @@ -51,6 +51,7 @@ let }; metadataFile = writeText "metadata.json" metadata; metalSupport = buildConfig.metal or false; + variant = "torch-${buildConfig.backend}"; in stdenv.mkDerivation (prevAttrs: { @@ -77,18 +78,9 @@ stdenv.mkDerivation (prevAttrs: { # build. But `build2cmake` does proper validation of the build.toml, so # we run it anyway. postPatch = '' - build2cmake generate-torch --backend ${ - if cudaSupport then - "cuda" - else if rocmSupport then - "rocm" - else if xpuSupport then - "xpu" - else if metalSupport then - "metal" - else - "cpu" - } --ops-id ${rev} build.toml + build2cmake generate-torch \ + --backend ${buildConfig.backend} \ + --ops-id ${rev} build.toml ''; installPhase = '' @@ -102,6 +94,6 @@ stdenv.mkDerivation (prevAttrs: { doInstallCheck = true; passthru = { - inherit dependencies; + inherit dependencies variant; }; }) From 72df9794c846233084be9889b50004b1cdf86708 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 2 Dec 2025 15:48:24 +0000 Subject: [PATCH 08/23] Let Torch derivations create the variant names --- flake.nix | 9 +-- lib/build-sets.nix | 11 +++- lib/build.nix | 5 +- lib/gen-flake-outputs.nix | 4 +- lib/torch-extension/arch.nix | 23 +------- lib/torch-extension/no-arch.nix | 4 +- pkgs/python-modules/torch/binary/generic.nix | 8 ++- .../torch/source/2_8/default.nix | 6 +- .../torch/source/2_9/default.nix | 6 +- pkgs/python-modules/torch/variant.nix | 55 +++++++++++++++++++ 10 files changed, 92 insertions(+), 39 deletions(-) create mode 100644 pkgs/python-modules/torch/variant.nix diff --git a/flake.nix b/flake.nix index 66f4fb8f..315cbe9d 100644 --- a/flake.nix +++ b/flake.nix @@ -112,8 +112,6 @@ pkgs = nixpkgs.legacyPackages.${system}; inherit (nixpkgs) lib; - buildName = (import ./lib/build-variants.nix { inherit lib; }).buildName; - buildSets = defaultBuildSetsPerSystem.${system}; in @@ -151,12 +149,11 @@ ++ allOutputs python3Packages.kernels ++ lib.optionals stdenv.hostPlatform.isLinux (allOutputs stdenvGlibc_2_27) ); - buildSetLinkFarm = - buildSet: pkgs.linkFarm (buildName buildSet.buildConfig) (buildSetOutputs buildSet); + buildSetLinkFarm = buildSet: pkgs.linkFarm buildSet.torch.variant (buildSetOutputs buildSet); in pkgs.linkFarm "packages-for-cache" ( map (buildSet: { - name = buildName (buildSet.buildConfig); + name = buildSet.torch.variant; path = buildSetLinkFarm buildSet; }) buildSets ); @@ -182,7 +179,7 @@ # This package set is exposed so that we can prebuild the Torch versions. torch = builtins.listToAttrs ( map (buildSet: { - name = buildName (buildSet.buildConfig); + name = buildSet.torch.variant; value = buildSet.torch; }) buildSets ); diff --git a/lib/build-sets.nix b/lib/build-sets.nix index 1d9e9c5c..dc4df1ea 100644 --- a/lib/build-sets.nix +++ b/lib/build-sets.nix @@ -126,7 +126,16 @@ let ); pkgsByXpuVer = pkgsForXpuVersions xpuVersions; - pkgsForMetal = pkgsForCpu; + pkgsForMetal = import nixpkgs { + inherit system; + config = { + allowUnfree = true; + metalSupport = true; + }; + overlays = [ + overlay + ]; + }; pkgsForCpu = import nixpkgs { inherit system; diff --git a/lib/build.nix b/lib/build.nix index 57968d62..2015b9a4 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -13,7 +13,6 @@ let abi = torch: if torch.passthru.cxx11Abi then "cxx11" else "cxx98"; - buildName = (import ./build-variants.nix { inherit lib; }).buildName; supportedCudaCapabilities = builtins.fromJSON ( builtins.readFile ../build2cmake/src/cuda_supported_archs.json ); @@ -250,7 +249,7 @@ rec { extension = mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }; in { - name = buildName buildSet.buildConfig; + name = buildSet.torch.variant; value = mkShell { nativeBuildInputs = with pkgs; pythonNativeCheckInputs python3.pkgs; @@ -308,7 +307,7 @@ rec { ); in { - name = buildName buildSet.buildConfig; + name = buildSet.torch.variant; value = mkShell rec { nativeBuildInputs = with pkgs; diff --git a/lib/gen-flake-outputs.nix b/lib/gen-flake-outputs.nix index 4144a311..ca3db386 100644 --- a/lib/gen-flake-outputs.nix +++ b/lib/gen-flake-outputs.nix @@ -17,8 +17,6 @@ }: let - inherit (import ./build-variants.nix { inherit lib; }) buildName; - supportedFormat = '' kernel-builder.lib.genFlakeOutputs { inherit self; @@ -102,7 +100,7 @@ let throw "No build variant is compatible with this system" else builtins.head buildSetsSorted; - shellTorch = buildName bestBuildSet.buildConfig; + shellTorch = bestBuildSet.torch.variant; headOrEmpty = l: if l == [ ] then [ ] else [ (builtins.head l) ]; in { diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index 22111cb1..301353fb 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -74,26 +74,6 @@ assert (buildConfig.metal or false) -> stdenv.hostPlatform.isDarwin; let inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps; - abiString = cxx11Abi: if cxx11Abi then "cxx11" else "cxx98"; - - computeStrings = { - cpu = "cpu"; - cuda = "cu${flattenVersion (lib.versions.majorMinor buildConfig.cudaVersion)}"; - metal = "metal"; - rocm = "rocm${flattenVersion (lib.versions.majorMinor buildConfig.rocmVersion)}"; - xpu = "xpu${flattenVersion (lib.versions.majorMinor buildConfig.xpuVersion)}"; - }; - computeString = computeStrings.${buildConfig.backend}; - - flattenVersion = - version: lib.replaceStrings [ "." ] [ "" ] (lib.versions.majorMinor (lib.versions.pad 2 version)); - - variant = - if buildConfig.system == "aarch64-darwin" then - "torch${flattenVersion buildConfig.torchVersion}-${computeString}-${buildConfig.system}" - else - "torch${flattenVersion buildConfig.torchVersion}-${abiString buildConfig.cxx11Abi}-${computeString}-${buildConfig.system}"; - dependencies = resolvePythonDeps pythonDeps ++ [ torch ]; moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName; @@ -291,6 +271,7 @@ stdenv.mkDerivation (prevAttrs: { __noChroot = metalSupport; passthru = { - inherit dependencies torch variant; + inherit dependencies torch; + inherit (torch) variant; }; }) diff --git a/lib/torch-extension/no-arch.nix b/lib/torch-extension/no-arch.nix index 2fa79f6d..56e291f7 100644 --- a/lib/torch-extension/no-arch.nix +++ b/lib/torch-extension/no-arch.nix @@ -51,7 +51,6 @@ let }; metadataFile = writeText "metadata.json" metadata; metalSupport = buildConfig.metal or false; - variant = "torch-${buildConfig.backend}"; in stdenv.mkDerivation (prevAttrs: { @@ -94,6 +93,7 @@ stdenv.mkDerivation (prevAttrs: { doInstallCheck = true; passthru = { - inherit dependencies variant; + inherit dependencies; + variant = torch.noarchVariant; }; }) diff --git a/pkgs/python-modules/torch/binary/generic.nix b/pkgs/python-modules/torch/binary/generic.nix index 4f4a90b0..83999272 100644 --- a/pkgs/python-modules/torch/binary/generic.nix +++ b/pkgs/python-modules/torch/binary/generic.nix @@ -1,4 +1,5 @@ { + callPackage, config, lib, stdenv, @@ -7,6 +8,7 @@ fetchurl, cudaSupport ? config.cudaSupport, + metalSupport ? config.metalSupport, rocmSupport ? config.rocmSupport, tritonSupport ? (!stdenv.hostPlatform.isDarwin), xpuSupport ? (config.xpuSupport or false), @@ -329,7 +331,11 @@ buildPythonPackage { cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ]; rocmArchs = if rocmSupport then supportedTorchRocmArchs else [ ]; - }; + } + // (callPackage ../variant.nix { + inherit cxx11Abi; + torchVersion = version; + }); meta = with lib; { description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration"; diff --git a/pkgs/python-modules/torch/source/2_8/default.nix b/pkgs/python-modules/torch/source/2_8/default.nix index fec710fd..7b64209c 100644 --- a/pkgs/python-modules/torch/source/2_8/default.nix +++ b/pkgs/python-modules/torch/source/2_8/default.nix @@ -741,7 +741,11 @@ buildPythonPackage rec { # To help debug when a package is broken due to CUDA support inherit brokenConditions; tests = callPackage ./tests.nix { }; - }; + } + // (callPackage ../variant.nix { + inherit cxx11Abi; + torchVersion = version; + }); meta = { changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}"; diff --git a/pkgs/python-modules/torch/source/2_9/default.nix b/pkgs/python-modules/torch/source/2_9/default.nix index 24b7c8f4..6b27b60a 100644 --- a/pkgs/python-modules/torch/source/2_9/default.nix +++ b/pkgs/python-modules/torch/source/2_9/default.nix @@ -724,7 +724,11 @@ buildPythonPackage rec { # To help debug when a package is broken due to CUDA support inherit brokenConditions; tests = callPackage ./tests.nix { }; - }; + } + // (callPackage ../variant.nix { + inherit cxx11Abi; + torchVersion = version; + }); meta = { changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}"; diff --git a/pkgs/python-modules/torch/variant.nix b/pkgs/python-modules/torch/variant.nix new file mode 100644 index 00000000..c783e434 --- /dev/null +++ b/pkgs/python-modules/torch/variant.nix @@ -0,0 +1,55 @@ +{ + config, + cudaSupport ? config.cudaSupport, + metalSupport ? config.metalSupport or false, + rocmSupport ? config.rocmSupport, + xpuSupport ? config.xpuSupport or false, + + cudaPackages, + rocmPackages, + xpuPackages, + + lib, + stdenv, + + cxx11Abi, + + torchVersion, +}: + +let + flattenVersion = + version: lib.replaceStrings [ "." ] [ "" ] (lib.versions.majorMinor (lib.versions.pad 2 version)); + abiString = cxx11Abi: if cxx11Abi then "cxx11" else "cxx98"; + backend = + if cudaSupport then + "cuda" + else if metalSupport then + "metal" + else if rocmSupport then + "rocm" + else if xpuSupport then + "xpu" + else + "cpu"; + computeString = + if cudaSupport then + "cu${flattenVersion cudaPackages.cudaMajorMinorVersion}" + else if metalSupport then + "metal" + else if rocmSupport then + "rocm${flattenVersion (lib.versions.majorMinor rocmPackages.rocm.version)}" + else if xpuSupport then + "xpu${flattenVersion (lib.versions.majorMinor xpuPackages.oneapi-torch-dev.version)}" + else + "cpu"; +in +{ + variant = + if stdenv.hostPlatform.system == "aarch64-darwin" then + "torch${flattenVersion (lib.versions.majorMinor torchVersion)}-${computeString}-${stdenv.hostPlatform.system}" + else + "torch${flattenVersion (lib.versions.majorMinor torchVersion)}-${abiString cxx11Abi}-${computeString}-${stdenv.hostPlatform.system}"; + + noarchVariant = "torch-${backend}"; +} From 67b01f0616ad22db8f8960c2a54c0fe22270b36d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 3 Dec 2025 13:46:16 +0000 Subject: [PATCH 09/23] Update build variants --- docs/build-variants.md | 10 +++++++--- scripts/gen_variants_markdown.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/docs/build-variants.md b/docs/build-variants.md index c691958b..b46a5dd5 100644 --- a/docs/build-variants.md +++ b/docs/build-variants.md @@ -53,9 +53,13 @@ available. This list will be updated as new PyTorch versions are released. - `torch28-cxx11-xpu20251-x86_64-linux` - `torch29-cxx11-xpu20252-x86_64-linux` -## Universal +## Python-only kernels Kernels that are in pure Python (e.g. Triton kernels) only need to provide -a single build variant: +one or more of the following variants: -- `torch-universal` +- `torch-cpu` +- `torch-cuda` +- `torch-metal` +- `torch-rocm` +- `torch-xpu` diff --git a/scripts/gen_variants_markdown.py b/scripts/gen_variants_markdown.py index a6f8d8d9..500ed7c3 100755 --- a/scripts/gen_variants_markdown.py +++ b/scripts/gen_variants_markdown.py @@ -11,7 +11,7 @@ "xpu": "XPU", } -HEADER = """# Build variants +SPECIFIC_VARIANTS = """# Build variants A kernel can be compliant for a specific compute framework (e.g. CUDA) or architecture (e.g. x86_64). For compliance with a compute framework and @@ -19,12 +19,10 @@ available. This list will be updated as new PyTorch versions are released.\n """ -FOOTER = """## Universal +NOARCH_VARIANTS = """## Python-only kernels Kernels that are in pure Python (e.g. Triton kernels) only need to provide -a single build variant: - -- `torch-universal` +one or more of the following variants:\n """ @@ -35,7 +33,7 @@ def json_to_markdown(): data = json.load(f) with open(project_root / "docs" / "build-variants.md", "w") as f: - f.write(HEADER) + f.write(SPECIFIC_VARIANTS) for arch, platforms in data.items(): for platform, variants in platforms.items(): f.write(f"## {_PLATFORM_NAMES[platform]} {arch}\n\n") @@ -44,7 +42,10 @@ def json_to_markdown(): f.write(f"- `{variant}`\n") f.write("\n") - f.write(FOOTER) + f.write(NOARCH_VARIANTS) + backends = { backend for platforms in data.values() for backend in platforms.keys() } + for backend in sorted(backends): + f.write(f"- `torch-{backend}`\n") if __name__ == "__main__": From c99ee16a9b53697b197b14d1a06e22d5c9a8264e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 3 Dec 2025 14:22:15 +0000 Subject: [PATCH 10/23] Validate build.toml, recommend update when necessary --- lib/build.nix | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/lib/build.nix b/lib/build.nix index 2015b9a4..a1e8656d 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -23,14 +23,13 @@ rec { validateBuildConfig = buildToml: let + hasBackends = buildToml.general ? backends; kernels = lib.attrValues (buildToml.kernel or { }); - hasOldUniversal = builtins.hasAttr "universal" (buildToml.torch or { }); - hasLanguage = lib.any (kernel: kernel ? language) kernels; in - assert lib.assertMsg (!hasOldUniversal && !hasLanguage) '' + assert lib.assertMsg hasBackends '' build.toml seems to be of an older version, update it with: - build2cmake update-build build.toml''; + nix run github:huggingface/kernel-builder#build2cmake update-build build.toml''; buildToml; # Backends supported by the kernel. @@ -46,7 +45,7 @@ rec { }; in lib.foldl (backends: backend: backends // { ${backend} = true; }) init ( - buildToml.general.backends or [ ] + buildToml.general.backends ); # Backends for which there is a native (compiled kernel). @@ -79,9 +78,8 @@ rec { buildToml: buildSets: let backends' = backends buildToml; - # COMPAT: buildToml.general.cuda-{minver,maxver} are backwards compat for v2 build.toml. - minCuda = buildToml.general.cuda.cuda-minver or buildToml.general.cuda-minver or "11.8"; - maxCuda = buildToml.general.cuda.cuda-maxver or buildToml.general.cuda-maxver or "99.9"; + minCuda = buildToml.general.cuda.cuda-minver or "11.8"; + maxCuda = buildToml.general.cuda.cuda-maxver or "99.9"; minTorch = buildToml.torch.minver or "2.0"; maxTorch = buildToml.torch.maxver or "99.9"; versionBetween = From f87f5fe5081df3837466eeb1f5dd7c5e97380451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 3 Dec 2025 14:26:45 +0000 Subject: [PATCH 11/23] Update example kernels build.toml --- examples/cutlass-gemm/build.toml | 5 ++++- examples/relu-backprop-compile/build.toml | 5 ++++- examples/relu-compiler-flags/build.toml | 6 +++++- examples/relu-metal-cpp/build.toml | 2 +- examples/relu-specific-torch/build.toml | 5 ++++- examples/relu/build.toml | 8 +++++++- examples/silu-and-mul-universal/build.toml | 8 +++++++- 7 files changed, 32 insertions(+), 7 deletions(-) diff --git a/examples/cutlass-gemm/build.toml b/examples/cutlass-gemm/build.toml index dc5d10c6..cd2c8066 100644 --- a/examples/cutlass-gemm/build.toml +++ b/examples/cutlass-gemm/build.toml @@ -1,6 +1,9 @@ [general] name = "cutlass-gemm" -universal = false +backends = [ + "cuda", + "xpu", +] [torch] src = [ diff --git a/examples/relu-backprop-compile/build.toml b/examples/relu-backprop-compile/build.toml index 130b6474..6de49b78 100644 --- a/examples/relu-backprop-compile/build.toml +++ b/examples/relu-backprop-compile/build.toml @@ -1,6 +1,9 @@ [general] name = "relu-backprop-compile" -universal = false +backends = [ + "cuda", + "rocm", +] [torch] src = [ diff --git a/examples/relu-compiler-flags/build.toml b/examples/relu-compiler-flags/build.toml index d595e09e..b358d67f 100644 --- a/examples/relu-compiler-flags/build.toml +++ b/examples/relu-compiler-flags/build.toml @@ -1,6 +1,10 @@ [general] name = "relu-compiler-flags" -universal = false +backends = [ + "cuda", + "rocm", + "xpu", +] [torch] src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"] diff --git a/examples/relu-metal-cpp/build.toml b/examples/relu-metal-cpp/build.toml index e0d6d487..8cf012bc 100644 --- a/examples/relu-metal-cpp/build.toml +++ b/examples/relu-metal-cpp/build.toml @@ -1,6 +1,6 @@ [general] name = "relu" -universal = false +backends = ["metal"] [torch] src = [ diff --git a/examples/relu-specific-torch/build.toml b/examples/relu-specific-torch/build.toml index 3db0e7e7..191c4081 100644 --- a/examples/relu-specific-torch/build.toml +++ b/examples/relu-specific-torch/build.toml @@ -1,6 +1,9 @@ [general] name = "relu-specific-torch" -universal = false +backends = [ + "cuda", + "rocm", +] [torch] src = [ diff --git a/examples/relu/build.toml b/examples/relu/build.toml index 84eb068a..3627027c 100644 --- a/examples/relu/build.toml +++ b/examples/relu/build.toml @@ -1,6 +1,12 @@ [general] name = "relu" -universal = false +backends = [ + "cpu", + "cuda", + "metal", + "rocm", + "xpu", +] [torch] src = [ diff --git a/examples/silu-and-mul-universal/build.toml b/examples/silu-and-mul-universal/build.toml index 826e880e..41f91fd5 100644 --- a/examples/silu-and-mul-universal/build.toml +++ b/examples/silu-and-mul-universal/build.toml @@ -1,3 +1,9 @@ [general] name = "silu-and-mul-universal" -universal = true +backends = [ + "cpu", + "cuda", + "metal", + "rocm", + "xpu", +] From 660571c13e8956bb90659d30b09a2cef4e1959e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 3 Dec 2025 14:27:12 +0000 Subject: [PATCH 12/23] nix fmt --- lib/build.nix | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/build.nix b/lib/build.nix index a1e8656d..943bc2ac 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -44,9 +44,7 @@ rec { xpu = false; }; in - lib.foldl (backends: backend: backends // { ${backend} = true; }) init ( - buildToml.general.backends - ); + lib.foldl (backends: backend: backends // { ${backend} = true; }) init (buildToml.general.backends); # Backends for which there is a native (compiled kernel). kernelBackends = From c66e7f34f3be33e989e59e950e4bddfa3b5d2caa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 3 Dec 2025 14:30:12 +0000 Subject: [PATCH 13/23] docs: update `build.toml` options for v3 --- docs/writing-kernels.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/writing-kernels.md b/docs/writing-kernels.md index 172f966e..c3e32368 100644 --- a/docs/writing-kernels.md +++ b/docs/writing-kernels.md @@ -93,11 +93,14 @@ depends = [ "torch" ] - `name` (required): the name of the kernel. The Python code for a Torch extension must be stored in `torch-ext/`. -- `universal`: the kernel is a universal kernel when set to `true`. A - universal kernel is a pure Python package (no compiled files). - Universal kernels do not use the other sections described below. - A good example of a universal kernel is a Triton kernel. - Default: `false` +- `backends` (required): a list of supported backends. Must be one or + more of `cpu`, `cuda`, `metal`, `rocm`, or `xpu`. +- `python-depends` (**experimental**): a list of additional Python dependencies + that the kernel requires. The only supported dependencies are `einops` + and `nvidia-cutlass-dsl`. + +### `general.cuda` + - `cuda-maxver`: the maximum CUDA toolkit version (inclusive). This option _must not_ be set under normal circumstances, since it can exclude Torch build variants that are [required for compliant kernels](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md). @@ -108,9 +111,6 @@ depends = [ "torch" ] build variants that are [required for compliant kernels](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md). This option is provided for kernels that require functionality only provided by newer CUDA toolkits. -- `python-depends` (**experimental**): a list of additional Python dependencies - that the kernel requires. The only supported dependencies are `einops` - and `nvidia-cutlass-dsl`. ### `torch` From 8065db08519328c1f431d3c9c1f64f01e2487a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 3 Dec 2025 14:34:18 +0000 Subject: [PATCH 14/23] Rename silu-and-mul-universal kernel to silu-and-mul --- .github/workflows/build_kernel.yaml | 10 +++++----- .github/workflows/build_kernel_windows.yaml | 4 ++-- .../build.toml | 2 +- .../{silu-and-mul-universal => silu-and-mul}/flake.nix | 0 .../tests/test_silu_and_mul.py | 2 +- .../torch-ext/silu_and_mul}/__init__.py | 0 .../torch-ext/silu_and_mul}/silu_and_mul.py | 0 tests/Dockerfile.test-kernel | 2 +- tests/run-tests.sh | 4 ++-- 9 files changed, 12 insertions(+), 12 deletions(-) rename examples/{silu-and-mul-universal => silu-and-mul}/build.toml (72%) rename examples/{silu-and-mul-universal => silu-and-mul}/flake.nix (100%) rename examples/{silu-and-mul-universal => silu-and-mul}/tests/test_silu_and_mul.py (96%) rename examples/{silu-and-mul-universal/torch-ext/silu_and_mul_universal => silu-and-mul/torch-ext/silu_and_mul}/__init__.py (100%) rename examples/{silu-and-mul-universal/torch-ext/silu_and_mul_universal => silu-and-mul/torch-ext/silu_and_mul}/silu_and_mul.py (100%) diff --git a/.github/workflows/build_kernel.yaml b/.github/workflows/build_kernel.yaml index 922bd5f5..7bf02a03 100644 --- a/.github/workflows/build_kernel.yaml +++ b/.github/workflows/build_kernel.yaml @@ -58,10 +58,10 @@ jobs: - name: Test that we can build a test shell (e.g. that gcc corresponds to CUDA-required) run: ( cd examples/relu && nix build .#devShells.x86_64-linux.test ) - - name: Build silu-and-mul-universal kernel - run: ( cd examples/silu-and-mul-universal && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux ) - - name: Copy silu-and-mul-universal kernel - run: cp -rL examples/silu-and-mul-universal/result silu-and-mul-universal-kernel + - name: Build silu-and-mul kernel + run: ( cd examples/silu-and-mul && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux ) + - name: Copy silu-and-mul kernel + run: cp -rL examples/silu-and-mul/result silu-and-mul-kernel - name: Upload kernel artifacts uses: actions/upload-artifact@v4 @@ -73,7 +73,7 @@ jobs: relu-kernel relu-kernel-cpu relu-backprop-compile-kernel - silu-and-mul-universal-kernel + silu-and-mul-kernel test: name: Test kernels diff --git a/.github/workflows/build_kernel_windows.yaml b/.github/workflows/build_kernel_windows.yaml index e803f688..4ca166a6 100644 --- a/.github/workflows/build_kernel_windows.yaml +++ b/.github/workflows/build_kernel_windows.yaml @@ -79,5 +79,5 @@ jobs: # - name: Build relu kernel (specific Torch version) # run: ( cd examples/relu-specific-torch && nix build . ) - - name: Build silu-and-mul-universal kernel - run: ( scripts\windows\builder.ps1 -SourceFolder examples/silu-and-mul-universal -BuildConfig Release -Build -Force) \ No newline at end of file + - name: Build silu-and-mul kernel + run: ( scripts\windows\builder.ps1 -SourceFolder examples/silu-and-mul -BuildConfig Release -Build -Force) diff --git a/examples/silu-and-mul-universal/build.toml b/examples/silu-and-mul/build.toml similarity index 72% rename from examples/silu-and-mul-universal/build.toml rename to examples/silu-and-mul/build.toml index 41f91fd5..c42fbecc 100644 --- a/examples/silu-and-mul-universal/build.toml +++ b/examples/silu-and-mul/build.toml @@ -1,5 +1,5 @@ [general] -name = "silu-and-mul-universal" +name = "silu-and-mul" backends = [ "cpu", "cuda", diff --git a/examples/silu-and-mul-universal/flake.nix b/examples/silu-and-mul/flake.nix similarity index 100% rename from examples/silu-and-mul-universal/flake.nix rename to examples/silu-and-mul/flake.nix diff --git a/examples/silu-and-mul-universal/tests/test_silu_and_mul.py b/examples/silu-and-mul/tests/test_silu_and_mul.py similarity index 96% rename from examples/silu-and-mul-universal/tests/test_silu_and_mul.py rename to examples/silu-and-mul/tests/test_silu_and_mul.py index aded2b45..d98cf408 100644 --- a/examples/silu-and-mul-universal/tests/test_silu_and_mul.py +++ b/examples/silu-and-mul/tests/test_silu_and_mul.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from torch.library import opcheck -from silu_and_mul_universal import ops, silu_and_mul +from silu_and_mul import ops, silu_and_mul def silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: diff --git a/examples/silu-and-mul-universal/torch-ext/silu_and_mul_universal/__init__.py b/examples/silu-and-mul/torch-ext/silu_and_mul/__init__.py similarity index 100% rename from examples/silu-and-mul-universal/torch-ext/silu_and_mul_universal/__init__.py rename to examples/silu-and-mul/torch-ext/silu_and_mul/__init__.py diff --git a/examples/silu-and-mul-universal/torch-ext/silu_and_mul_universal/silu_and_mul.py b/examples/silu-and-mul/torch-ext/silu_and_mul/silu_and_mul.py similarity index 100% rename from examples/silu-and-mul-universal/torch-ext/silu_and_mul_universal/silu_and_mul.py rename to examples/silu-and-mul/torch-ext/silu_and_mul/silu_and_mul.py diff --git a/tests/Dockerfile.test-kernel b/tests/Dockerfile.test-kernel index 90a09ef5..5329685c 100644 --- a/tests/Dockerfile.test-kernel +++ b/tests/Dockerfile.test-kernel @@ -66,7 +66,7 @@ RUN uv add numpy pytest COPY relu-kernel ./relu-kernel COPY relu-kernel-cpu ./relu-kernel-cpu COPY cutlass-gemm-kernel ./cutlass-gemm-kernel -COPY silu-and-mul-universal-kernel ./silu-and-mul-universal-kernel +COPY silu-and-mul-kernel ./silu-and-mul-kernel COPY examples/relu/tests ./relu_tests COPY examples/cutlass-gemm/tests ./cutlass_gemm_tests diff --git a/tests/run-tests.sh b/tests/run-tests.sh index f449a2f6..3498a8c1 100644 --- a/tests/run-tests.sh +++ b/tests/run-tests.sh @@ -4,8 +4,8 @@ PYTHONPATH="relu-kernel:cutlass-gemm-kernel:$PYTHONPATH" \ .venv/bin/pytest relu_tests cutlass_gemm_tests # We only care about importing, the kernel is trivial. -PYTHONPATH="silu-and-mul-universal-kernel:$PYTHONPATH" \ - .venv/bin/python -c "import silu_and_mul_universal" +PYTHONPATH="silu-and-mul-kernel:$PYTHONPATH" \ + .venv/bin/python -c "import silu_and_mul" PYTHONPATH="relu-kernel-cpu:$PYTHONPATH" \ CUDA_VISIBLE_DEVICES="" \ From 44786f89b657df4f3151a10ab922fb314545b563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 4 Dec 2025 13:03:55 +0000 Subject: [PATCH 15/23] build2cmake: rename universal templates to noarch --- build2cmake/src/templates/{universal => noarch}/_ops.py | 0 .../src/templates/{universal => noarch}/pyproject.toml | 0 build2cmake/src/torch/noarch.rs | 8 ++++---- .../torch-ext/silu_and_mul/{silu_and_mul.py => op.py} | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename build2cmake/src/templates/{universal => noarch}/_ops.py (100%) rename build2cmake/src/templates/{universal => noarch}/pyproject.toml (100%) rename examples/silu-and-mul/torch-ext/silu_and_mul/{silu_and_mul.py => op.py} (100%) diff --git a/build2cmake/src/templates/universal/_ops.py b/build2cmake/src/templates/noarch/_ops.py similarity index 100% rename from build2cmake/src/templates/universal/_ops.py rename to build2cmake/src/templates/noarch/_ops.py diff --git a/build2cmake/src/templates/universal/pyproject.toml b/build2cmake/src/templates/noarch/pyproject.toml similarity index 100% rename from build2cmake/src/templates/universal/pyproject.toml rename to build2cmake/src/templates/noarch/pyproject.toml diff --git a/build2cmake/src/torch/noarch.rs b/build2cmake/src/torch/noarch.rs index f9821d61..3fd441e1 100644 --- a/build2cmake/src/torch/noarch.rs +++ b/build2cmake/src/torch/noarch.rs @@ -38,8 +38,8 @@ fn write_ops_py( path.push("_ops.py"); let writer = file_set.entry(path); - env.get_template("universal/_ops.py") - .wrap_err("Cannot get _ops-universal.py template")? + env.get_template("noarch/_ops.py") + .wrap_err("Cannot get noarch _ops.py template")? .render_to_write( context! { ops_name => ops_name, @@ -69,8 +69,8 @@ fn write_pyproject_toml( .map(|d| format!("\"{d}\"")) .join(", "); - env.get_template("universal/pyproject.toml") - .wrap_err("Cannot get universal pyproject.toml template")? + env.get_template("noarch/pyproject.toml") + .wrap_err("Cannot get noarch pyproject.toml template")? .render_to_write( context! { data_globs => data_globs, diff --git a/examples/silu-and-mul/torch-ext/silu_and_mul/silu_and_mul.py b/examples/silu-and-mul/torch-ext/silu_and_mul/op.py similarity index 100% rename from examples/silu-and-mul/torch-ext/silu_and_mul/silu_and_mul.py rename to examples/silu-and-mul/torch-ext/silu_and_mul/op.py From 4a83239422360cc97a88168f52de018e3c101645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 4 Dec 2025 13:04:10 +0000 Subject: [PATCH 16/23] Move unused `v2::Backend` --- build2cmake/src/config/v2.rs | 39 +----------------------------------- 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index b003bc33..20b906bc 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr}; +use std::{collections::HashMap, fmt::Display, path::PathBuf}; use eyre::{bail, Result}; use serde::{Deserialize, Serialize}; @@ -117,43 +117,6 @@ pub enum Kernel { }, } -#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] -#[serde(deny_unknown_fields, rename_all = "kebab-case")] -pub enum Backend { - Cpu, - Cuda, - Metal, - Rocm, - Xpu, -} - -impl Display for Backend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Backend::Cpu => write!(f, "cpu"), - Backend::Cuda => write!(f, "cuda"), - Backend::Metal => write!(f, "metal"), - Backend::Rocm => write!(f, "rocm"), - Backend::Xpu => write!(f, "xpu"), - } - } -} - -impl FromStr for Backend { - type Err = String; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "cpu" => Ok(Backend::Cpu), - "cuda" => Ok(Backend::Cuda), - "metal" => Ok(Backend::Metal), - "rocm" => Ok(Backend::Rocm), - "xpu" => Ok(Backend::Xpu), - _ => Err(format!("Unknown backend: {s}")), - } - } -} - impl TryFrom for Build { type Error = eyre::Error; From 5d697bae3bd865ac1c10e038ec23483ac761cee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 4 Dec 2025 13:04:22 +0000 Subject: [PATCH 17/23] examples/silu_and_mul: fix --- examples/silu-and-mul/torch-ext/silu_and_mul/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/silu-and-mul/torch-ext/silu_and_mul/__init__.py b/examples/silu-and-mul/torch-ext/silu_and_mul/__init__.py index f17b92fa..dd9ea8f7 100644 --- a/examples/silu-and-mul/torch-ext/silu_and_mul/__init__.py +++ b/examples/silu-and-mul/torch-ext/silu_and_mul/__init__.py @@ -1,7 +1,7 @@ import torch from ._ops import ops -from .silu_and_mul import _silu_and_mul +from .op import _silu_and_mul def silu_and_mul(x: torch.Tensor) -> torch.Tensor: From dc4867eb6ea4e124f6e996620c91c82e8f9c4c6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 4 Dec 2025 13:04:46 +0000 Subject: [PATCH 18/23] build2cmake: update flake.lock for newer Rust --- build2cmake/flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build2cmake/flake.lock b/build2cmake/flake.lock index 27fe85ca..7adc14fd 100644 --- a/build2cmake/flake.lock +++ b/build2cmake/flake.lock @@ -48,11 +48,11 @@ ] }, "locked": { - "lastModified": 1750991972, - "narHash": "sha256-jzadGZL1MtqmHb5AZcjZhHpNulOdMZPxf8Wifg8e5VA=", + "lastModified": 1764816035, + "narHash": "sha256-F0IQSmSj4t2ThkbWZooAhkCTO+YpZSd2Pqiv2uoYEHo=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "b6509555d8ffaa0727f998af6ace901c5b78dc26", + "rev": "74d9abb7c5c030469f90d97a67d127cc5d76c238", "type": "github" }, "original": { From db454c961be571f9247daa669ac57d0451a776da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 4 Dec 2025 13:49:54 +0000 Subject: [PATCH 19/23] CI: name silu-and-mul output changed --- .github/workflows/build_kernel.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_kernel.yaml b/.github/workflows/build_kernel.yaml index 7bf02a03..11ecb6f8 100644 --- a/.github/workflows/build_kernel.yaml +++ b/.github/workflows/build_kernel.yaml @@ -59,7 +59,7 @@ jobs: run: ( cd examples/relu && nix build .#devShells.x86_64-linux.test ) - name: Build silu-and-mul kernel - run: ( cd examples/silu-and-mul && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux ) + run: ( cd examples/silu-and-mul && nix build .\#redistributable.torch-cuda ) - name: Copy silu-and-mul kernel run: cp -rL examples/silu-and-mul/result silu-and-mul-kernel From 5d497e118d339cfbea656e6367c00e24b6780963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 4 Dec 2025 14:57:03 +0000 Subject: [PATCH 20/23] Fix variant generation --- lib/build-variants.nix | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/build-variants.nix b/lib/build-variants.nix index 741d471f..9ee8d7e7 100644 --- a/lib/build-variants.nix +++ b/lib/build-variants.nix @@ -1,7 +1,7 @@ { lib }: let inherit (import ./torch-version-utils.nix { inherit lib; }) - flattenSystems + backend flattenSystems ; in rec { @@ -11,15 +11,15 @@ rec { inherit (import ./version-utils.nix { inherit lib; }) abiString flattenVersion; computeString = version: - if version.backend == "cpu" then + if backend version == "cpu" then "cpu" - else if version.backend == "cuda" then + else if backend version == "cuda" then "cu${flattenVersion (lib.versions.majorMinor version.cudaVersion)}" - else if version.backend == "rocm" then + else if backend version == "rocm" then "rocm${flattenVersion (lib.versions.majorMinor version.rocmVersion)}" - else if version.backend == "metal" then + else if backend version == "metal" then "metal" - else if version.backend == "xpu" then + else if backend version == "xpu" then "xpu${flattenVersion (lib.versions.majorMinor version.xpuVersion)}" else throw "No compute framework set in Torch version"; @@ -41,7 +41,7 @@ rec { let path = [ version.system - version.backend + (backend version) ]; pathVersions = lib.attrByPath path [ ] acc ++ [ (buildName version) ]; in From b3a961bdd81976a21d5fc04b29a1f00a6fa6554b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 4 Dec 2025 15:03:07 +0000 Subject: [PATCH 21/23] nix fmt --- lib/build-variants.nix | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/build-variants.nix b/lib/build-variants.nix index 9ee8d7e7..7201211a 100644 --- a/lib/build-variants.nix +++ b/lib/build-variants.nix @@ -1,7 +1,8 @@ { lib }: let inherit (import ./torch-version-utils.nix { inherit lib; }) - backend flattenSystems + backend + flattenSystems ; in rec { From 01a0e2f4e9db510f1971d03978c79870576775ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 4 Dec 2025 15:14:05 +0000 Subject: [PATCH 22/23] Fixup CUDA minver/maxver in Nix and docs --- docs/writing-kernels.md | 4 ++-- lib/build.nix | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/writing-kernels.md b/docs/writing-kernels.md index c3e32368..d0223f8b 100644 --- a/docs/writing-kernels.md +++ b/docs/writing-kernels.md @@ -101,12 +101,12 @@ depends = [ "torch" ] ### `general.cuda` -- `cuda-maxver`: the maximum CUDA toolkit version (inclusive). This option +- `maxver`: the maximum CUDA toolkit version (inclusive). This option _must not_ be set under normal circumstances, since it can exclude Torch build variants that are [required for compliant kernels](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md). This option is provided for kernels that cause compiler errors on newer CUDA toolkit versions. -- `cuda-minver`: the minimum required CUDA toolkit version. This option +- `minver`: the minimum required CUDA toolkit version. This option _must not_ be set under normal circumstances, since it can exclude Torch build variants that are [required for compliant kernels](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md). This option is provided for kernels that require functionality only diff --git a/lib/build.nix b/lib/build.nix index 943bc2ac..c00ea764 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -76,8 +76,8 @@ rec { buildToml: buildSets: let backends' = backends buildToml; - minCuda = buildToml.general.cuda.cuda-minver or "11.8"; - maxCuda = buildToml.general.cuda.cuda-maxver or "99.9"; + minCuda = buildToml.general.cuda.minver or "11.8"; + maxCuda = buildToml.general.cuda.maxver or "99.9"; minTorch = buildToml.torch.minver or "2.0"; maxTorch = buildToml.torch.maxver or "99.9"; versionBetween = From ca5871931bed96e5f67ee88b4d41f711ee3412c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 5 Dec 2025 08:56:59 +0000 Subject: [PATCH 23/23] Update relu-torch-bounds `build.toml` --- examples/relu-torch-bounds/build.toml | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/relu-torch-bounds/build.toml b/examples/relu-torch-bounds/build.toml index 8cc96b51..be8b5396 100644 --- a/examples/relu-torch-bounds/build.toml +++ b/examples/relu-torch-bounds/build.toml @@ -1,11 +1,17 @@ [general] name = "relu" -universal = false +backends = [ + "cuda", + "rocm", +] [torch] -src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"] minver = "2.9" maxver = "2.9" +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] [kernel.relu] backend = "cuda" @@ -14,6 +20,7 @@ src = ["relu_cuda/relu.cu"] [kernel.relu_rocm] backend = "rocm" +depends = ["torch"] rocm-archs = [ "gfx906", "gfx908", @@ -25,5 +32,4 @@ rocm-archs = [ "gfx1100", "gfx1101", ] -depends = ["torch"] src = ["relu_cuda/relu.cu"]