Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/build_kernel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ jobs:
- name: Copy cutlass GEMM kernel
run: cp -rL examples/kernels/cutlass-gemm/result cutlass-gemm-kernel

- name: Build cutlass-gemm-tvm-ffi kernel
run: ( cd examples/kernels/cutlass-gemm-tvm-ffi && nix build .\#redistributable.tvm-ffi01-cu126-${{ matrix.arch }} )
- name: Copy cutlass-gemm-tvm-ffi kernel
run: cp -rL examples/kernels/cutlass-gemm-tvm-ffi/result cutlass-gemm-tvm-ffi-kernel

- name: Build relu-backprop-compile kernel
run: ( cd examples/kernels/relu-backprop-compile && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
- name: Copy relu-backprop-compile kernel
Expand Down Expand Up @@ -91,6 +96,7 @@ jobs:
path: |
activation-kernel
cutlass-gemm-kernel
cutlass-gemm-tvm-ffi-kernel
extra-data
relu-kernel
relu-tvm-ffi-kernel
Expand Down
25 changes: 25 additions & 0 deletions examples/kernels/cutlass-gemm-tvm-ffi/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[general]
name = "cutlass-gemm-tvm-ffi"
backends = [
"cuda",
"xpu",
]

[tvm-ffi]
src = [
"tvm-ffi-ext/tvm_ffi_binding.cpp",
]

[kernel.gemm]
backend = "cuda"
depends = [
"cutlass_3_6",
]
src = ["gemm.cu", "util.hh"]

[kernel.gemm_xpu]
backend = "xpu"
depends = [
"sycl_tla",
]
src = ["gemm_sycl.cpp", "util.hh"]
17 changes: 17 additions & 0 deletions examples/kernels/cutlass-gemm-tvm-ffi/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
description = "Flake for CUTLASS gemm tvm-ffi test kernel";

inputs = {
kernel-builder.url = "path:../../..";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genKernelFlakeOutputs {
inherit self;
path = ./.;
};
}
58 changes: 58 additions & 0 deletions examples/kernels/cutlass-gemm-tvm-ffi/gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <cutlass/gemm/device/gemm.h>
#include <tvm/ffi/tvm_ffi.h>
#include <tvm/ffi/extra/cuda/device_guard.h>
#include <tvm/ffi/extra/c_env_api.h>

#include "util.hh"

using namespace tvm;

void cutlass_gemm(ffi::TensorView out, ffi::TensorView const A, ffi::TensorView const B) {
CHECK_INPUT_CUDA(A);
CHECK_INPUT_CUDA(B);
CHECK_INPUT_CUDA(out);
CHECK_DEVICE(A, out);
CHECK_DEVICE(B, out);

TVM_FFI_CHECK(A.dtype() == dl_float32, TypeError) << "A must be float32";
TVM_FFI_CHECK(B.dtype() == dl_float32, TypeError) << "B must be float32";
TVM_FFI_CHECK(out.dtype() == dl_float32, TypeError) << "out must be float32";

TVM_FFI_CHECK(A.ndim() == 2, ValueError) << "A must be 2D";
TVM_FFI_CHECK(B.ndim() == 2, ValueError) << "B must be 2D";
TVM_FFI_CHECK(out.ndim() == 2, ValueError) << "out must be 2D";

ffi::CUDADeviceGuard guard(A.device().device_id);
cudaStream_t stream = static_cast<cudaStream_t>(
TVMFFIEnvGetStream(A.device().device_type, A.device().device_id));

// Define the GEMM operation
using Gemm = cutlass::gemm::device::Gemm<float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;

// Create a GEMM object
Gemm gemm_op;

// Define the problem size
cutlass::gemm::GemmCoord problem_size(A.size(0), B.size(1), A.size(1));

// Define the arguments for the GEMM operation
typename Gemm::Arguments args(
problem_size,
{static_cast<float*>(A.data_ptr()), static_cast<int>(A.size(1))},
{static_cast<float*>(B.data_ptr()), static_cast<int>(B.size(1))},
{static_cast<float*>(out.data_ptr()), static_cast<int>(out.size(1))},
{static_cast<float*>(out.data_ptr()), static_cast<int>(out.size(1))},
{1.0f, 0.0f}
);

// Launch the GEMM operation
cutlass::Status status = gemm_op(args, nullptr, stream);

// Check for errors
if (status != cutlass::Status::kSuccess) {
TVM_FFI_THROW(RuntimeError) << "CUTLASS GEMM operation failed: "
<< cutlassGetStatusString(status);
}
}
174 changes: 174 additions & 0 deletions examples/kernels/cutlass-gemm-tvm-ffi/gemm_sycl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief CUTLASS Intel BMG Gemm Example.

This example constructs and executes a simple CUTLASS GEMM kernel on Intel BMG hardware, and
verifies its correctness with a reference implementation
(cutlass::reference::device::GemmComplex). The example also provides a performance measurement
for the GEMM in TFLOPS.

This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions.

The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the
batch size is defined by `options.l`. The tile shape, which defines how much work is executed by
a single work-group, is defined at compile time by:
```
using TileShape = Shape<_256, _256, _32>;
```
That is, each work-group processes a tile of M=256, N=256, and iterates over `options.k` in
blocks of K=32.

Performance of GEMM on BMG is heavily dependent on prefetching the A and B matrices. That is,
executing Intel specific prefetch instructions for future iterations to ensure that the required
blocks of A and B are resident in cache before they are needed.

To build & run this example (from your build dir):

$ ninja 00_bmg_gemm
$ ./examples/sycl/00_bmg_gemm/00_bmg_gemm

Call with `--help` for information about available options
*/

#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/collective/xe_epilogue.hpp"
#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/collective/collective_mma.hpp"
#include "cutlass/util/GPU_Clock.hpp"

#include <cute/tensor.hpp>
#include <random>

#include "cutlass/util/command_line.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include <tvm/ffi/tvm_ffi.h>
using namespace cute;
using namespace tvm;

#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
TVM_FFI_THROW(RuntimeError) << "Got cutlass error: " << cutlassGetStatusString(error) \
<< " at: " << __LINE__; \
} \
}

