Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/nix_fmt.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: "Check Nix formatting"
name: "Nix checks"
on:
push:
branches: [main]
Expand All @@ -9,7 +9,7 @@ on:

jobs:
build:
name: Check Nix formatting
name: Nix checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -18,3 +18,5 @@ jobs:
nix_path: nixpkgs=channel:nixos-unstable
- name: Check formatting
run: nix fmt -- --ci
- name: Nix checks
run: nix build .\#checks.x86_64-linux.default
4 changes: 4 additions & 0 deletions build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ impl Display for PythonDependency {
#[serde(deny_unknown_fields)]
pub struct Torch {
pub include: Option<Vec<String>>,
pub minver: Option<Version>,
pub maxver: Option<Version>,
pub pyext: Option<Vec<String>>,

#[serde(default)]
Expand Down Expand Up @@ -352,6 +354,8 @@ impl From<v1::Torch> for Torch {
fn from(torch: v1::Torch) -> Self {
Self {
include: torch.include,
minver: None,
maxver: None,
pyext: torch.pyext,
src: torch.src,
}
Expand Down
16 changes: 16 additions & 0 deletions build2cmake/src/templates/cpu/preamble.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,20 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")

find_package(Torch REQUIRED)

run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")

{% if torch_minver %}
if (TORCH_VERSION VERSION_LESS {{ torch_minver }})
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. "
"Minimum required version is {{ torch_minver }}.")
endif()
{% endif %}

{% if torch_maxver %}
if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. "
"Maximum supported version is {{ torch_maxver }}.")
endif()
{% endif %}

add_compile_definitions(CPU_KERNEL)
16 changes: 16 additions & 0 deletions build2cmake/src/templates/cuda/preamble.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")

find_package(Torch REQUIRED)

run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")

{% if torch_minver %}
if (TORCH_VERSION VERSION_LESS {{ torch_minver }})
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. "
"Minimum required version is {{ torch_minver }}.")
endif()
{% endif %}

{% if torch_maxver %}
if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. "
"Maximum supported version is {{ torch_maxver }}.")
endif()
{% endif %}

if (NOT TARGET_DEVICE STREQUAL "cuda" AND
NOT TARGET_DEVICE STREQUAL "rocm")
return()
Expand Down
16 changes: 16 additions & 0 deletions build2cmake/src/templates/metal/preamble.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")

find_package(Torch REQUIRED)

run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")

{% if torch_minver %}
if (TORCH_VERSION VERSION_LESS {{ torch_minver }})
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. "
"Minimum required version is {{ torch_minver }}.")
endif()
{% endif %}

{% if torch_maxver %}
if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. "
"Maximum supported version is {{ torch_maxver }}.")
endif()
{% endif %}

add_compile_definitions(METAL_KERNEL)

# Initialize list for Metal shader sources
Expand Down
16 changes: 15 additions & 1 deletion build2cmake/src/templates/xpu/preamble.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,23 @@ find_package(Torch REQUIRED)

# Intel XPU backend detection and setup
if(NOT TORCH_VERSION)
run_python(TORCH_VERSION "import torch; print(torch.__version__)" "Failed to get Torch version")
run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version")
endif()

{% if torch_minver %}
if (TORCH_VERSION VERSION_LESS {{ torch_minver }})
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. "
"Minimum required version is {{ torch_minver }}.")
endif()
{% endif %}

{% if torch_maxver %}
if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }})
message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. "
"Maximum supported version is {{ torch_maxver }}.")
endif()
{% endif %}

