From a844fa0a98e0f442c9edaee979e775cc0363dfbc Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 10 Nov 2025 10:23:43 -0500 Subject: [PATCH 01/10] feat: support metal cpp --- .github/workflows/build_kernel_macos.yaml | 3 + build2cmake/src/config/v2.rs | 2 + examples/relu-metal-cpp/build.toml | 20 +++ examples/relu-metal-cpp/common.h | 10 ++ examples/relu-metal-cpp/flake.lock | 164 ++++++++++++++++++ examples/relu-metal-cpp/flake.nix | 17 ++ examples/relu-metal-cpp/metallib_loader.mm | 41 +++++ examples/relu-metal-cpp/relu.cpp | 119 +++++++++++++ examples/relu-metal-cpp/relu_cpp.metal | 17 ++ examples/relu-metal-cpp/result | 1 + .../relu-metal-cpp/torch-ext/relu/__init__.py | 12 ++ .../torch-ext/torch_binding.cpp | 19 ++ .../relu-metal-cpp/torch-ext/torch_binding.h | 5 + examples/relu/flake.lock | 164 ++++++++++++++++++ examples/relu/result | 1 + lib/deps.nix | 3 + lib/torch-extension/arch.nix | 1 + 17 files changed, 599 insertions(+) create mode 100644 examples/relu-metal-cpp/build.toml create mode 100644 examples/relu-metal-cpp/common.h create mode 100644 examples/relu-metal-cpp/flake.lock create mode 100644 examples/relu-metal-cpp/flake.nix create mode 100644 examples/relu-metal-cpp/metallib_loader.mm create mode 100644 examples/relu-metal-cpp/relu.cpp create mode 100644 examples/relu-metal-cpp/relu_cpp.metal create mode 120000 examples/relu-metal-cpp/result create mode 100644 examples/relu-metal-cpp/torch-ext/relu/__init__.py create mode 100644 examples/relu-metal-cpp/torch-ext/torch_binding.cpp create mode 100644 examples/relu-metal-cpp/torch-ext/torch_binding.h create mode 100644 examples/relu/flake.lock create mode 120000 examples/relu/result diff --git a/.github/workflows/build_kernel_macos.yaml b/.github/workflows/build_kernel_macos.yaml index ceb966dd..2c6b1733 100644 --- a/.github/workflows/build_kernel_macos.yaml +++ b/.github/workflows/build_kernel_macos.yaml @@ -26,3 +26,6 @@ jobs: # kernels. Also run tests once we have a macOS runner. - name: Build relu kernel run: ( cd examples/relu && nix build .\#redistributable.torch29-metal-aarch64-darwin -L ) + + - name: Build relu metal cpp kernel + run: ( cd examples/relu-metal-cpp && nix build .\#redistributable.torch29-metal-aarch64-darwin -L ) \ No newline at end of file diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index ecbdd9ec..0f8457e9 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -247,6 +247,8 @@ pub enum Dependencies { Cutlass4_0, #[serde(rename = "cutlass_sycl")] CutlassSycl, + #[serde(rename = "metal-cpp")] + MetalCpp, Torch, } diff --git a/examples/relu-metal-cpp/build.toml b/examples/relu-metal-cpp/build.toml new file mode 100644 index 00000000..8fb84793 --- /dev/null +++ b/examples/relu-metal-cpp/build.toml @@ -0,0 +1,20 @@ +[general] +name = "relu" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] + + +[kernel.relu_metal] +backend = "metal" +src = [ + "relu.cpp", + "metallib_loader.mm", + "relu_cpp.metal", + "common.h", +] +depends = [ "torch", "metal-cpp" ] \ No newline at end of file diff --git a/examples/relu-metal-cpp/common.h b/examples/relu-metal-cpp/common.h new file mode 100644 index 00000000..1b891fad --- /dev/null +++ b/examples/relu-metal-cpp/common.h @@ -0,0 +1,10 @@ +#ifndef COMMON_H +#define COMMON_H + +#include +using namespace metal; + +// Common constants and utilities for Metal kernels +constant float RELU_THRESHOLD = 0.0f; + +#endif // COMMON_H \ No newline at end of file diff --git a/examples/relu-metal-cpp/flake.lock b/examples/relu-metal-cpp/flake.lock new file mode 100644 index 00000000..0182bc37 --- /dev/null +++ b/examples/relu-metal-cpp/flake.lock @@ -0,0 +1,164 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1762504832, + "narHash": "sha256-PIxh2ZFqq3CAkQNtupT0AfxA1n3raM/3enDLHn4a21k=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "3267e738faafa71bed7de9b75d74f6a90ec1bc57", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "path": "../..", + "type": "path" + }, + "original": { + "path": "../..", + "type": "path" + }, + "parent": [] + }, + "nixpkgs": { + "locked": { + "lastModified": 1762168314, + "narHash": "sha256-+DX6mIF47gRGoK0mqkTg1Jmcjcup0CAXJFHVkdUx8YA=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "94fc102d2c15d9c1a861e59de550807c65358e1b", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/examples/relu-metal-cpp/flake.nix b/examples/relu-metal-cpp/flake.nix new file mode 100644 index 00000000..758fb200 --- /dev/null +++ b/examples/relu-metal-cpp/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for ReLU metal cpp kernel"; + + inputs = { + kernel-builder.url = "path:../.."; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + inherit self; + path = ./.; + }; +} diff --git a/examples/relu-metal-cpp/metallib_loader.mm b/examples/relu-metal-cpp/metallib_loader.mm new file mode 100644 index 00000000..050ee791 --- /dev/null +++ b/examples/relu-metal-cpp/metallib_loader.mm @@ -0,0 +1,41 @@ +#import +#import +#include +#include + +#ifdef EMBEDDED_METALLIB_HEADER +#include EMBEDDED_METALLIB_HEADER +#else +#error "EMBEDDED_METALLIB_HEADER not defined" +#endif + +// C++ interface to load the embedded metallib without exposing ObjC types +extern "C" { + void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg) { + id mtlDevice = (__bridge id)device; + NSError* error = nil; + + id library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(mtlDevice, &error); + + if (!library && errorMsg && error) { + *errorMsg = strdup([error.localizedDescription UTF8String]); + } + + // Manually retain since we're not using ARC + // The caller will wrap in NS::TransferPtr which assumes ownership + if (library) { + [library retain]; + } + return (__bridge void*)library; + } + + // Get PyTorch's MPS device (returns id as void*) + void* getMPSDevice() { + return (__bridge void*)at::mps::MPSDevice::getInstance()->device(); + } + + // Get PyTorch's current MPS command queue (returns id as void*) + void* getMPSCommandQueue() { + return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue(); + } +} diff --git a/examples/relu-metal-cpp/relu.cpp b/examples/relu-metal-cpp/relu.cpp new file mode 100644 index 00000000..c07ad544 --- /dev/null +++ b/examples/relu-metal-cpp/relu.cpp @@ -0,0 +1,119 @@ +#define NS_PRIVATE_IMPLEMENTATION +#define MTL_PRIVATE_IMPLEMENTATION + +// Include metal-cpp headers from system +#include +#include +#include + +#include + +// C interface from metallib_loader.mm +extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg); +extern "C" void* getMPSDevice(); +extern "C" void* getMPSCommandQueue(); + +namespace { + +MTL::Buffer* getMTLBuffer(const torch::Tensor& tensor) { + return reinterpret_cast(const_cast(tensor.storage().data())); +} + +NS::String* makeNSString(const std::string& value) { + return NS::String::string(value.c_str(), NS::StringEncoding::UTF8StringEncoding); +} + +MTL::Library* loadLibrary(MTL::Device* device) { + const char* errorMsg = nullptr; + void* library = loadEmbeddedMetalLibrary(reinterpret_cast(device), &errorMsg); + + TORCH_CHECK(library != nullptr, "Failed to create Metal library from embedded data: ", + errorMsg ? errorMsg : "Unknown error"); + + if (errorMsg) { + free(const_cast(errorMsg)); + } + + return reinterpret_cast(library); +} + +} // namespace + +void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) { + // Use PyTorch's MPS device and command queue (these are borrowed references, not owned) + MTL::Device* device = reinterpret_cast(getMPSDevice()); + TORCH_CHECK(device != nullptr, "Failed to get MPS device"); + + MTL::CommandQueue* commandQueue = reinterpret_cast(getMPSCommandQueue()); + TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue"); + + MTL::Library* libraryPtr = reinterpret_cast(loadLibrary(device)); + NS::SharedPtr library = NS::TransferPtr(libraryPtr); + + const std::string kernelName = + std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half"); + NS::SharedPtr kernelNameString = NS::TransferPtr(makeNSString(kernelName)); + + NS::SharedPtr computeFunction = + NS::TransferPtr(library->newFunction(kernelNameString.get())); + TORCH_CHECK(computeFunction.get() != nullptr, "Failed to create Metal function for ", kernelName); + + NS::Error* pipelineError = nullptr; + NS::SharedPtr pipelineState = + NS::TransferPtr(device->newComputePipelineState(computeFunction.get(), &pipelineError)); + TORCH_CHECK(pipelineState.get() != nullptr, + "Failed to create compute pipeline state: ", + pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error"); + + // Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue + MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer(); + TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer"); + + MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder(); + TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder"); + + encoder->setComputePipelineState(pipelineState.get()); + + auto* inputBuffer = getMTLBuffer(input); + auto* outputBuffer = getMTLBuffer(output); + TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null"); + TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null"); + + encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0); + encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1); + + const NS::UInteger totalThreads = input.numel(); + NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup(); + if (threadGroupSize > totalThreads) { + threadGroupSize = totalThreads; + } + + const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1); + const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1); + + encoder->dispatchThreads(gridSize, threadsPerThreadgroup); + encoder->endEncoding(); + + commandBuffer->commit(); +} + +void relu(torch::Tensor& out, const torch::Tensor& input) { + TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.scalar_type() == torch::kFloat || input.scalar_type() == torch::kHalf, + "Unsupported data type: ", input.scalar_type()); + + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + + dispatchReluKernel(input, out); +} diff --git a/examples/relu-metal-cpp/relu_cpp.metal b/examples/relu-metal-cpp/relu_cpp.metal new file mode 100644 index 00000000..969ec170 --- /dev/null +++ b/examples/relu-metal-cpp/relu_cpp.metal @@ -0,0 +1,17 @@ +#include +#include "common.h" +using namespace metal; + +kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]], + device float *outC [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + // Explicitly write to output + outC[index] = max(RELU_THRESHOLD, inA[index]); +} + +kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]], + device half *outC [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + // Explicitly write to output + outC[index] = max(static_cast(0.0), inA[index]); +} diff --git a/examples/relu-metal-cpp/result b/examples/relu-metal-cpp/result new file mode 120000 index 00000000..d8bc7df4 --- /dev/null +++ b/examples/relu-metal-cpp/result @@ -0,0 +1 @@ +/nix/store/khawpdnqpl0c14h14gg2g2abvxqclxp0-torch-ext-bundle \ No newline at end of file diff --git a/examples/relu-metal-cpp/torch-ext/relu/__init__.py b/examples/relu-metal-cpp/torch-ext/relu/__init__.py new file mode 100644 index 00000000..8050dfd7 --- /dev/null +++ b/examples/relu-metal-cpp/torch-ext/relu/__init__.py @@ -0,0 +1,12 @@ +from typing import Optional + +import torch + +from ._ops import ops + + +def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + if out is None: + out = torch.empty_like(x) + ops.relu(out, x) + return out \ No newline at end of file diff --git a/examples/relu-metal-cpp/torch-ext/torch_binding.cpp b/examples/relu-metal-cpp/torch-ext/torch_binding.cpp new file mode 100644 index 00000000..1765d92d --- /dev/null +++ b/examples/relu-metal-cpp/torch-ext/torch_binding.cpp @@ -0,0 +1,19 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("relu(Tensor! out, Tensor input) -> ()"); +#if defined(CPU_KERNEL) + ops.impl("relu", torch::kCPU, &relu); +#elif defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("relu", torch::kCUDA, &relu); +#elif defined(METAL_KERNEL) + ops.impl("relu", torch::kMPS, relu); +#elif defined(XPU_KERNEL) + ops.impl("relu", torch::kXPU, &relu); +#endif +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/examples/relu-metal-cpp/torch-ext/torch_binding.h b/examples/relu-metal-cpp/torch-ext/torch_binding.h new file mode 100644 index 00000000..3bcf2904 --- /dev/null +++ b/examples/relu-metal-cpp/torch-ext/torch_binding.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +void relu(torch::Tensor &out, torch::Tensor const &input); \ No newline at end of file diff --git a/examples/relu/flake.lock b/examples/relu/flake.lock new file mode 100644 index 00000000..0182bc37 --- /dev/null +++ b/examples/relu/flake.lock @@ -0,0 +1,164 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1762504832, + "narHash": "sha256-PIxh2ZFqq3CAkQNtupT0AfxA1n3raM/3enDLHn4a21k=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "3267e738faafa71bed7de9b75d74f6a90ec1bc57", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "path": "../..", + "type": "path" + }, + "original": { + "path": "../..", + "type": "path" + }, + "parent": [] + }, + "nixpkgs": { + "locked": { + "lastModified": 1762168314, + "narHash": "sha256-+DX6mIF47gRGoK0mqkTg1Jmcjcup0CAXJFHVkdUx8YA=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "94fc102d2c15d9c1a861e59de550807c65358e1b", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/examples/relu/result b/examples/relu/result new file mode 120000 index 00000000..49c0a773 --- /dev/null +++ b/examples/relu/result @@ -0,0 +1 @@ +/nix/store/5p3508x9iq4w69bwad427i99b7k0f5pj-torch-ext-bundle \ No newline at end of file diff --git a/lib/deps.nix b/lib/deps.nix index 9e8c6c81..2fb4e4fa 100644 --- a/lib/deps.nix +++ b/lib/deps.nix @@ -33,6 +33,9 @@ let #torch.cxxdev ]; "cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ]; + "metal-cpp" = lib.optionals pkgs.stdenv.hostPlatform.isDarwin [ + pkgs.metal-cpp.dev + ]; }; in let diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index 6e907029..34998cd0 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -30,6 +30,7 @@ oneapi-torch-dev, onednn-xpu, torch, + metal-cpp, }: { From 8632abccd08c0da6529c309d3252580e97af6824 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 10 Nov 2025 10:31:53 -0500 Subject: [PATCH 02/10] fix: adjust dep ordering --- lib/torch-extension/arch.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index 34998cd0..3a436198 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -27,10 +27,10 @@ # Build inputs apple-sdk_26, clr, + metal-cpp, oneapi-torch-dev, onednn-xpu, torch, - metal-cpp, }: { From 988ebd7140629c4ec591ed677784ac4028216b47 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 10 Nov 2025 15:05:49 -0500 Subject: [PATCH 03/10] fix: rebuild with latest metal cpp header url --- examples/relu-metal-cpp/flake.lock | 19 ++++++++++--------- flake.nix | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/relu-metal-cpp/flake.lock b/examples/relu-metal-cpp/flake.lock index 0182bc37..13d0cd41 100644 --- a/examples/relu-metal-cpp/flake.lock +++ b/examples/relu-metal-cpp/flake.lock @@ -17,11 +17,11 @@ }, "flake-compat_2": { "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "owner": "edolstra", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "type": "github" }, "original": { @@ -73,15 +73,16 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1762504832, - "narHash": "sha256-PIxh2ZFqq3CAkQNtupT0AfxA1n3raM/3enDLHn4a21k=", + "lastModified": 1762804970, + "narHash": "sha256-YPLcoqJBlYYGgmTG/J7bDx7lHqPnJ8beLs0+9OLfFhM=", "owner": "huggingface", "repo": "hf-nix", - "rev": "3267e738faafa71bed7de9b75d74f6a90ec1bc57", + "rev": "389aa480b495775803cf42894dc119299460016d", "type": "github" }, "original": { "owner": "huggingface", + "ref": "bump-metal-cpp-version", "repo": "hf-nix", "type": "github" } @@ -109,11 +110,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1762168314, - "narHash": "sha256-+DX6mIF47gRGoK0mqkTg1Jmcjcup0CAXJFHVkdUx8YA=", + "lastModified": 1762764791, + "narHash": "sha256-mWl8rYSYDFWD+zCR0VkBjEjD9jYj1/nlkDOfNNu44NA=", "owner": "nixos", "repo": "nixpkgs", - "rev": "94fc102d2c15d9c1a861e59de550807c65358e1b", + "rev": "b549734f6b3ec54bb9a611a4185d11ee31f52ee1", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 081180f2..d645d63a 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ flake-utils.url = "github:numtide/flake-utils"; nixpkgs.follows = "hf-nix/nixpkgs"; flake-compat.url = "github:edolstra/flake-compat"; - hf-nix.url = "github:huggingface/hf-nix"; + hf-nix.url = "github:huggingface/hf-nix/bump-metal-cpp-version"; }; outputs = From 92782fa27a46bd8557ac2c54178df862234cf74f Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 10 Nov 2025 22:58:38 -0500 Subject: [PATCH 04/10] fix: adjust hf-nix url to avoid branch --- examples/relu-metal-cpp/flake.lock | 5 ++--- flake.nix | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/relu-metal-cpp/flake.lock b/examples/relu-metal-cpp/flake.lock index 13d0cd41..4c779b54 100644 --- a/examples/relu-metal-cpp/flake.lock +++ b/examples/relu-metal-cpp/flake.lock @@ -73,16 +73,15 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1762804970, + "lastModified": 1762833384, "narHash": "sha256-YPLcoqJBlYYGgmTG/J7bDx7lHqPnJ8beLs0+9OLfFhM=", "owner": "huggingface", "repo": "hf-nix", - "rev": "389aa480b495775803cf42894dc119299460016d", + "rev": "752645bcda8793906249809319fa9b8dc11d7af6", "type": "github" }, "original": { "owner": "huggingface", - "ref": "bump-metal-cpp-version", "repo": "hf-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index d645d63a..081180f2 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ flake-utils.url = "github:numtide/flake-utils"; nixpkgs.follows = "hf-nix/nixpkgs"; flake-compat.url = "github:edolstra/flake-compat"; - hf-nix.url = "github:huggingface/hf-nix/bump-metal-cpp-version"; + hf-nix.url = "github:huggingface/hf-nix"; }; outputs = From 914ed28aeef63ed3f15e1792867d486b21e337d8 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 11 Nov 2025 09:52:37 -0500 Subject: [PATCH 05/10] fix: refactor example and remove unneeded files --- examples/relu-metal-cpp/build.toml | 8 +- examples/relu-metal-cpp/flake.lock | 164 ------------------ examples/relu-metal-cpp/{ => relu}/common.h | 0 .../{ => relu}/metallib_loader.mm | 0 examples/relu-metal-cpp/{ => relu}/relu.cpp | 0 .../relu-metal-cpp/{ => relu}/relu_cpp.metal | 0 examples/relu-metal-cpp/result | 1 - examples/relu-metal-cpp/tests/__init__.py | 0 examples/relu-metal-cpp/tests/test_relu.py | 19 ++ examples/relu/result | 1 - flake.lock | 18 +- test-kernel.py | 23 +++ 12 files changed, 55 insertions(+), 179 deletions(-) delete mode 100644 examples/relu-metal-cpp/flake.lock rename examples/relu-metal-cpp/{ => relu}/common.h (100%) rename examples/relu-metal-cpp/{ => relu}/metallib_loader.mm (100%) rename examples/relu-metal-cpp/{ => relu}/relu.cpp (100%) rename examples/relu-metal-cpp/{ => relu}/relu_cpp.metal (100%) delete mode 120000 examples/relu-metal-cpp/result create mode 100644 examples/relu-metal-cpp/tests/__init__.py create mode 100644 examples/relu-metal-cpp/tests/test_relu.py delete mode 120000 examples/relu/result create mode 100644 test-kernel.py diff --git a/examples/relu-metal-cpp/build.toml b/examples/relu-metal-cpp/build.toml index 8fb84793..e0d6d487 100644 --- a/examples/relu-metal-cpp/build.toml +++ b/examples/relu-metal-cpp/build.toml @@ -12,9 +12,9 @@ src = [ [kernel.relu_metal] backend = "metal" src = [ - "relu.cpp", - "metallib_loader.mm", - "relu_cpp.metal", - "common.h", + "relu/relu.cpp", + "relu/metallib_loader.mm", + "relu/relu_cpp.metal", + "relu/common.h", ] depends = [ "torch", "metal-cpp" ] \ No newline at end of file diff --git a/examples/relu-metal-cpp/flake.lock b/examples/relu-metal-cpp/flake.lock deleted file mode 100644 index 4c779b54..00000000 --- a/examples/relu-metal-cpp/flake.lock +++ /dev/null @@ -1,164 +0,0 @@ -{ - "nodes": { - "flake-compat": { - "locked": { - "lastModified": 1761588595, - "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-compat_2": { - "locked": { - "lastModified": 1761588595, - "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-utils": { - "inputs": { - "systems": "systems" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "flake-utils_2": { - "inputs": { - "systems": "systems_2" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "hf-nix": { - "inputs": { - "flake-compat": "flake-compat_2", - "flake-utils": "flake-utils_2", - "nixpkgs": "nixpkgs" - }, - "locked": { - "lastModified": 1762833384, - "narHash": "sha256-YPLcoqJBlYYGgmTG/J7bDx7lHqPnJ8beLs0+9OLfFhM=", - "owner": "huggingface", - "repo": "hf-nix", - "rev": "752645bcda8793906249809319fa9b8dc11d7af6", - "type": "github" - }, - "original": { - "owner": "huggingface", - "repo": "hf-nix", - "type": "github" - } - }, - "kernel-builder": { - "inputs": { - "flake-compat": "flake-compat", - "flake-utils": "flake-utils", - "hf-nix": "hf-nix", - "nixpkgs": [ - "kernel-builder", - "hf-nix", - "nixpkgs" - ] - }, - "locked": { - "path": "../..", - "type": "path" - }, - "original": { - "path": "../..", - "type": "path" - }, - "parent": [] - }, - "nixpkgs": { - "locked": { - "lastModified": 1762764791, - "narHash": "sha256-mWl8rYSYDFWD+zCR0VkBjEjD9jYj1/nlkDOfNNu44NA=", - "owner": "nixos", - "repo": "nixpkgs", - "rev": "b549734f6b3ec54bb9a611a4185d11ee31f52ee1", - "type": "github" - }, - "original": { - "owner": "nixos", - "ref": "nixos-unstable-small", - "repo": "nixpkgs", - "type": "github" - } - }, - "root": { - "inputs": { - "kernel-builder": "kernel-builder" - } - }, - "systems": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - }, - "systems_2": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - } - }, - "root": "root", - "version": 7 -} diff --git a/examples/relu-metal-cpp/common.h b/examples/relu-metal-cpp/relu/common.h similarity index 100% rename from examples/relu-metal-cpp/common.h rename to examples/relu-metal-cpp/relu/common.h diff --git a/examples/relu-metal-cpp/metallib_loader.mm b/examples/relu-metal-cpp/relu/metallib_loader.mm similarity index 100% rename from examples/relu-metal-cpp/metallib_loader.mm rename to examples/relu-metal-cpp/relu/metallib_loader.mm diff --git a/examples/relu-metal-cpp/relu.cpp b/examples/relu-metal-cpp/relu/relu.cpp similarity index 100% rename from examples/relu-metal-cpp/relu.cpp rename to examples/relu-metal-cpp/relu/relu.cpp diff --git a/examples/relu-metal-cpp/relu_cpp.metal b/examples/relu-metal-cpp/relu/relu_cpp.metal similarity index 100% rename from examples/relu-metal-cpp/relu_cpp.metal rename to examples/relu-metal-cpp/relu/relu_cpp.metal diff --git a/examples/relu-metal-cpp/result b/examples/relu-metal-cpp/result deleted file mode 120000 index d8bc7df4..00000000 --- a/examples/relu-metal-cpp/result +++ /dev/null @@ -1 +0,0 @@ -/nix/store/khawpdnqpl0c14h14gg2g2abvxqclxp0-torch-ext-bundle \ No newline at end of file diff --git a/examples/relu-metal-cpp/tests/__init__.py b/examples/relu-metal-cpp/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/relu-metal-cpp/tests/test_relu.py b/examples/relu-metal-cpp/tests/test_relu.py new file mode 100644 index 00000000..65544aa4 --- /dev/null +++ b/examples/relu-metal-cpp/tests/test_relu.py @@ -0,0 +1,19 @@ +import platform + +import torch +import torch.nn.functional as F + +import relu + + +def test_relu(): + if platform.system() == "Darwin": + device = torch.device("mps") + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + elif torch.version.cuda is not None and torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + x = torch.randn(1024, 1024, dtype=torch.float32, device=device) + torch.testing.assert_allclose(F.relu(x), relu.relu(x)) diff --git a/examples/relu/result b/examples/relu/result deleted file mode 120000 index 49c0a773..00000000 --- a/examples/relu/result +++ /dev/null @@ -1 +0,0 @@ -/nix/store/5p3508x9iq4w69bwad427i99b7k0f5pj-torch-ext-bundle \ No newline at end of file diff --git a/flake.lock b/flake.lock index 8ab650e2..006f1995 100644 --- a/flake.lock +++ b/flake.lock @@ -17,11 +17,11 @@ }, "flake-compat_2": { "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "owner": "edolstra", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "type": "github" }, "original": { @@ -73,11 +73,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1762504832, - "narHash": "sha256-PIxh2ZFqq3CAkQNtupT0AfxA1n3raM/3enDLHn4a21k=", + "lastModified": 1762833384, + "narHash": "sha256-YPLcoqJBlYYGgmTG/J7bDx7lHqPnJ8beLs0+9OLfFhM=", "owner": "huggingface", "repo": "hf-nix", - "rev": "3267e738faafa71bed7de9b75d74f6a90ec1bc57", + "rev": "752645bcda8793906249809319fa9b8dc11d7af6", "type": "github" }, "original": { @@ -88,11 +88,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1762168314, - "narHash": "sha256-+DX6mIF47gRGoK0mqkTg1Jmcjcup0CAXJFHVkdUx8YA=", + "lastModified": 1762764791, + "narHash": "sha256-mWl8rYSYDFWD+zCR0VkBjEjD9jYj1/nlkDOfNNu44NA=", "owner": "nixos", "repo": "nixpkgs", - "rev": "94fc102d2c15d9c1a861e59de550807c65358e1b", + "rev": "b549734f6b3ec54bb9a611a4185d11ee31f52ee1", "type": "github" }, "original": { diff --git a/test-kernel.py b/test-kernel.py new file mode 100644 index 00000000..072c4067 --- /dev/null +++ b/test-kernel.py @@ -0,0 +1,23 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = ["kernels", "torch", "numpy"] +# /// +from kernels import get_local_kernel +import torch +from pathlib import Path + +relu = get_local_kernel(Path("examples/relu-metal-cpp/result"), "relu").relu + +input = torch.tensor([-1.0, -1.5, 0.0, 2.0, 3.5], device="mps", dtype=torch.float16) +out = relu(input) +ref = torch.relu(input) + +assert torch.allclose(out, ref), f"Float16 failed: {out} != {ref}" + +print(out.cpu().numpy()) +print(ref.cpu().numpy()) + +print("PASS") +# [0. 0. 0. 2. 3.5] +# [0. 0. 0. 2. 3.5] +# PASS From 9a22bbbd0a8d51b4b34ccc5ebb60160e2db70573 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 11 Nov 2025 09:54:56 -0500 Subject: [PATCH 06/10] fix: remove relu flake lock --- examples/relu/flake.lock | 164 --------------------------------------- 1 file changed, 164 deletions(-) delete mode 100644 examples/relu/flake.lock diff --git a/examples/relu/flake.lock b/examples/relu/flake.lock deleted file mode 100644 index 0182bc37..00000000 --- a/examples/relu/flake.lock +++ /dev/null @@ -1,164 +0,0 @@ -{ - "nodes": { - "flake-compat": { - "locked": { - "lastModified": 1761588595, - "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-compat_2": { - "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-utils": { - "inputs": { - "systems": "systems" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "flake-utils_2": { - "inputs": { - "systems": "systems_2" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "hf-nix": { - "inputs": { - "flake-compat": "flake-compat_2", - "flake-utils": "flake-utils_2", - "nixpkgs": "nixpkgs" - }, - "locked": { - "lastModified": 1762504832, - "narHash": "sha256-PIxh2ZFqq3CAkQNtupT0AfxA1n3raM/3enDLHn4a21k=", - "owner": "huggingface", - "repo": "hf-nix", - "rev": "3267e738faafa71bed7de9b75d74f6a90ec1bc57", - "type": "github" - }, - "original": { - "owner": "huggingface", - "repo": "hf-nix", - "type": "github" - } - }, - "kernel-builder": { - "inputs": { - "flake-compat": "flake-compat", - "flake-utils": "flake-utils", - "hf-nix": "hf-nix", - "nixpkgs": [ - "kernel-builder", - "hf-nix", - "nixpkgs" - ] - }, - "locked": { - "path": "../..", - "type": "path" - }, - "original": { - "path": "../..", - "type": "path" - }, - "parent": [] - }, - "nixpkgs": { - "locked": { - "lastModified": 1762168314, - "narHash": "sha256-+DX6mIF47gRGoK0mqkTg1Jmcjcup0CAXJFHVkdUx8YA=", - "owner": "nixos", - "repo": "nixpkgs", - "rev": "94fc102d2c15d9c1a861e59de550807c65358e1b", - "type": "github" - }, - "original": { - "owner": "nixos", - "ref": "nixos-unstable-small", - "repo": "nixpkgs", - "type": "github" - } - }, - "root": { - "inputs": { - "kernel-builder": "kernel-builder" - } - }, - "systems": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - }, - "systems_2": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - } - }, - "root": "root", - "version": 7 -} From 67473e2d72ed17fc11193b5eb1a281fb45681c73 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 11 Nov 2025 09:57:25 -0500 Subject: [PATCH 07/10] fix: clean up conditional and arch input --- lib/deps.nix | 4 ++-- lib/torch-extension/arch.nix | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/deps.nix b/lib/deps.nix index 2fb4e4fa..0cb69e0f 100644 --- a/lib/deps.nix +++ b/lib/deps.nix @@ -33,8 +33,8 @@ let #torch.cxxdev ]; "cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ]; - "metal-cpp" = lib.optionals pkgs.stdenv.hostPlatform.isDarwin [ - pkgs.metal-cpp.dev + "metal-cpp" = [ + pkgs.metal-cpp.dev ]; }; in diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index 3a436198..6e907029 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -27,7 +27,6 @@ # Build inputs apple-sdk_26, clr, - metal-cpp, oneapi-torch-dev, onednn-xpu, torch, From 43aea1d03e82d179fddcb2a7fb257274618d53ae Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 11 Nov 2025 10:00:48 -0500 Subject: [PATCH 08/10] fix: run nix fmt --- lib/deps.nix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/deps.nix b/lib/deps.nix index 0cb69e0f..0da60528 100644 --- a/lib/deps.nix +++ b/lib/deps.nix @@ -34,7 +34,7 @@ let ]; "cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ]; "metal-cpp" = [ - pkgs.metal-cpp.dev + pkgs.metal-cpp.dev ]; }; in From 55a9904a79a1780d43ea94a7e9d5a5e0de200bbd Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 11 Nov 2025 13:03:01 -0500 Subject: [PATCH 09/10] fix: small cleanups --- examples/relu-metal-cpp/relu/common.h | 5 +---- examples/relu-metal-cpp/relu/metallib_loader.mm | 1 - examples/relu-metal-cpp/relu/relu.cpp | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/relu-metal-cpp/relu/common.h b/examples/relu-metal-cpp/relu/common.h index 1b891fad..a981a521 100644 --- a/examples/relu-metal-cpp/relu/common.h +++ b/examples/relu-metal-cpp/relu/common.h @@ -1,10 +1,7 @@ -#ifndef COMMON_H -#define COMMON_H +#pragma once #include using namespace metal; // Common constants and utilities for Metal kernels constant float RELU_THRESHOLD = 0.0f; - -#endif // COMMON_H \ No newline at end of file diff --git a/examples/relu-metal-cpp/relu/metallib_loader.mm b/examples/relu-metal-cpp/relu/metallib_loader.mm index 050ee791..9e63d909 100644 --- a/examples/relu-metal-cpp/relu/metallib_loader.mm +++ b/examples/relu-metal-cpp/relu/metallib_loader.mm @@ -1,5 +1,4 @@ #import -#import #include #include diff --git a/examples/relu-metal-cpp/relu/relu.cpp b/examples/relu-metal-cpp/relu/relu.cpp index c07ad544..85b9ac57 100644 --- a/examples/relu-metal-cpp/relu/relu.cpp +++ b/examples/relu-metal-cpp/relu/relu.cpp @@ -3,7 +3,6 @@ // Include metal-cpp headers from system #include -#include #include #include From 1b4fe2a6511da6cccb1efb30fbe4e085d3ddf951 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 12 Nov 2025 18:46:39 -0500 Subject: [PATCH 10/10] fix: remove dev test file --- test-kernel.py | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 test-kernel.py diff --git a/test-kernel.py b/test-kernel.py deleted file mode 100644 index 072c4067..00000000 --- a/test-kernel.py +++ /dev/null @@ -1,23 +0,0 @@ -# /// script -# requires-python = ">=3.10" -# dependencies = ["kernels", "torch", "numpy"] -# /// -from kernels import get_local_kernel -import torch -from pathlib import Path - -relu = get_local_kernel(Path("examples/relu-metal-cpp/result"), "relu").relu - -input = torch.tensor([-1.0, -1.5, 0.0, 2.0, 3.5], device="mps", dtype=torch.float16) -out = relu(input) -ref = torch.relu(input) - -assert torch.allclose(out, ref), f"Float16 failed: {out} != {ref}" - -print(out.cpu().numpy()) -print(ref.cpu().numpy()) - -print("PASS") -# [0. 0. 0. 2. 3.5] -# [0. 0. 0. 2. 3.5] -# PASS