diff --git a/.github/workflows/build_kernel_macos.yaml b/.github/workflows/build_kernel_macos.yaml index ceb966dd..2c6b1733 100644 --- a/.github/workflows/build_kernel_macos.yaml +++ b/.github/workflows/build_kernel_macos.yaml @@ -26,3 +26,6 @@ jobs: # kernels. Also run tests once we have a macOS runner. - name: Build relu kernel run: ( cd examples/relu && nix build .\#redistributable.torch29-metal-aarch64-darwin -L ) + + - name: Build relu metal cpp kernel + run: ( cd examples/relu-metal-cpp && nix build .\#redistributable.torch29-metal-aarch64-darwin -L ) \ No newline at end of file diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index ecbdd9ec..0f8457e9 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -247,6 +247,8 @@ pub enum Dependencies { Cutlass4_0, #[serde(rename = "cutlass_sycl")] CutlassSycl, + #[serde(rename = "metal-cpp")] + MetalCpp, Torch, } diff --git a/examples/relu-metal-cpp/build.toml b/examples/relu-metal-cpp/build.toml new file mode 100644 index 00000000..e0d6d487 --- /dev/null +++ b/examples/relu-metal-cpp/build.toml @@ -0,0 +1,20 @@ +[general] +name = "relu" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] + + +[kernel.relu_metal] +backend = "metal" +src = [ + "relu/relu.cpp", + "relu/metallib_loader.mm", + "relu/relu_cpp.metal", + "relu/common.h", +] +depends = [ "torch", "metal-cpp" ] \ No newline at end of file diff --git a/examples/relu-metal-cpp/flake.nix b/examples/relu-metal-cpp/flake.nix new file mode 100644 index 00000000..758fb200 --- /dev/null +++ b/examples/relu-metal-cpp/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for ReLU metal cpp kernel"; + + inputs = { + kernel-builder.url = "path:../.."; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + inherit self; + path = ./.; + }; +} diff --git a/examples/relu-metal-cpp/relu/common.h b/examples/relu-metal-cpp/relu/common.h new file mode 100644 index 00000000..a981a521 --- /dev/null +++ b/examples/relu-metal-cpp/relu/common.h @@ -0,0 +1,7 @@ +#pragma once + +#include +using namespace metal; + +// Common constants and utilities for Metal kernels +constant float RELU_THRESHOLD = 0.0f; diff --git a/examples/relu-metal-cpp/relu/metallib_loader.mm b/examples/relu-metal-cpp/relu/metallib_loader.mm new file mode 100644 index 00000000..9e63d909 --- /dev/null +++ b/examples/relu-metal-cpp/relu/metallib_loader.mm @@ -0,0 +1,40 @@ +#import +#include +#include + +#ifdef EMBEDDED_METALLIB_HEADER +#include EMBEDDED_METALLIB_HEADER +#else +#error "EMBEDDED_METALLIB_HEADER not defined" +#endif + +// C++ interface to load the embedded metallib without exposing ObjC types +extern "C" { + void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg) { + id mtlDevice = (__bridge id)device; + NSError* error = nil; + + id library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(mtlDevice, &error); + + if (!library && errorMsg && error) { + *errorMsg = strdup([error.localizedDescription UTF8String]); + } + + // Manually retain since we're not using ARC + // The caller will wrap in NS::TransferPtr which assumes ownership + if (library) { + [library retain]; + } + return (__bridge void*)library; + } + + // Get PyTorch's MPS device (returns id as void*) + void* getMPSDevice() { + return (__bridge void*)at::mps::MPSDevice::getInstance()->device(); + } + + // Get PyTorch's current MPS command queue (returns id as void*) + void* getMPSCommandQueue() { + return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue(); + } +} diff --git a/examples/relu-metal-cpp/relu/relu.cpp b/examples/relu-metal-cpp/relu/relu.cpp new file mode 100644 index 00000000..85b9ac57 --- /dev/null +++ b/examples/relu-metal-cpp/relu/relu.cpp @@ -0,0 +1,118 @@ +#define NS_PRIVATE_IMPLEMENTATION +#define MTL_PRIVATE_IMPLEMENTATION + +// Include metal-cpp headers from system +#include +#include + +#include + +// C interface from metallib_loader.mm +extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg); +extern "C" void* getMPSDevice(); +extern "C" void* getMPSCommandQueue(); + +namespace { + +MTL::Buffer* getMTLBuffer(const torch::Tensor& tensor) { + return reinterpret_cast(const_cast(tensor.storage().data())); +} + +NS::String* makeNSString(const std::string& value) { + return NS::String::string(value.c_str(), NS::StringEncoding::UTF8StringEncoding); +} + +MTL::Library* loadLibrary(MTL::Device* device) { + const char* errorMsg = nullptr; + void* library = loadEmbeddedMetalLibrary(reinterpret_cast(device), &errorMsg); + + TORCH_CHECK(library != nullptr, "Failed to create Metal library from embedded data: ", + errorMsg ? errorMsg : "Unknown error"); + + if (errorMsg) { + free(const_cast(errorMsg)); + } + + return reinterpret_cast(library); +} + +} // namespace + +void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) { + // Use PyTorch's MPS device and command queue (these are borrowed references, not owned) + MTL::Device* device = reinterpret_cast(getMPSDevice()); + TORCH_CHECK(device != nullptr, "Failed to get MPS device"); + + MTL::CommandQueue* commandQueue = reinterpret_cast(getMPSCommandQueue()); + TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue"); + + MTL::Library* libraryPtr = reinterpret_cast(loadLibrary(device)); + NS::SharedPtr library = NS::TransferPtr(libraryPtr); + + const std::string kernelName = + std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half"); + NS::SharedPtr kernelNameString = NS::TransferPtr(makeNSString(kernelName)); + + NS::SharedPtr computeFunction = + NS::TransferPtr(library->newFunction(kernelNameString.get())); + TORCH_CHECK(computeFunction.get() != nullptr, "Failed to create Metal function for ", kernelName); + + NS::Error* pipelineError = nullptr; + NS::SharedPtr pipelineState = + NS::TransferPtr(device->newComputePipelineState(computeFunction.get(), &pipelineError)); + TORCH_CHECK(pipelineState.get() != nullptr, + "Failed to create compute pipeline state: ", + pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error"); + + // Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue + MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer(); + TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer"); + + MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder(); + TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder"); + + encoder->setComputePipelineState(pipelineState.get()); + + auto* inputBuffer = getMTLBuffer(input); + auto* outputBuffer = getMTLBuffer(output); + TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null"); + TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null"); + + encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0); + encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1); + + const NS::UInteger totalThreads = input.numel(); + NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup(); + if (threadGroupSize > totalThreads) { + threadGroupSize = totalThreads; + } + + const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1); + const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1); + + encoder->dispatchThreads(gridSize, threadsPerThreadgroup); + encoder->endEncoding(); + + commandBuffer->commit(); +} + +void relu(torch::Tensor& out, const torch::Tensor& input) { + TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.scalar_type() == torch::kFloat || input.scalar_type() == torch::kHalf, + "Unsupported data type: ", input.scalar_type()); + + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + + dispatchReluKernel(input, out); +} diff --git a/examples/relu-metal-cpp/relu/relu_cpp.metal b/examples/relu-metal-cpp/relu/relu_cpp.metal new file mode 100644 index 00000000..969ec170 --- /dev/null +++ b/examples/relu-metal-cpp/relu/relu_cpp.metal @@ -0,0 +1,17 @@ +#include +#include "common.h" +using namespace metal; + +kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]], + device float *outC [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + // Explicitly write to output + outC[index] = max(RELU_THRESHOLD, inA[index]); +} + +kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]], + device half *outC [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + // Explicitly write to output + outC[index] = max(static_cast(0.0), inA[index]); +} diff --git a/examples/relu-metal-cpp/tests/__init__.py b/examples/relu-metal-cpp/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/relu-metal-cpp/tests/test_relu.py b/examples/relu-metal-cpp/tests/test_relu.py new file mode 100644 index 00000000..65544aa4 --- /dev/null +++ b/examples/relu-metal-cpp/tests/test_relu.py @@ -0,0 +1,19 @@ +import platform + +import torch +import torch.nn.functional as F + +import relu + + +def test_relu(): + if platform.system() == "Darwin": + device = torch.device("mps") + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + elif torch.version.cuda is not None and torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + x = torch.randn(1024, 1024, dtype=torch.float32, device=device) + torch.testing.assert_allclose(F.relu(x), relu.relu(x)) diff --git a/examples/relu-metal-cpp/torch-ext/relu/__init__.py b/examples/relu-metal-cpp/torch-ext/relu/__init__.py new file mode 100644 index 00000000..8050dfd7 --- /dev/null +++ b/examples/relu-metal-cpp/torch-ext/relu/__init__.py @@ -0,0 +1,12 @@ +from typing import Optional + +import torch + +from ._ops import ops + + +def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + if out is None: + out = torch.empty_like(x) + ops.relu(out, x) + return out \ No newline at end of file diff --git a/examples/relu-metal-cpp/torch-ext/torch_binding.cpp b/examples/relu-metal-cpp/torch-ext/torch_binding.cpp new file mode 100644 index 00000000..1765d92d --- /dev/null +++ b/examples/relu-metal-cpp/torch-ext/torch_binding.cpp @@ -0,0 +1,19 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("relu(Tensor! out, Tensor input) -> ()"); +#if defined(CPU_KERNEL) + ops.impl("relu", torch::kCPU, &relu); +#elif defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("relu", torch::kCUDA, &relu); +#elif defined(METAL_KERNEL) + ops.impl("relu", torch::kMPS, relu); +#elif defined(XPU_KERNEL) + ops.impl("relu", torch::kXPU, &relu); +#endif +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/examples/relu-metal-cpp/torch-ext/torch_binding.h b/examples/relu-metal-cpp/torch-ext/torch_binding.h new file mode 100644 index 00000000..3bcf2904 --- /dev/null +++ b/examples/relu-metal-cpp/torch-ext/torch_binding.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +void relu(torch::Tensor &out, torch::Tensor const &input); \ No newline at end of file diff --git a/flake.lock b/flake.lock index 8ab650e2..006f1995 100644 --- a/flake.lock +++ b/flake.lock @@ -17,11 +17,11 @@ }, "flake-compat_2": { "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "owner": "edolstra", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "type": "github" }, "original": { @@ -73,11 +73,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1762504832, - "narHash": "sha256-PIxh2ZFqq3CAkQNtupT0AfxA1n3raM/3enDLHn4a21k=", + "lastModified": 1762833384, + "narHash": "sha256-YPLcoqJBlYYGgmTG/J7bDx7lHqPnJ8beLs0+9OLfFhM=", "owner": "huggingface", "repo": "hf-nix", - "rev": "3267e738faafa71bed7de9b75d74f6a90ec1bc57", + "rev": "752645bcda8793906249809319fa9b8dc11d7af6", "type": "github" }, "original": { @@ -88,11 +88,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1762168314, - "narHash": "sha256-+DX6mIF47gRGoK0mqkTg1Jmcjcup0CAXJFHVkdUx8YA=", + "lastModified": 1762764791, + "narHash": "sha256-mWl8rYSYDFWD+zCR0VkBjEjD9jYj1/nlkDOfNNu44NA=", "owner": "nixos", "repo": "nixpkgs", - "rev": "94fc102d2c15d9c1a861e59de550807c65358e1b", + "rev": "b549734f6b3ec54bb9a611a4185d11ee31f52ee1", "type": "github" }, "original": { diff --git a/lib/deps.nix b/lib/deps.nix index 9e8c6c81..0da60528 100644 --- a/lib/deps.nix +++ b/lib/deps.nix @@ -33,6 +33,9 @@ let #torch.cxxdev ]; "cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ]; + "metal-cpp" = [ + pkgs.metal-cpp.dev + ]; }; in let