# Check for Intel XPU support in PyTorch
run_python(XPU_AVAILABLE
"import torch; print('true' if hasattr(torch, 'xpu') else 'false')"
Expand Down
19 changes: 17 additions & 2 deletions build2cmake/src/torch/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::{common::write_pyproject_toml, kernel_ops_identifier};
use crate::{
config::{Build, Kernel, Torch},
fileset::FileSet,
version::Version,
};

static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
Expand Down Expand Up @@ -71,7 +72,13 @@ fn write_cmake(

let cmake_writer = file_set.entry("CMakeLists.txt");

render_preamble(env, name, cmake_writer)?;
render_preamble(
env,
name,
torch.minver.as_ref(),
torch.maxver.as_ref(),
cmake_writer,
)?;

// Add deps once we have any non-CUDA deps.
// render_deps(env, build, cmake_writer)?;
Expand Down Expand Up @@ -168,12 +175,20 @@ pub fn render_kernel(
Ok(())
}

fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
fn render_preamble(
env: &Environment,
name: &str,
torch_minver: Option<&Version>,
torch_maxver: Option<&Version>,
write: &mut impl Write,
) -> Result<()> {
env.get_template("cpu/preamble.cmake")
.wrap_err("Cannot get CMake prelude template")?
.render_to_write(
context! {
name => name,
torch_minver => torch_minver.map(|v| v.to_string()),
torch_maxver => torch_maxver.map(|v| v.to_string()),
},
&mut *write,
)
Expand Down
6 changes: 6 additions & 0 deletions build2cmake/src/torch/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ fn write_cmake(
name,
build.general.cuda_minver.as_ref(),
build.general.cuda_maxver.as_ref(),
torch.minver.as_ref(),
torch.maxver.as_ref(),
cmake_writer,
)?;

Expand Down Expand Up @@ -390,6 +392,8 @@ pub fn render_preamble(
name: &str,
cuda_minver: Option<&Version>,
cuda_maxver: Option<&Version>,
torch_minver: Option<&Version>,
torch_maxver: Option<&Version>,
write: &mut impl Write,
) -> Result<()> {
env.get_template("cuda/preamble.cmake")
Expand All @@ -399,6 +403,8 @@ pub fn render_preamble(
name => name,
cuda_minver => cuda_minver.map(|v| v.to_string()),
cuda_maxver => cuda_maxver.map(|v| v.to_string()),
torch_minver => torch_minver.map(|v| v.to_string()),
torch_maxver => torch_maxver.map(|v| v.to_string()),
cuda_supported_archs => cuda_supported_archs(),
platform => env::consts::OS
},
Expand Down
19 changes: 17 additions & 2 deletions build2cmake/src/torch/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::{common::write_pyproject_toml, kernel_ops_identifier};
use crate::{
config::{Build, Kernel, Torch},
fileset::FileSet,
version::Version,
};

static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
Expand Down Expand Up @@ -87,7 +88,13 @@ fn write_cmake(

let cmake_writer = file_set.entry("CMakeLists.txt");

render_preamble(env, name, cmake_writer)?;
render_preamble(
env,
name,
torch.minver.as_ref(),
torch.maxver.as_ref(),
cmake_writer,
)?;

// Add deps once we have any non-CUDA deps.
// render_deps(env, build, cmake_writer)?;
Expand Down Expand Up @@ -184,12 +191,20 @@ pub fn render_kernel(
Ok(())
}

fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
fn render_preamble(
env: &Environment,
name: &str,
torch_minver: Option<&Version>,
torch_maxver: Option<&Version>,
write: &mut impl Write,
) -> Result<()> {
env.get_template("metal/preamble.cmake")
.wrap_err("Cannot get CMake prelude template")?
.render_to_write(
context! {
name => name,
torch_minver => torch_minver.map(|v| v.to_string()),
torch_maxver => torch_maxver.map(|v| v.to_string()),
},
&mut *write,
)
Expand Down
19 changes: 17 additions & 2 deletions build2cmake/src/torch/xpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use minijinja::{context, Environment};
use super::common::write_pyproject_toml;
use super::kernel_ops_identifier;
use crate::config::{Build, Dependency, Kernel, Torch};
use crate::version::Version;
use crate::FileSet;

static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
Expand Down Expand Up @@ -135,7 +136,13 @@ fn write_cmake(

let cmake_writer = file_set.entry("CMakeLists.txt");

render_preamble(env, name, cmake_writer)?;
render_preamble(
env,
name,
torch.minver.as_ref(),
torch.maxver.as_ref(),
cmake_writer,
)?;

render_deps(env, build, cmake_writer)?;

Expand Down Expand Up @@ -263,12 +270,20 @@ pub fn render_extension(
Ok(())
}

pub fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> {
pub fn render_preamble(
env: &Environment,
name: &str,
torch_minver: Option<&Version>,
torch_maxver: Option<&Version>,
write: &mut impl Write,
) -> Result<()> {
env.get_template("xpu/preamble.cmake")
.wrap_err("Cannot get CMake prelude template")?
.render_to_write(
context! {
name => name,
torch_minver => torch_minver.map(|v| v.to_string()),
torch_maxver => torch_maxver.map(|v| v.to_string()),
},
&mut *write,
)
Expand Down
4 changes: 4 additions & 0 deletions docs/writing-kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ options:
`["py", "pyi"]`.
- `include` (optional): include directories relative to the project root.
Default: `[]`.
- `maxver` (optional): only build for this Torch version and earlier. Use cautiously, since this option produces
non-compliant kernels if the version range does not correspond to the [required variants](build-variants.md).
- `minver` (optional): only build for this Torch version and later. Use cautiously, since this option produces
non-compliant kernels if the version range does not correspond to the [required variants](build-variants.md).

### `kernel.<name>`

Expand Down
29 changes: 29 additions & 0 deletions examples/relu-torch-bounds/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[general]
name = "relu"
universal = false

[torch]
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
minver = "2.9"
maxver = "2.9"

[kernel.relu]
backend = "cuda"
depends = ["torch"]
src = ["relu_cuda/relu.cu"]

[kernel.relu_rocm]
backend = "rocm"
rocm-archs = [
"gfx906",
"gfx908",
"gfx90a",
"gfx940",
"gfx941",
"gfx942",
"gfx1030",
"gfx1100",
"gfx1101",
]
depends = ["torch"]
src = ["relu_cuda/relu.cu"]
17 changes: 17 additions & 0 deletions examples/relu-torch-bounds/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
description = "Flake for ReLU kernel";

inputs = {
kernel-builder.url = "path:../..";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genFlakeOutputs {
inherit self;
path = ./.;
};
}
43 changes: 43 additions & 0 deletions examples/relu-torch-bounds/relu_cuda/relu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#include <cmath>

__global__ void relu_kernel(float *__restrict__ out,
float const *__restrict__ input, const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
auto x = input[token_idx * d + idx];
out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
}
}

void relu(torch::Tensor &out, torch::Tensor const &input) {
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
input.scalar_type() == at::ScalarType::Float,
"relu_kernel only supports float32");

TORCH_CHECK(input.sizes() == out.sizes(),
"Tensors must have the same shape. Got input shape: ",
input.sizes(), " and output shape: ", out.sizes());

TORCH_CHECK(input.scalar_type() == out.scalar_type(),
"Tensors must have the same data type. Got input dtype: ",
input.scalar_type(), " and output dtype: ", out.scalar_type());

TORCH_CHECK(input.device() == out.device(),
"Tensors must be on the same device. Got input device: ",
input.device(), " and output device: ", out.device());

int d = input.size(-1);
int64_t num_tokens = input.numel() / d;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
input.data_ptr<float>(), d);
}
Empty file.
Loading
Loading