void cutlass_gemm(ffi::TensorView out, ffi::TensorView const A, ffi::TensorView const B) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = bfloat16_t;
using ElementInputB = bfloat16_t;
using ElementOutput = float;

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
using TileShape = Shape<_256, _256, _32>;
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue, ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape, decltype(tile_shape(TiledMma()))>;
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
EpilogueDispatchPolicy,
TileShape,
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutC>,
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>,
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
void, void,
XE_2D_U32x8x16_ST_N,
void, void>;
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
TileShape,
ElementInputA,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementInputB,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA, void, void, cute::identity,
GmemTiledCopyB, void, void, cute::identity>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

cutlass::KernelHardwareInfo hw_info;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

// get shape
int M = A.size(0);
int K = A.size(1);
int N = B.size(1);
int L = 1; // batch size

auto stride_A = cutlass::make_cute_packed_stride(GemmKernel::StrideA{}, cute::make_shape(M, K, L));
auto stride_B = cutlass::make_cute_packed_stride(GemmKernel::StrideB{}, cute::make_shape(N, K, L));
auto stride_C = cutlass::make_cute_packed_stride(GemmKernel::StrideC{}, cute::make_shape(M, N, L));
auto stride_D = cutlass::make_cute_packed_stride(GemmKernel::StrideD{}, cute::make_shape(M, N, L));

GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
GemmKernel::ProblemShape{M, N, K, L},
{reinterpret_cast<ElementInputA*>(A.data_ptr()), stride_A, reinterpret_cast<ElementInputB*>(B.data_ptr()), stride_B},
{{1.0f, 0.0f}, reinterpret_cast<ElementOutput*>(out.data_ptr()), stride_C, reinterpret_cast<ElementOutput*>(out.data_ptr()), stride_D},
hw_info
};

Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

TVM_FFI_CHECK(gemm_op.can_implement(arguments) == cutlass::Status::kSuccess, ValueError) << "Invalid GEMM problem size or configuration";
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_op.run());
#if defined(OLD_API)
syclcompat::wait();
#else
compat::wait();
#endif
}
12 changes: 12 additions & 0 deletions examples/kernels/cutlass-gemm-tvm-ffi/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest
import torch


@pytest.fixture(scope="session")
def device() -> torch.device:
if hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu")
elif torch.version.cuda is not None and torch.cuda.is_available():
return torch.device("cuda")
else:
pytest.skip("Neither CUDA nor XPU device available")
12 changes: 12 additions & 0 deletions examples/kernels/cutlass-gemm-tvm-ffi/tests/test_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from cutlass_gemm_tvm_ffi import cutlass_gemm


def test_gemm(device):
A = torch.randn((64, 32), device=device, dtype=torch.float32)
B = torch.randn((32, 64), device=device, dtype=torch.float32)
out = torch.zeros((64, 64), device=device, dtype=torch.float32)

cutlass_gemm(out, A, B)

torch.testing.assert_close(out, torch.mm(A, B))
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import tvm_ffi

from ._ops import ops


def cutlass_gemm(
out: tvm_ffi.Tensor, A: tvm_ffi.Tensor, B: tvm_ffi.Tensor
) -> tvm_ffi.Tensor:
device = A.device
if device.type == "cuda":
ops.cutlass_gemm(out, A, B)
elif device.type == "xpu":
ops.cutlass_gemm(out, A, B)
else:
raise NotImplementedError(f"Unsupported device type: {device.type}")
return out


__all__ = ["cutlass_gemm"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <tvm/ffi/tvm_ffi.h>

void cutlass_gemm(tvm::ffi::TensorView out, tvm::ffi::TensorView const A, tvm::ffi::TensorView const B);

#if defined(CUDA_KERNEL) || defined(XPU_KERNEL)
TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_gemm, cutlass_gemm);
#endif
33 changes: 33 additions & 0 deletions examples/kernels/cutlass-gemm-tvm-ffi/util.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <tvm/ffi/tvm_ffi.h>

#define CHECK_CUDA(x) \
TVM_FFI_CHECK((x).device().device_type == kDLCUDA, ValueError) << #x " must be a CUDA tensor"
#define CHECK_CONTIGUOUS(x) \
TVM_FFI_CHECK((x).IsContiguous(), ValueError) << #x " must be contiguous"
#define CHECK_INPUT(x) \
do { \
CHECK_CONTIGUOUS(x); \
} while (0)
#define CHECK_INPUT_CUDA(x) \
do { \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x); \
} while (0)
#define CHECK_DEVICE(a, b) \
do { \
TVM_FFI_CHECK((a).device().device_type == (b).device().device_type, ValueError) \
<< #a " and " #b " must be on the same device type"; \
TVM_FFI_CHECK((a).device().device_id == (b).device().device_id, ValueError) \
<< #a " and " #b " must be on the same device"; \
} while (0)
#define CHECK_XPU(x) \
TVM_FFI_CHECK((x).device().device_type == kDLOneAPI, ValueError) << #x " must be an XPU tensor"
#define CHECK_INPUT_XPU(x) \
do { \
CHECK_XPU(x); \
CHECK_CONTIGUOUS(x); \
} while (0)

constexpr DLDataType dl_float32 = DLDataType{kDLFloat, 32, 1};


Loading
Loading