Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/build_kernel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ jobs:
- name: Test that we can build a test shell (e.g. that gcc corresponds to CUDA-required)
run: ( cd examples/relu && nix build .#devShells.x86_64-linux.test )

- name: Build silu-and-mul-universal kernel
run: ( cd examples/silu-and-mul-universal && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
- name: Copy silu-and-mul-universal kernel
run: cp -rL examples/silu-and-mul-universal/result silu-and-mul-universal-kernel
- name: Build silu-and-mul kernel
run: ( cd examples/silu-and-mul && nix build .\#redistributable.torch-cuda )
- name: Copy silu-and-mul kernel
run: cp -rL examples/silu-and-mul/result silu-and-mul-kernel

- name: Upload kernel artifacts
uses: actions/upload-artifact@v4
Expand All @@ -73,7 +73,7 @@ jobs:
relu-kernel
relu-kernel-cpu
relu-backprop-compile-kernel
silu-and-mul-universal-kernel
silu-and-mul-kernel

test:
name: Test kernels
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/build_kernel_windows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ jobs:
# - name: Build relu kernel (specific Torch version)
# run: ( cd examples/relu-specific-torch && nix build . )

- name: Build silu-and-mul-universal kernel
run: ( scripts\windows\builder.ps1 -SourceFolder examples/silu-and-mul-universal -BuildConfig Release -Build -Force)
- name: Build silu-and-mul kernel
run: ( scripts\windows\builder.ps1 -SourceFolder examples/silu-and-mul -BuildConfig Release -Build -Force)
6 changes: 3 additions & 3 deletions build2cmake/flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions build2cmake/src/config/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use serde::{Deserialize, Serialize};

#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
#[non_exhaustive]
#[serde(rename_all = "lowercase")]
pub enum Dependency {
#[serde(rename = "cutlass_2_10")]
Cutlass2_10,
#[serde(rename = "cutlass_3_5")]
Cutlass3_5,
#[serde(rename = "cutlass_3_6")]
Cutlass3_6,
#[serde(rename = "cutlass_3_8")]
Cutlass3_8,
#[serde(rename = "cutlass_3_9")]
Cutlass3_9,
#[serde(rename = "cutlass_4_0")]
Cutlass4_0,
#[serde(rename = "cutlass_sycl")]
CutlassSycl,
#[serde(rename = "metal-cpp")]
MetalCpp,
Torch,
}
33 changes: 20 additions & 13 deletions build2cmake/src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use eyre::Result;
use serde::Deserialize;
use serde_value::Value;

pub mod v1;

mod common;

mod v2;
use serde_value::Value;
pub use v2::{Backend, Build, Dependency, General, Kernel, Torch};

mod v3;
pub use common::Dependency;
pub use v3::{Backend, Build, General, Kernel, Torch};

#[derive(Debug)]
pub enum BuildCompat {
V1(v1::Build),
V2(Build),
V2(v2::Build),
V3(Build),
}

impl<'de> Deserialize<'de> for BuildCompat {
Expand All @@ -20,14 +26,11 @@ impl<'de> Deserialize<'de> for BuildCompat {
{
let value = Value::deserialize(deserializer)?;

match v1::Build::deserialize(value.clone()) {
Ok(v1_build) => Ok(BuildCompat::V1(v1_build)),
Err(_) => {
let v2_build: Build =
Build::deserialize(value).map_err(serde::de::Error::custom)?;
Ok(BuildCompat::V2(v2_build))
}
}
v1::Build::deserialize(value.clone())
.map(BuildCompat::V1)
.or_else(|_| v2::Build::deserialize(value.clone()).map(BuildCompat::V2))
.or_else(|_| Build::deserialize(value.clone()).map(BuildCompat::V3))
.map_err(serde::de::Error::custom)
}
}

Expand All @@ -36,8 +39,12 @@ impl TryFrom<BuildCompat> for Build {

fn try_from(compat: BuildCompat) -> Result<Self> {
match compat {
BuildCompat::V1(v1_build) => v1_build.try_into(),
BuildCompat::V2(v2_build) => Ok(v2_build),
BuildCompat::V1(v1_build) => {
let v2_build: v2::Build = v1_build.try_into()?;
v2_build.try_into()
}
BuildCompat::V2(v2_build) => v2_build.try_into(),
BuildCompat::V3(v3_build) => Ok(v3_build),
}
}
}
2 changes: 1 addition & 1 deletion build2cmake/src/config/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{collections::HashMap, fmt::Display, path::PathBuf};

use serde::Deserialize;

use super::v2::Dependency;
use super::common::Dependency;

#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
Expand Down
172 changes: 5 additions & 167 deletions build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use std::{
collections::{BTreeSet, HashMap},
fmt::Display,
path::PathBuf,
str::FromStr,
};
use std::{collections::HashMap, fmt::Display, path::PathBuf};

use eyre::{bail, Result};
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use crate::version::Version;

use super::v1::{self, Language};
use super::{
common::Dependency,
v1::{self, Language},
};

#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
Expand All @@ -23,25 +20,6 @@ pub struct Build {
pub kernels: HashMap<String, Kernel>,
}

impl Build {
pub fn has_kernel_with_backend(&self, backend: &Backend) -> bool {
self.backends().contains(backend)
}

pub fn backends(&self) -> BTreeSet<Backend> {
self.kernels
.values()
.map(|kernel| match kernel {
Kernel::Cpu { .. } => Backend::Cpu,
Kernel::Cuda { .. } => Backend::Cuda,
Kernel::Metal { .. } => Backend::Metal,
Kernel::Rocm { .. } => Backend::Rocm,
Kernel::Xpu { .. } => Backend::Xpu,
})
.collect()
}
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
pub struct General {
Expand All @@ -58,13 +36,6 @@ pub struct General {
pub python_depends: Option<Vec<PythonDependency>>,
}

impl General {
/// Name of the kernel as a Python extension.
pub fn python_name(&self) -> String {
self.name.replace("-", "_")
}
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
pub struct Hub {
Expand Down Expand Up @@ -100,27 +71,6 @@ pub struct Torch {
pub src: Vec<PathBuf>,
}

impl Torch {
pub fn data_globs(&self) -> Option<Vec<String>> {
match self.pyext.as_ref() {
Some(exts) => {
let globs = exts
.iter()
.filter(|&ext| ext != "py" && ext != "pyi")
.map(|ext| format!("\"**/*.{ext}\""))
.collect_vec();
if globs.is_empty() {
None
} else {
Some(globs)
}
}

None => None,
}
}
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")]
pub enum Kernel {
Expand Down Expand Up @@ -167,118 +117,6 @@ pub enum Kernel {
},
}

impl Kernel {
pub fn cxx_flags(&self) -> Option<&[String]> {
match self {
Kernel::Cpu { cxx_flags, .. }
| Kernel::Cuda { cxx_flags, .. }
| Kernel::Metal { cxx_flags, .. }
| Kernel::Rocm { cxx_flags, .. }
| Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(),
}
}

pub fn include(&self) -> Option<&[String]> {
match self {
Kernel::Cpu { include, .. }
| Kernel::Cuda { include, .. }
| Kernel::Metal { include, .. }
| Kernel::Rocm { include, .. }
| Kernel::Xpu { include, .. } => include.as_deref(),
}
}

pub fn backend(&self) -> Backend {
match self {
Kernel::Cpu { .. } => Backend::Cpu,
Kernel::Cuda { .. } => Backend::Cuda,
Kernel::Metal { .. } => Backend::Metal,
Kernel::Rocm { .. } => Backend::Rocm,
Kernel::Xpu { .. } => Backend::Xpu,
}
}

pub fn depends(&self) -> &[Dependency] {
match self {
Kernel::Cpu { depends, .. }
| Kernel::Cuda { depends, .. }
| Kernel::Metal { depends, .. }
| Kernel::Rocm { depends, .. }
| Kernel::Xpu { depends, .. } => depends,
}
}

pub fn src(&self) -> &[String] {
match self {
Kernel::Cpu { src, .. }
| Kernel::Cuda { src, .. }
| Kernel::Metal { src, .. }
| Kernel::Rocm { src, .. }
| Kernel::Xpu { src, .. } => src,
}
}
}

#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)]
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
pub enum Backend {
Cpu,
Cuda,
Metal,
Rocm,
Xpu,
}

impl Display for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Backend::Cpu => write!(f, "cpu"),
Backend::Cuda => write!(f, "cuda"),
Backend::Metal => write!(f, "metal"),
Backend::Rocm => write!(f, "rocm"),
Backend::Xpu => write!(f, "xpu"),
}
}
}

impl FromStr for Backend {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"cpu" => Ok(Backend::Cpu),
"cuda" => Ok(Backend::Cuda),
"metal" => Ok(Backend::Metal),
"rocm" => Ok(Backend::Rocm),
"xpu" => Ok(Backend::Xpu),
_ => Err(format!("Unknown backend: {s}")),
}
}
}

#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
#[non_exhaustive]
#[serde(rename_all = "lowercase")]
pub enum Dependency {
#[serde(rename = "cutlass_2_10")]
Cutlass2_10,
#[serde(rename = "cutlass_3_5")]
Cutlass3_5,
#[serde(rename = "cutlass_3_6")]
Cutlass3_6,
#[serde(rename = "cutlass_3_8")]
Cutlass3_8,
#[serde(rename = "cutlass_3_9")]
Cutlass3_9,
#[serde(rename = "cutlass_4_0")]
Cutlass4_0,
#[serde(rename = "cutlass_sycl")]
CutlassSycl,
#[serde(rename = "metal-cpp")]
MetalCpp,
Torch,
}

impl TryFrom<v1::Build> for Build {
type Error = eyre::Error;

Expand Down
Loading
Loading