diff --git a/.github/workflows/nix_fmt.yaml b/.github/workflows/nix_fmt.yaml index fa620ea2..629d8a48 100644 --- a/.github/workflows/nix_fmt.yaml +++ b/.github/workflows/nix_fmt.yaml @@ -1,4 +1,4 @@ -name: "Check Nix formatting" +name: "Nix checks" on: push: branches: [main] @@ -9,7 +9,7 @@ on: jobs: build: - name: Check Nix formatting + name: Nix checks runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -18,3 +18,5 @@ jobs: nix_path: nixpkgs=channel:nixos-unstable - name: Check formatting run: nix fmt -- --ci + - name: Nix checks + run: nix build .\#checks.x86_64-linux.default diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 60b222f8..03cba382 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -92,6 +92,8 @@ impl Display for PythonDependency { #[serde(deny_unknown_fields)] pub struct Torch { pub include: Option>, + pub minver: Option, + pub maxver: Option, pub pyext: Option>, #[serde(default)] @@ -352,6 +354,8 @@ impl From for Torch { fn from(torch: v1::Torch) -> Self { Self { include: torch.include, + minver: None, + maxver: None, pyext: torch.pyext, src: torch.src, } diff --git a/build2cmake/src/templates/cpu/preamble.cmake b/build2cmake/src/templates/cpu/preamble.cmake index bbd064f2..18046454 100644 --- a/build2cmake/src/templates/cpu/preamble.cmake +++ b/build2cmake/src/templates/cpu/preamble.cmake @@ -25,4 +25,20 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") find_package(Torch REQUIRED) +run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version") + +{% if torch_minver %} +if (TORCH_VERSION VERSION_LESS {{ torch_minver }}) + message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. " + "Minimum required version is {{ torch_minver }}.") +endif() +{% endif %} + +{% if torch_maxver %} +if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }}) + message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. " + "Maximum supported version is {{ torch_maxver }}.") +endif() +{% endif %} + add_compile_definitions(CPU_KERNEL) diff --git a/build2cmake/src/templates/cuda/preamble.cmake b/build2cmake/src/templates/cuda/preamble.cmake index 1f709da2..05024e06 100644 --- a/build2cmake/src/templates/cuda/preamble.cmake +++ b/build2cmake/src/templates/cuda/preamble.cmake @@ -29,6 +29,22 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") find_package(Torch REQUIRED) +run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version") + +{% if torch_minver %} +if (TORCH_VERSION VERSION_LESS {{ torch_minver }}) + message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. " + "Minimum required version is {{ torch_minver }}.") +endif() +{% endif %} + +{% if torch_maxver %} +if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }}) + message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. " + "Maximum supported version is {{ torch_maxver }}.") +endif() +{% endif %} + if (NOT TARGET_DEVICE STREQUAL "cuda" AND NOT TARGET_DEVICE STREQUAL "rocm") return() diff --git a/build2cmake/src/templates/metal/preamble.cmake b/build2cmake/src/templates/metal/preamble.cmake index 9871d286..524f037b 100644 --- a/build2cmake/src/templates/metal/preamble.cmake +++ b/build2cmake/src/templates/metal/preamble.cmake @@ -25,6 +25,22 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") find_package(Torch REQUIRED) +run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version") + +{% if torch_minver %} +if (TORCH_VERSION VERSION_LESS {{ torch_minver }}) + message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. " + "Minimum required version is {{ torch_minver }}.") +endif() +{% endif %} + +{% if torch_maxver %} +if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }}) + message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. " + "Maximum supported version is {{ torch_maxver }}.") +endif() +{% endif %} + add_compile_definitions(METAL_KERNEL) # Initialize list for Metal shader sources diff --git a/build2cmake/src/templates/xpu/preamble.cmake b/build2cmake/src/templates/xpu/preamble.cmake index c05afcf9..54dca570 100644 --- a/build2cmake/src/templates/xpu/preamble.cmake +++ b/build2cmake/src/templates/xpu/preamble.cmake @@ -43,9 +43,23 @@ find_package(Torch REQUIRED) # Intel XPU backend detection and setup if(NOT TORCH_VERSION) - run_python(TORCH_VERSION "import torch; print(torch.__version__)" "Failed to get Torch version") + run_python(TORCH_VERSION "import torch; print(torch.__version__.split('+')[0])" "Failed to get Torch version") endif() +{% if torch_minver %} +if (TORCH_VERSION VERSION_LESS {{ torch_minver }}) + message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too old. " + "Minimum required version is {{ torch_minver }}.") +endif() +{% endif %} + +{% if torch_maxver %} +if (TORCH_VERSION VERSION_GREATER {{ torch_maxver }}) + message(FATAL_ERROR "Torch version ${TORCH_VERSION} is too new. " + "Maximum supported version is {{ torch_maxver }}.") +endif() +{% endif %} + # Check for Intel XPU support in PyTorch run_python(XPU_AVAILABLE "import torch; print('true' if hasattr(torch, 'xpu') else 'false')" diff --git a/build2cmake/src/torch/cpu.rs b/build2cmake/src/torch/cpu.rs index 9bec4bd3..5b2829f4 100644 --- a/build2cmake/src/torch/cpu.rs +++ b/build2cmake/src/torch/cpu.rs @@ -8,6 +8,7 @@ use super::{common::write_pyproject_toml, kernel_ops_identifier}; use crate::{ config::{Build, Kernel, Torch}, fileset::FileSet, + version::Version, }; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); @@ -71,7 +72,13 @@ fn write_cmake( let cmake_writer = file_set.entry("CMakeLists.txt"); - render_preamble(env, name, cmake_writer)?; + render_preamble( + env, + name, + torch.minver.as_ref(), + torch.maxver.as_ref(), + cmake_writer, + )?; // Add deps once we have any non-CUDA deps. // render_deps(env, build, cmake_writer)?; @@ -168,12 +175,20 @@ pub fn render_kernel( Ok(()) } -fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { +fn render_preamble( + env: &Environment, + name: &str, + torch_minver: Option<&Version>, + torch_maxver: Option<&Version>, + write: &mut impl Write, +) -> Result<()> { env.get_template("cpu/preamble.cmake") .wrap_err("Cannot get CMake prelude template")? .render_to_write( context! { name => name, + torch_minver => torch_minver.map(|v| v.to_string()), + torch_maxver => torch_maxver.map(|v| v.to_string()), }, &mut *write, ) diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index faf3eb56..297df3d5 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -168,6 +168,8 @@ fn write_cmake( name, build.general.cuda_minver.as_ref(), build.general.cuda_maxver.as_ref(), + torch.minver.as_ref(), + torch.maxver.as_ref(), cmake_writer, )?; @@ -390,6 +392,8 @@ pub fn render_preamble( name: &str, cuda_minver: Option<&Version>, cuda_maxver: Option<&Version>, + torch_minver: Option<&Version>, + torch_maxver: Option<&Version>, write: &mut impl Write, ) -> Result<()> { env.get_template("cuda/preamble.cmake") @@ -399,6 +403,8 @@ pub fn render_preamble( name => name, cuda_minver => cuda_minver.map(|v| v.to_string()), cuda_maxver => cuda_maxver.map(|v| v.to_string()), + torch_minver => torch_minver.map(|v| v.to_string()), + torch_maxver => torch_maxver.map(|v| v.to_string()), cuda_supported_archs => cuda_supported_archs(), platform => env::consts::OS }, diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs index ad09ac65..b4317c68 100644 --- a/build2cmake/src/torch/metal.rs +++ b/build2cmake/src/torch/metal.rs @@ -8,6 +8,7 @@ use super::{common::write_pyproject_toml, kernel_ops_identifier}; use crate::{ config::{Build, Kernel, Torch}, fileset::FileSet, + version::Version, }; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); @@ -87,7 +88,13 @@ fn write_cmake( let cmake_writer = file_set.entry("CMakeLists.txt"); - render_preamble(env, name, cmake_writer)?; + render_preamble( + env, + name, + torch.minver.as_ref(), + torch.maxver.as_ref(), + cmake_writer, + )?; // Add deps once we have any non-CUDA deps. // render_deps(env, build, cmake_writer)?; @@ -184,12 +191,20 @@ pub fn render_kernel( Ok(()) } -fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { +fn render_preamble( + env: &Environment, + name: &str, + torch_minver: Option<&Version>, + torch_maxver: Option<&Version>, + write: &mut impl Write, +) -> Result<()> { env.get_template("metal/preamble.cmake") .wrap_err("Cannot get CMake prelude template")? .render_to_write( context! { name => name, + torch_minver => torch_minver.map(|v| v.to_string()), + torch_maxver => torch_maxver.map(|v| v.to_string()), }, &mut *write, ) diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 5f045b4e..3ef285b1 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -9,6 +9,7 @@ use minijinja::{context, Environment}; use super::common::write_pyproject_toml; use super::kernel_ops_identifier; use crate::config::{Build, Dependency, Kernel, Torch}; +use crate::version::Version; use crate::FileSet; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); @@ -135,7 +136,13 @@ fn write_cmake( let cmake_writer = file_set.entry("CMakeLists.txt"); - render_preamble(env, name, cmake_writer)?; + render_preamble( + env, + name, + torch.minver.as_ref(), + torch.maxver.as_ref(), + cmake_writer, + )?; render_deps(env, build, cmake_writer)?; @@ -263,12 +270,20 @@ pub fn render_extension( Ok(()) } -pub fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { +pub fn render_preamble( + env: &Environment, + name: &str, + torch_minver: Option<&Version>, + torch_maxver: Option<&Version>, + write: &mut impl Write, +) -> Result<()> { env.get_template("xpu/preamble.cmake") .wrap_err("Cannot get CMake prelude template")? .render_to_write( context! { name => name, + torch_minver => torch_minver.map(|v| v.to_string()), + torch_maxver => torch_maxver.map(|v| v.to_string()), }, &mut *write, ) diff --git a/docs/writing-kernels.md b/docs/writing-kernels.md index 8d66db2e..172f966e 100644 --- a/docs/writing-kernels.md +++ b/docs/writing-kernels.md @@ -123,6 +123,10 @@ options: `["py", "pyi"]`. - `include` (optional): include directories relative to the project root. Default: `[]`. +- `maxver` (optional): only build for this Torch version and earlier. Use cautiously, since this option produces + non-compliant kernels if the version range does not correspond to the [required variants](build-variants.md). +- `minver` (optional): only build for this Torch version and later. Use cautiously, since this option produces + non-compliant kernels if the version range does not correspond to the [required variants](build-variants.md). ### `kernel.` diff --git a/examples/relu-torch-bounds/build.toml b/examples/relu-torch-bounds/build.toml new file mode 100644 index 00000000..8cc96b51 --- /dev/null +++ b/examples/relu-torch-bounds/build.toml @@ -0,0 +1,29 @@ +[general] +name = "relu" +universal = false + +[torch] +src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"] +minver = "2.9" +maxver = "2.9" + +[kernel.relu] +backend = "cuda" +depends = ["torch"] +src = ["relu_cuda/relu.cu"] + +[kernel.relu_rocm] +backend = "rocm" +rocm-archs = [ + "gfx906", + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx1030", + "gfx1100", + "gfx1101", +] +depends = ["torch"] +src = ["relu_cuda/relu.cu"] diff --git a/examples/relu-torch-bounds/flake.nix b/examples/relu-torch-bounds/flake.nix new file mode 100644 index 00000000..bfe8717d --- /dev/null +++ b/examples/relu-torch-bounds/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for ReLU kernel"; + + inputs = { + kernel-builder.url = "path:../.."; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + inherit self; + path = ./.; + }; +} diff --git a/examples/relu-torch-bounds/relu_cuda/relu.cu b/examples/relu-torch-bounds/relu_cuda/relu.cu new file mode 100644 index 00000000..6bbe3160 --- /dev/null +++ b/examples/relu-torch-bounds/relu_cuda/relu.cu @@ -0,0 +1,43 @@ +#include +#include +#include + +#include + +__global__ void relu_kernel(float *__restrict__ out, + float const *__restrict__ input, const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + auto x = input[token_idx * d + idx]; + out[token_idx * d + idx] = x > 0.0f ? x : 0.0f; + } +} + +void relu(torch::Tensor &out, torch::Tensor const &input) { + TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Float && + input.scalar_type() == at::ScalarType::Float, + "relu_kernel only supports float32"); + + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + + int d = input.size(-1); + int64_t num_tokens = input.numel() / d; + dim3 grid(num_tokens); + dim3 block(std::min(d, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + relu_kernel<<>>(out.data_ptr(), + input.data_ptr(), d); +} diff --git a/examples/relu-torch-bounds/tests/__init__.py b/examples/relu-torch-bounds/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/relu-torch-bounds/tests/test_relu.py b/examples/relu-torch-bounds/tests/test_relu.py new file mode 100644 index 00000000..98b292b9 --- /dev/null +++ b/examples/relu-torch-bounds/tests/test_relu.py @@ -0,0 +1,15 @@ +import platform + +import torch +import torch.nn.functional as F + +import relu + + +def test_relu(): + if platform.system() == "Darwin": + device = torch.device("mps") + else: + device = torch.device("cuda") + x = torch.randn(1024, 1024, dtype=torch.float32, device=device) + torch.testing.assert_allclose(F.relu(x), relu.relu(x)) diff --git a/examples/relu-torch-bounds/torch-ext/relu/__init__.py b/examples/relu-torch-bounds/torch-ext/relu/__init__.py new file mode 100644 index 00000000..d801867e --- /dev/null +++ b/examples/relu-torch-bounds/torch-ext/relu/__init__.py @@ -0,0 +1,12 @@ +from typing import Optional + +import torch + +from ._ops import ops + + +def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + if out is None: + out = torch.empty_like(x) + ops.relu(out, x) + return out diff --git a/examples/relu-torch-bounds/torch-ext/torch_binding.cpp b/examples/relu-torch-bounds/torch-ext/torch_binding.cpp new file mode 100644 index 00000000..4f75d886 --- /dev/null +++ b/examples/relu-torch-bounds/torch-ext/torch_binding.cpp @@ -0,0 +1,15 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("relu(Tensor! out, Tensor input) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("relu", torch::kCUDA, &relu); +#elif defined(METAL_KERNEL) + ops.impl("relu", torch::kMPS, relu); +#endif +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/examples/relu-torch-bounds/torch-ext/torch_binding.h b/examples/relu-torch-bounds/torch-ext/torch_binding.h new file mode 100644 index 00000000..c1186254 --- /dev/null +++ b/examples/relu-torch-bounds/torch-ext/torch_binding.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +void relu(torch::Tensor &out, torch::Tensor const &input); diff --git a/flake.nix b/flake.nix index eebeeb01..66f4fb8f 100644 --- a/flake.nix +++ b/flake.nix @@ -118,6 +118,11 @@ in rec { + checks.default = pkgs.callPackage ./lib/checks.nix { + inherit buildSets; + build = defaultBuildPerSystem.${system}; + }; + formatter = pkgs.nixfmt-tree; packages = diff --git a/lib/build.nix b/lib/build.nix index 5b6c6f3c..a39e1a00 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -73,6 +73,8 @@ rec { backends' = backends buildToml; minCuda = buildToml.general.cuda-minver or "11.8"; maxCuda = buildToml.general.cuda-maxver or "99.9"; + minTorch = buildToml.torch.minver or "2.0"; + maxTorch = buildToml.torch.maxver or "99.9"; versionBetween = minver: maxver: ver: builtins.compareVersions ver minver >= 0 && builtins.compareVersions ver maxver <= 0; @@ -89,8 +91,11 @@ rec { cudaVersionSupported = !(isCuda buildSet.buildConfig) || versionBetween minCuda maxCuda buildSet.pkgs.cudaPackages.cudaMajorMinorVersion; + torchVersionParts = lib.splitString "." buildSet.torch.version; + torchMajorMinor = lib.concatStringsSep "." (lib.take 2 torchVersionParts); + torchVersionSupported = versionBetween minTorch maxTorch torchMajorMinor; in - backendSupported && cudaVersionSupported; + backendSupported && cudaVersionSupported && torchVersionSupported; in builtins.filter supportedBuildSet buildSets; diff --git a/lib/checks.nix b/lib/checks.nix new file mode 100644 index 00000000..7ce6716e --- /dev/null +++ b/lib/checks.nix @@ -0,0 +1,20 @@ +{ + lib, + runCommand, + + build, + buildSets, +}: + +let + kernelBuildSets = build.applicableBuildSets { + inherit buildSets; + path = ../examples/relu-torch-bounds; + }; +in +assert lib.assertMsg (builtins.all (buildSet: buildSet.torch.version == "2.9.0") kernelBuildSets) '' + Torch minver/maxver filtering does not work. +''; +runCommand "builder-nix-checks" { } '' + touch $out +''