From 7d763ed524abd5b61ef112951a0f1b97a8904ae6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 20 Mar 2026 14:25:58 +0000 Subject: [PATCH 1/2] kernel-builder: add support for C++ dependencies in tvm-ffi kernels --- kernel-builder/src/pyproject/{torch => }/deps.rs | 14 +++++++------- kernel-builder/src/pyproject/mod.rs | 3 ++- .../templates/{torch => }/cuda/dep-cutlass.cmake | 0 .../templates/{torch => }/xpu/dep-sycl-tla.cmake | 0 kernel-builder/src/pyproject/torch/mod.rs | 3 +-- kernel-builder/src/pyproject/tvm_ffi/mod.rs | 3 +++ 6 files changed, 13 insertions(+), 10 deletions(-) rename kernel-builder/src/pyproject/{torch => }/deps.rs (87%) rename kernel-builder/src/pyproject/templates/{torch => }/cuda/dep-cutlass.cmake (100%) rename kernel-builder/src/pyproject/templates/{torch => }/xpu/dep-sycl-tla.cmake (100%) diff --git a/kernel-builder/src/pyproject/torch/deps.rs b/kernel-builder/src/pyproject/deps.rs similarity index 87% rename from kernel-builder/src/pyproject/torch/deps.rs rename to kernel-builder/src/pyproject/deps.rs index 41f16faa..54cb79ce 100644 --- a/kernel-builder/src/pyproject/torch/deps.rs +++ b/kernel-builder/src/pyproject/deps.rs @@ -16,7 +16,7 @@ pub fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> for dep in deps { match dep { Dependency::Cutlass2_10 => { - env.get_template("torch/cuda/dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -27,7 +27,7 @@ pub fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> .wrap_err("Cannot render CUTLASS dependency template")?; } Dependency::Cutlass3_5 => { - env.get_template("torch/cuda/dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -38,7 +38,7 @@ pub fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> .wrap_err("Cannot render CUTLASS dependency template")?; } Dependency::Cutlass3_6 => { - env.get_template("torch/cuda/dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -49,7 +49,7 @@ pub fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> .wrap_err("Cannot render CUTLASS dependency template")?; } Dependency::Cutlass3_8 => { - env.get_template("torch/cuda/dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -60,7 +60,7 @@ pub fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> .wrap_err("Cannot render CUTLASS dependency template")?; } Dependency::Cutlass3_9 => { - env.get_template("torch/cuda/dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -71,7 +71,7 @@ pub fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> .wrap_err("Cannot render CUTLASS dependency template")?; } Dependency::Cutlass4_0 => { - env.get_template("torch/cuda/dep-cutlass.cmake") + env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( context! { @@ -82,7 +82,7 @@ pub fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> .wrap_err("Cannot render CUTLASS dependency template")?; } Dependency::SyclTla => { - env.get_template("torch/xpu/dep-sycl-tla.cmake")? + env.get_template("xpu/dep-sycl-tla.cmake")? .render_to_write(context! {}, &mut *write)?; } Dependency::MetalCpp => { diff --git a/kernel-builder/src/pyproject/mod.rs b/kernel-builder/src/pyproject/mod.rs index 89db2b8c..5537d4c1 100644 --- a/kernel-builder/src/pyproject/mod.rs +++ b/kernel-builder/src/pyproject/mod.rs @@ -3,7 +3,8 @@ use std::path::Path; use eyre::Result; use minijinja::Environment; -mod common; +pub(crate) mod common; +pub mod deps; pub mod fileset; mod kernel; mod metadata; diff --git a/kernel-builder/src/pyproject/templates/torch/cuda/dep-cutlass.cmake b/kernel-builder/src/pyproject/templates/cuda/dep-cutlass.cmake similarity index 100% rename from kernel-builder/src/pyproject/templates/torch/cuda/dep-cutlass.cmake rename to kernel-builder/src/pyproject/templates/cuda/dep-cutlass.cmake diff --git a/kernel-builder/src/pyproject/templates/torch/xpu/dep-sycl-tla.cmake b/kernel-builder/src/pyproject/templates/xpu/dep-sycl-tla.cmake similarity index 100% rename from kernel-builder/src/pyproject/templates/torch/xpu/dep-sycl-tla.cmake rename to kernel-builder/src/pyproject/templates/xpu/dep-sycl-tla.cmake diff --git a/kernel-builder/src/pyproject/torch/mod.rs b/kernel-builder/src/pyproject/torch/mod.rs index e0183390..a7855936 100644 --- a/kernel-builder/src/pyproject/torch/mod.rs +++ b/kernel-builder/src/pyproject/torch/mod.rs @@ -12,8 +12,7 @@ use crate::pyproject::common::{ use crate::pyproject::ops_identifier::{git_identifier, random_identifier}; use crate::pyproject::FileSet; -mod deps; -use deps::render_deps; +use crate::pyproject::deps::render_deps; use crate::pyproject::kernel::render_kernel_components; diff --git a/kernel-builder/src/pyproject/tvm_ffi/mod.rs b/kernel-builder/src/pyproject/tvm_ffi/mod.rs index 683fb11e..e544b5a3 100644 --- a/kernel-builder/src/pyproject/tvm_ffi/mod.rs +++ b/kernel-builder/src/pyproject/tvm_ffi/mod.rs @@ -9,6 +9,7 @@ use crate::config::{Backend, Build, General, TvmFfi}; use crate::pyproject::common::{ prefix_and_join_includes, write_cmake_file, write_compat_py, write_metadata, }; +use crate::pyproject::deps::render_deps; use crate::pyproject::kernel::render_kernel_components; use crate::pyproject::ops_identifier::{git_identifier, random_identifier}; use crate::pyproject::FileSet; @@ -228,6 +229,8 @@ pub fn write_cmake( render_preamble(env, &build.general, revision, cmake_writer)?; + render_deps(env, build, cmake_writer)?; + render_binding(env, tvm_ffi, name, cmake_writer)?; render_kernel_components(env, build, cmake_writer)?; From 6e53633e4ca7edc7856e6d6eae52f01c112272b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 20 Mar 2026 14:51:28 +0000 Subject: [PATCH 2/2] Add cutlass-gemm-tvm-ffi example/test for tvm-ffi with deps --- .github/workflows/build_kernel.yaml | 6 + .../kernels/cutlass-gemm-tvm-ffi/build.toml | 25 +++ .../kernels/cutlass-gemm-tvm-ffi/flake.nix | 17 ++ examples/kernels/cutlass-gemm-tvm-ffi/gemm.cu | 58 ++++++ .../cutlass-gemm-tvm-ffi/gemm_sycl.cpp | 174 ++++++++++++++++++ .../cutlass-gemm-tvm-ffi/tests/conftest.py | 12 ++ .../cutlass-gemm-tvm-ffi/tests/test_gemm.py | 12 ++ .../cutlass_gemm_tvm_ffi/__init__.py | 19 ++ .../tvm-ffi-ext/tvm_ffi_binding.cpp | 7 + examples/kernels/cutlass-gemm-tvm-ffi/util.hh | 33 ++++ nix-builder/tests/Dockerfile.test-kernel | 2 + nix-builder/tests/run-tests.sh | 5 +- 12 files changed, 368 insertions(+), 2 deletions(-) create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/build.toml create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/flake.nix create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/gemm.cu create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/gemm_sycl.cpp create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/tests/conftest.py create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/tests/test_gemm.py create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/tvm-ffi-ext/cutlass_gemm_tvm_ffi/__init__.py create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/tvm-ffi-ext/tvm_ffi_binding.cpp create mode 100644 examples/kernels/cutlass-gemm-tvm-ffi/util.hh diff --git a/.github/workflows/build_kernel.yaml b/.github/workflows/build_kernel.yaml index ca50b12e..7b5ebd87 100644 --- a/.github/workflows/build_kernel.yaml +++ b/.github/workflows/build_kernel.yaml @@ -64,6 +64,11 @@ jobs: - name: Copy cutlass GEMM kernel run: cp -rL examples/kernels/cutlass-gemm/result cutlass-gemm-kernel + - name: Build cutlass-gemm-tvm-ffi kernel + run: ( cd examples/kernels/cutlass-gemm-tvm-ffi && nix build .\#redistributable.tvm-ffi01-cu126-${{ matrix.arch }} ) + - name: Copy cutlass-gemm-tvm-ffi kernel + run: cp -rL examples/kernels/cutlass-gemm-tvm-ffi/result cutlass-gemm-tvm-ffi-kernel + - name: Build relu-backprop-compile kernel run: ( cd examples/kernels/relu-backprop-compile && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} ) - name: Copy relu-backprop-compile kernel @@ -91,6 +96,7 @@ jobs: path: | activation-kernel cutlass-gemm-kernel + cutlass-gemm-tvm-ffi-kernel extra-data relu-kernel relu-tvm-ffi-kernel diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/build.toml b/examples/kernels/cutlass-gemm-tvm-ffi/build.toml new file mode 100644 index 00000000..e6d5087d --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/build.toml @@ -0,0 +1,25 @@ +[general] +name = "cutlass-gemm-tvm-ffi" +backends = [ + "cuda", + "xpu", +] + +[tvm-ffi] +src = [ + "tvm-ffi-ext/tvm_ffi_binding.cpp", +] + +[kernel.gemm] +backend = "cuda" +depends = [ + "cutlass_3_6", +] +src = ["gemm.cu", "util.hh"] + +[kernel.gemm_xpu] +backend = "xpu" +depends = [ + "sycl_tla", +] +src = ["gemm_sycl.cpp", "util.hh"] diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/flake.nix b/examples/kernels/cutlass-gemm-tvm-ffi/flake.nix new file mode 100644 index 00000000..c52f2e8a --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for CUTLASS gemm tvm-ffi test kernel"; + + inputs = { + kernel-builder.url = "path:../../.."; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genKernelFlakeOutputs { + inherit self; + path = ./.; + }; +} diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/gemm.cu b/examples/kernels/cutlass-gemm-tvm-ffi/gemm.cu new file mode 100644 index 00000000..0a1abf55 --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/gemm.cu @@ -0,0 +1,58 @@ +#include +#include +#include +#include + +#include "util.hh" + +using namespace tvm; + +void cutlass_gemm(ffi::TensorView out, ffi::TensorView const A, ffi::TensorView const B) { + CHECK_INPUT_CUDA(A); + CHECK_INPUT_CUDA(B); + CHECK_INPUT_CUDA(out); + CHECK_DEVICE(A, out); + CHECK_DEVICE(B, out); + + TVM_FFI_CHECK(A.dtype() == dl_float32, TypeError) << "A must be float32"; + TVM_FFI_CHECK(B.dtype() == dl_float32, TypeError) << "B must be float32"; + TVM_FFI_CHECK(out.dtype() == dl_float32, TypeError) << "out must be float32"; + + TVM_FFI_CHECK(A.ndim() == 2, ValueError) << "A must be 2D"; + TVM_FFI_CHECK(B.ndim() == 2, ValueError) << "B must be 2D"; + TVM_FFI_CHECK(out.ndim() == 2, ValueError) << "out must be 2D"; + + ffi::CUDADeviceGuard guard(A.device().device_id); + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(A.device().device_type, A.device().device_id)); + + // Define the GEMM operation + using Gemm = cutlass::gemm::device::Gemm; + + // Create a GEMM object + Gemm gemm_op; + + // Define the problem size + cutlass::gemm::GemmCoord problem_size(A.size(0), B.size(1), A.size(1)); + + // Define the arguments for the GEMM operation + typename Gemm::Arguments args( + problem_size, + {static_cast(A.data_ptr()), static_cast(A.size(1))}, + {static_cast(B.data_ptr()), static_cast(B.size(1))}, + {static_cast(out.data_ptr()), static_cast(out.size(1))}, + {static_cast(out.data_ptr()), static_cast(out.size(1))}, + {1.0f, 0.0f} + ); + + // Launch the GEMM operation + cutlass::Status status = gemm_op(args, nullptr, stream); + + // Check for errors + if (status != cutlass::Status::kSuccess) { + TVM_FFI_THROW(RuntimeError) << "CUTLASS GEMM operation failed: " + << cutlassGetStatusString(status); + } +} \ No newline at end of file diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/gemm_sycl.cpp b/examples/kernels/cutlass-gemm-tvm-ffi/gemm_sycl.cpp new file mode 100644 index 00000000..a8866b2d --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/gemm_sycl.cpp @@ -0,0 +1,174 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Gemm Example. + + This example constructs and executes a simple CUTLASS GEMM kernel on Intel BMG hardware, and + verifies its correctness with a reference implementation + (cutlass::reference::device::GemmComplex). The example also provides a performance measurement + for the GEMM in TFLOPS. + + This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions. + + The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the + batch size is defined by `options.l`. The tile shape, which defines how much work is executed by + a single work-group, is defined at compile time by: + ``` + using TileShape = Shape<_256, _256, _32>; + ``` + That is, each work-group processes a tile of M=256, N=256, and iterates over `options.k` in + blocks of K=32. + + Performance of GEMM on BMG is heavily dependent on prefetching the A and B matrices. That is, + executing Intel specific prefetch instructions for future iterations to ensure that the required + blocks of A and B are resident in cache before they are needed. + + To build & run this example (from your build dir): + + $ ninja 00_bmg_gemm + $ ./examples/sycl/00_bmg_gemm/00_bmg_gemm + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include +using namespace cute; +using namespace tvm; + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + TVM_FFI_THROW(RuntimeError) << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__; \ + } \ + } + +void cutlass_gemm(ffi::TensorView out, ffi::TensorView const A, ffi::TensorView const B) { + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = bfloat16_t; + using ElementInputB = bfloat16_t; + using ElementOutput = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using TileShape = Shape<_256, _256, _32>; + using TiledMma = typename TiledMMAHelper, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, + GmemTiledCopyB, void, void, cute::identity>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // get shape + int M = A.size(0); + int K = A.size(1); + int N = B.size(1); + int L = 1; // batch size + + auto stride_A = cutlass::make_cute_packed_stride(GemmKernel::StrideA{}, cute::make_shape(M, K, L)); + auto stride_B = cutlass::make_cute_packed_stride(GemmKernel::StrideB{}, cute::make_shape(N, K, L)); + auto stride_C = cutlass::make_cute_packed_stride(GemmKernel::StrideC{}, cute::make_shape(M, N, L)); + auto stride_D = cutlass::make_cute_packed_stride(GemmKernel::StrideD{}, cute::make_shape(M, N, L)); + + GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + GemmKernel::ProblemShape{M, N, K, L}, + {reinterpret_cast(A.data_ptr()), stride_A, reinterpret_cast(B.data_ptr()), stride_B}, + {{1.0f, 0.0f}, reinterpret_cast(out.data_ptr()), stride_C, reinterpret_cast(out.data_ptr()), stride_D}, + hw_info + }; + + Gemm gemm_op; + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + TVM_FFI_CHECK(gemm_op.can_implement(arguments) == cutlass::Status::kSuccess, ValueError) << "Invalid GEMM problem size or configuration"; + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_op.run()); +#if defined(OLD_API) + syclcompat::wait(); +#else + compat::wait(); +#endif +} diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/tests/conftest.py b/examples/kernels/cutlass-gemm-tvm-ffi/tests/conftest.py new file mode 100644 index 00000000..6873749d --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/tests/conftest.py @@ -0,0 +1,12 @@ +import pytest +import torch + + +@pytest.fixture(scope="session") +def device() -> torch.device: + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu") + elif torch.version.cuda is not None and torch.cuda.is_available(): + return torch.device("cuda") + else: + pytest.skip("Neither CUDA nor XPU device available") diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/tests/test_gemm.py b/examples/kernels/cutlass-gemm-tvm-ffi/tests/test_gemm.py new file mode 100644 index 00000000..20ed571f --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/tests/test_gemm.py @@ -0,0 +1,12 @@ +import torch +from cutlass_gemm_tvm_ffi import cutlass_gemm + + +def test_gemm(device): + A = torch.randn((64, 32), device=device, dtype=torch.float32) + B = torch.randn((32, 64), device=device, dtype=torch.float32) + out = torch.zeros((64, 64), device=device, dtype=torch.float32) + + cutlass_gemm(out, A, B) + + torch.testing.assert_close(out, torch.mm(A, B)) diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/tvm-ffi-ext/cutlass_gemm_tvm_ffi/__init__.py b/examples/kernels/cutlass-gemm-tvm-ffi/tvm-ffi-ext/cutlass_gemm_tvm_ffi/__init__.py new file mode 100644 index 00000000..46d32cd9 --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/tvm-ffi-ext/cutlass_gemm_tvm_ffi/__init__.py @@ -0,0 +1,19 @@ +import tvm_ffi + +from ._ops import ops + + +def cutlass_gemm( + out: tvm_ffi.Tensor, A: tvm_ffi.Tensor, B: tvm_ffi.Tensor +) -> tvm_ffi.Tensor: + device = A.device + if device.type == "cuda": + ops.cutlass_gemm(out, A, B) + elif device.type == "xpu": + ops.cutlass_gemm(out, A, B) + else: + raise NotImplementedError(f"Unsupported device type: {device.type}") + return out + + +__all__ = ["cutlass_gemm"] diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/tvm-ffi-ext/tvm_ffi_binding.cpp b/examples/kernels/cutlass-gemm-tvm-ffi/tvm-ffi-ext/tvm_ffi_binding.cpp new file mode 100644 index 00000000..d859a949 --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/tvm-ffi-ext/tvm_ffi_binding.cpp @@ -0,0 +1,7 @@ +#include + +void cutlass_gemm(tvm::ffi::TensorView out, tvm::ffi::TensorView const A, tvm::ffi::TensorView const B); + +#if defined(CUDA_KERNEL) || defined(XPU_KERNEL) +TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_gemm, cutlass_gemm); +#endif \ No newline at end of file diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/util.hh b/examples/kernels/cutlass-gemm-tvm-ffi/util.hh new file mode 100644 index 00000000..4bd5a477 --- /dev/null +++ b/examples/kernels/cutlass-gemm-tvm-ffi/util.hh @@ -0,0 +1,33 @@ +#include + +#define CHECK_CUDA(x) \ + TVM_FFI_CHECK((x).device().device_type == kDLCUDA, ValueError) << #x " must be a CUDA tensor" +#define CHECK_CONTIGUOUS(x) \ + TVM_FFI_CHECK((x).IsContiguous(), ValueError) << #x " must be contiguous" +#define CHECK_INPUT(x) \ + do { \ + CHECK_CONTIGUOUS(x); \ + } while (0) +#define CHECK_INPUT_CUDA(x) \ + do { \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x); \ + } while (0) +#define CHECK_DEVICE(a, b) \ + do { \ + TVM_FFI_CHECK((a).device().device_type == (b).device().device_type, ValueError) \ + << #a " and " #b " must be on the same device type"; \ + TVM_FFI_CHECK((a).device().device_id == (b).device().device_id, ValueError) \ + << #a " and " #b " must be on the same device"; \ + } while (0) +#define CHECK_XPU(x) \ + TVM_FFI_CHECK((x).device().device_type == kDLOneAPI, ValueError) << #x " must be an XPU tensor" +#define CHECK_INPUT_XPU(x) \ + do { \ + CHECK_XPU(x); \ + CHECK_CONTIGUOUS(x); \ + } while (0) + +constexpr DLDataType dl_float32 = DLDataType{kDLFloat, 32, 1}; + + diff --git a/nix-builder/tests/Dockerfile.test-kernel b/nix-builder/tests/Dockerfile.test-kernel index cb99d4f1..aa46ad1a 100644 --- a/nix-builder/tests/Dockerfile.test-kernel +++ b/nix-builder/tests/Dockerfile.test-kernel @@ -67,11 +67,13 @@ COPY relu-kernel ./relu-kernel COPY relu-tvm-ffi-kernel ./relu-tvm-ffi-kernel COPY relu-kernel-cpu ./relu-kernel-cpu COPY cutlass-gemm-kernel ./cutlass-gemm-kernel +COPY cutlass-gemm-tvm-ffi-kernel ./cutlass-gemm-tvm-ffi-kernel COPY silu-and-mul-kernel ./silu-and-mul-kernel COPY examples/kernels/extra-data/tests ./extra_data_tests COPY examples/kernels/relu/tests ./relu_tests COPY examples/kernels/relu-tvm-ffi/tests ./relu_tvm_ffi_tests COPY examples/kernels/cutlass-gemm/tests ./cutlass_gemm_tests +COPY examples/kernels/cutlass-gemm-tvm-ffi/tests ./cutlass_gemm_tvm_ffi_tests # Run tests ADD nix-builder/tests/run-tests.sh ./run-tests.sh diff --git a/nix-builder/tests/run-tests.sh b/nix-builder/tests/run-tests.sh index 30c36163..b28325b7 100644 --- a/nix-builder/tests/run-tests.sh +++ b/nix-builder/tests/run-tests.sh @@ -5,11 +5,12 @@ EXTRA_DATA_PATH=$(echo extra-data/torch*) RELU_PATH=$(echo relu-kernel/torch*) RELU_TVM_FFI_PATH=$(echo relu-tvm-ffi-kernel/tvm-ffi*) CUTLASS_PATH=$(echo cutlass-gemm-kernel/torch*) +CUTLASS_TVM_FFI_PATH=$(echo cutlass-gemm-tvm-ffi-kernel/tvm-ffi*) SILU_MUL_PATH=$(echo silu-and-mul-kernel/torch*) RELU_CPU_PATH=$(echo relu-kernel-cpu/torch*) -PYTHONPATH="$EXTRA_DATA_PATH:$RELU_PATH:$RELU_TVM_FFI_PATH:$CUTLASS_PATH:$PYTHONPATH" \ - .venv/bin/pytest extra_data_tests relu_tests relu_tvm_ffi_tests cutlass_gemm_tests +PYTHONPATH="$EXTRA_DATA_PATH:$RELU_PATH:$RELU_TVM_FFI_PATH:$CUTLASS_PATH:$CUTLASS_TVM_FFI_PATH:$PYTHONPATH" \ + .venv/bin/pytest extra_data_tests relu_tests relu_tvm_ffi_tests cutlass_gemm_tests cutlass_gemm_tvm_ffi_tests # We only care about importing, the kernel is trivial. PYTHONPATH="$SILU_MUL_PATH:$PYTHONPATH" \