diff --git a/dub.json b/dub.json index 3154710..7f14bee 100644 --- a/dub.json +++ b/dub.json @@ -7,6 +7,7 @@ "dependencies": { "derelict-cl" : "~>3.2.0", "bindbc-cuda": "~>0.1.0", + "metal-d": "~>0.5.2", "taggedalgebraic": "~>0.10.7" }, "configurations": [ @@ -33,5 +34,11 @@ "targetType": "executable", "versions": ["DComputeTestOpenCL"], }, + { + "name" : "test-metal", + "dflags": ["-mdcompute-targets=metal-400", "-version=LDC_DCompute","-oq"], + "targetType": "executable", + "versions": ["DComputeTestMetal"], + }, ] } diff --git a/source/dcompute/driver/metal/buffer.d b/source/dcompute/driver/metal/buffer.d new file mode 100644 index 0000000..3700a97 --- /dev/null +++ b/source/dcompute/driver/metal/buffer.d @@ -0,0 +1,30 @@ +module dcompute.driver.metal.buffer; +import metal; +import dcompute.driver.metal.program; +import dcompute.driver.metal; +import core.stdc.string; + +struct Buffer(T) +{ + MTLBuffer mtlBuffer; + + // Host memory associated with this buffer + T[] hostMemory; + + this(MTLBuffer _mtlBuffer, T[] array) + { + mtlBuffer = _mtlBuffer; + hostMemory = array; + } + + T* contents() + { + return cast(T*) mtlBuffer.contents(); + } + + void release() + { + mtlBuffer = null; + hostMemory = null; + } +} diff --git a/source/dcompute/driver/metal/device.d b/source/dcompute/driver/metal/device.d new file mode 100644 index 0000000..847bf58 --- /dev/null +++ b/source/dcompute/driver/metal/device.d @@ -0,0 +1,44 @@ +module dcompute.driver.metal.device; +import dcompute.driver.metal.buffer; +import core.stdc.string; +import metal; + +struct Device +{ + /** + A pointer to $(D MTLDevice). It is $(D void*) because upon storing array of $(D Device), + linker look for the $(D MTLDevice) but fails to + find it as it is Objective-C binding hence had to wrap it as such + */ + void* raw; + + @property MTLDevice mtlDevice() + { + return cast(MTLDevice) raw; + } + + this(MTLDevice device) + { + raw = cast(void*)device; + } + + MTLBuffer newBuffer(size_t sizeInBytes) + { + return mtlDevice.newBuffer(sizeInBytes, MTLResourceOptions.StorageModeShared); + } + + Buffer!T makeBuffer(T)(T[] hostMemory) + { + size_t sizeInBytes = hostMemory.length * T.sizeof; + + auto mtlBuffer = newBuffer(sizeInBytes); + auto buffer = Buffer!T(mtlBuffer, hostMemory); + + if (buffer.hostMemory.ptr !is null && sizeInBytes > 0) + { + memcpy(buffer.mtlBuffer.contents(), buffer.hostMemory.ptr, sizeInBytes); + } + + return buffer; + } +} diff --git a/source/dcompute/driver/metal/kernel.d b/source/dcompute/driver/metal/kernel.d new file mode 100644 index 0000000..c0596dc --- /dev/null +++ b/source/dcompute/driver/metal/kernel.d @@ -0,0 +1,12 @@ +module dcompute.driver.metal.kernel; +import metal.library; + +struct Kernel(F) if (is(F==function) || is(F==void)) +{ + MTLFunction kernelFunction; + + this(MTLFunction _kernelFunction) + { + kernelFunction = _kernelFunction; + } +} diff --git a/source/dcompute/driver/metal/package.d b/source/dcompute/driver/metal/package.d new file mode 100644 index 0000000..1775bb4 --- /dev/null +++ b/source/dcompute/driver/metal/package.d @@ -0,0 +1,26 @@ +module dcompute.driver.metal; +import ldc.dcompute; +import std.range; +import std.meta; +import std.traits; + +public import dcompute.driver.metal.buffer; +public import dcompute.driver.metal.device; +public import dcompute.driver.metal.kernel; +public import dcompute.driver.metal.platform; +public import dcompute.driver.metal.program; +public import dcompute.driver.metal.queue; + + +template HostArgsOf(F) +{ + template toBuffer(T) + { + static if (is(T: Pointer!(n,U), uint n, U)) + alias toBuffer = Buffer!U; + else + alias toBuffer = T; + } + + alias HostArgsOf = staticMap!(toBuffer, Parameters!F); +} diff --git a/source/dcompute/driver/metal/platform.d b/source/dcompute/driver/metal/platform.d new file mode 100644 index 0000000..5caa7be --- /dev/null +++ b/source/dcompute/driver/metal/platform.d @@ -0,0 +1,26 @@ +module dcompute.driver.metal.platform; + +import dcompute.driver.metal.device; +import metal.device; + +struct Platform +{ + static Device[] getDevices() + { + auto mtlDevices = MTLCopyAllDevices(); + auto devices = new Device[mtlDevices.length]; + + for(int i=0;i < mtlDevices.length;i ++) + { + devices[i] = Device(mtlDevices[i]); + } + + return devices; + } + + static Device getDefaultDevice() + { + auto device = Device(MTLCreateSystemDefaultDevice()); + return device; + } +} diff --git a/source/dcompute/driver/metal/program.d b/source/dcompute/driver/metal/program.d new file mode 100644 index 0000000..bf48d23 --- /dev/null +++ b/source/dcompute/driver/metal/program.d @@ -0,0 +1,62 @@ +module dcompute.driver.metal.program; +import dcompute.driver.metal.device; +import dcompute.driver.metal.kernel; +import objc; +import foundation; +import core.stdc.stdio; +import std.string; +import std.path; +import metal.library; +import metal.device; + +struct Program +{ + MTLLibrary metalLibrary; + + Device device; + + Kernel!void getKernelByName(immutable(char)* name) + { + auto kName = fromStringz(name); + + auto kNameInNSString = NSString.create(kName); + + auto kernelFunction = metalLibrary.newFunctionWithName(kNameInNSString); + + if (kernelFunction is null) + { + printf("Error: Could not find kernel function %s in library.\n", name); + assert(0); + } + + return Kernel!void(kernelFunction); + } + + Kernel!(typeof(k)) getKernel(alias k)() + { + return cast(typeof(return)) getKernelByName(k.mangleof.ptr); + } + + static Program fromFile(Device device, string path) + { + NSError error; + auto nsPath = NSString.create(absolutePath(path)); + + auto library = device.mtlDevice.newLibrary(NSURL.fromPath(nsPath), error); + + if (library is null) + { + printf("Error loading .metallib: %s\n", error.localizedDescription().ptr); + assert(0); + } + + return Program(library, device); + } + + __gshared static Program globalProgram; + + void unload() + { + metalLibrary = null; + } +} diff --git a/source/dcompute/driver/metal/queue.d b/source/dcompute/driver/metal/queue.d new file mode 100644 index 0000000..d19d308 --- /dev/null +++ b/source/dcompute/driver/metal/queue.d @@ -0,0 +1,105 @@ +module dcompute.driver.metal.queue; +import dcompute.driver.metal.buffer; + +import dcompute.driver.metal; +import dcompute.driver.metal.device; +import dcompute.driver.metal.program; +import metal; +import metal.argument; +import metal.types; +import core.stdc.stdio; +import objc; +import foundation; + +struct Queue +{ + Device device; + MTLCommandQueue commandQueue; + MTLCommandBuffer lastActiveBuffer; + + // TODO(asadbek): explore options to make the use of async execution with events + this (Device _device /*bool async*/) + { + device = _device; + commandQueue = device.mtlDevice.newCommandQueue(); + } + + auto enqueue(alias k)(uint[3] _grid, uint[3] _block) + { + static struct Call + { + Queue* q; + uint[3] grid, block; + + this(Queue* _q, uint[3] _grid, uint[3] _block) + { + q = _q; + grid = _grid; + block = _block; + } + + void opCall(HostArgsOf!(typeof(k)) args) + { + NSError error; + + auto kernel = Program.globalProgram.getKernel!k(); + + auto pipelineState = q.device.mtlDevice.newComputePipelineStateWithFunction( + kernel.kernelFunction, + MTLPipelineOption.None, + null, + error + ); + + if (pipelineState is null) + { + printf("Error: Backend compilation failed: %s\n", error.localizedDescription().ptr); + assert(0); + } + + auto commandBuffer = q.commandQueue.commandBuffer(); + + auto computeEncoder = commandBuffer.computeCommandEncoder(); + + computeEncoder.setComputePipelineState(pipelineState); + + foreach (i, arg; args) + { + static if (is(typeof(arg): Buffer!U, U)) + { + computeEncoder.setBuffer(arg.mtlBuffer, 0, i); + } else static if (__traits(isScalar, typeof(arg))) + { + computeEncoder.setBytes(&arg, typeof(arg).sizeof, i); + } + else + { + static assert(0, "Unsupported argument type for Metal kernel dispatch!"); + } + } + + auto threadgroupsPerGrid = MTLSize(grid[0], grid[1], grid[2]); + + auto threadsPerThreadgroup = MTLSize(block[0], block[1], block[2]); + + computeEncoder.dispatchThreads(threadgroupsPerGrid, threadsPerThreadgroup); + + computeEncoder.endEncoding(); + commandBuffer.commit(); + + q.lastActiveBuffer = commandBuffer; + } + } + + return Call(&this, _grid, _block); + } + + void finish() { + if (lastActiveBuffer !is null) { + lastActiveBuffer.waitUntilCompleted(); + lastActiveBuffer.release(); + + lastActiveBuffer = null; + } + } +} diff --git a/source/dcompute/std/index.d b/source/dcompute/std/index.d index 60abf6c..54aefa8 100644 --- a/source/dcompute/std/index.d +++ b/source/dcompute/std/index.d @@ -4,6 +4,7 @@ import ldc.dcompute; private import ocl = dcompute.std.opencl.index; private import cuda = dcompute.std.cuda.index; +private import metal = dcompute.std.metal.index; /* Index Terminology @@ -46,6 +47,8 @@ struct GlobalDimension return ocl.get_global_size(0); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ntid_x()*cuda.nctaid_x(); + else if (__dcompute_reflect(ReflectTarget.Metal,0)) + return metal.threads_per_grid(0); else assert(0); } @@ -56,6 +59,8 @@ struct GlobalDimension return ocl.get_global_size(1); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ntid_y()*cuda.nctaid_y(); + else if (__dcompute_reflect(ReflectTarget.Metal,0)) + return metal.threads_per_grid(1); else assert(0); } @@ -66,6 +71,8 @@ struct GlobalDimension return ocl.get_global_size(2); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ntid_z()*cuda.nctaid_z(); + else if (__dcompute_reflect(ReflectTarget.Metal,0)) + return metal.threads_per_grid(2); else assert(0); } @@ -80,6 +87,8 @@ struct GlobalIndex return ocl.get_global_id(0); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ctaid_x()*cuda.ntid_x() + cuda.tid_x(); + else if(__dcompute_reflect(ReflectTarget.Metal,0)) + return metal.thread_position_in_grid(0); else assert(0); } @@ -90,6 +99,8 @@ struct GlobalIndex return ocl.get_global_id(1); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ctaid_y()*cuda.ntid_y() + cuda.tid_y(); + else if(__dcompute_reflect(ReflectTarget.Metal,0)) + return metal.thread_position_in_grid(1); else assert(0); } @@ -100,6 +111,8 @@ struct GlobalIndex return ocl.get_global_id(2); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ctaid_z()*cuda.ntid_z() + cuda.tid_z(); + else if(__dcompute_reflect(ReflectTarget.Metal,0)) + return metal.thread_position_in_grid(2); else assert(0); } @@ -139,6 +152,8 @@ struct GroupDimension return ocl.get_num_groups(0); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.nctaid_x(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threadgroups_per_grid(0); else assert(0); } @@ -149,6 +164,8 @@ struct GroupDimension return ocl.get_num_groups(1); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.nctaid_y(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threadgroups_per_grid(1); else assert(0); } @@ -159,6 +176,8 @@ struct GroupDimension return ocl.get_num_groups(2); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.nctaid_z(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threadgroups_per_grid(2); else assert(0); } @@ -173,6 +192,8 @@ struct GroupIndex return ocl.get_group_id(0); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ctaid_x(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threadgroup_position_in_grid(0); else assert(0); } @@ -183,6 +204,8 @@ struct GroupIndex return ocl.get_group_id(1); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ctaid_y(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threadgroup_position_in_grid(1); else assert(0); } @@ -193,6 +216,8 @@ struct GroupIndex return ocl.get_group_id(2); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ctaid_z(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threadgroup_position_in_grid(2); else assert(0); } @@ -207,6 +232,8 @@ struct SharedDimension return ocl.get_local_size(0); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ntid_x(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threads_per_threadgroup(0); else assert(0); } @@ -217,6 +244,8 @@ struct SharedDimension return ocl.get_local_size(1); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ntid_y(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threads_per_threadgroup(1); else assert(0); @@ -228,6 +257,8 @@ struct SharedDimension return ocl.get_local_size(2); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.ntid_z(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.threads_per_threadgroup(2); else assert(0); } @@ -242,6 +273,8 @@ struct SharedIndex return ocl.get_local_id(0); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.tid_x(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.thread_position_in_threadgroup(0); else assert(0); } @@ -252,6 +285,8 @@ struct SharedIndex return ocl.get_local_id(1); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.tid_y(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.thread_position_in_threadgroup(1); else assert(0); } @@ -262,6 +297,8 @@ struct SharedIndex return ocl.get_local_id(2); else if(__dcompute_reflect(ReflectTarget.CUDA,0)) return cuda.tid_z(); + else if(__dcompute_reflect(ReflectTarget.Metal, 0)) + return metal.thread_position_in_threadgroup(2); else assert(0); } diff --git a/source/dcompute/std/metal/index.d b/source/dcompute/std/metal/index.d new file mode 100644 index 0000000..fdbc86f --- /dev/null +++ b/source/dcompute/std/metal/index.d @@ -0,0 +1,22 @@ +@compute(CompileFor.deviceOnly) module dcompute.std.metal.index; + +import ldc.dcompute; +pure: nothrow: @nogc: + +pragma(LDC_intrinsic, "air.get_global_id.i32") +uint thread_position_in_grid(uint dim); + +pragma(LDC_intrinsic, "air.get_local_id.i32") +uint thread_position_in_threadgroup(uint dim); + +pragma(LDC_intrinsic, "air.get_local_size.i32") +uint threads_per_threadgroup(uint dim); + +pragma(LDC_intrinsic, "air.get_group_id.i32") +uint threadgroup_position_in_grid(uint dim); + +pragma(LDC_intrinsic, "air.get_global_size.i32") +uint threads_per_grid(uint dim); + +pragma(LDC_intrinsic, "air.get_num_groups.i32") +uint threadgroups_per_grid(uint dim); \ No newline at end of file diff --git a/source/dcompute/tests/dummykernels.d b/source/dcompute/tests/dummykernels.d index ab37d3a..daff4cd 100644 --- a/source/dcompute/tests/dummykernels.d +++ b/source/dcompute/tests/dummykernels.d @@ -5,21 +5,34 @@ pragma(LDC_no_moduleinfo); import ldc.dcompute; import dcompute.std.index; -@kernel() void saxpy(GlobalPointer!(float) res, - float alpha,GlobalPointer!(float) x, - GlobalPointer!(float) y, - size_t N) -{ - auto i = GlobalIndex.x; - if (i >= N) return; - res[i] = alpha*x[i] + y[i]; -} +version(DComputeTestMetal){ + @kernel() void saxpy(GlobalPointer!(float) res, + float alpha,GlobalPointer!(float) x, + GlobalPointer!(float) y, + size_t N) + { + auto i = GlobalIndex.x; + if (i >= N) return; + res[i] = alpha*x[i] + y[i]; + } +} else { + @kernel() void saxpy(GlobalPointer!(float) res, + float alpha,GlobalPointer!(float) x, + GlobalPointer!(float) y, + size_t N) + { + auto i = GlobalIndex.x; + if (i >= N) return; + res[i] = alpha*x[i] + y[i]; + } + -alias aagf = AutoIndexed!(GlobalPointer!(float)); + alias aagf = AutoIndexed!(GlobalPointer!(float)); -@kernel() void auto_index_test(aagf a, - aagf b, - aagf c) -{ - a = b + c; + @kernel() void auto_index_test(aagf a, + aagf b, + aagf c) + { + a = b + c; + } } diff --git a/source/dcompute/tests/main.d b/source/dcompute/tests/main.d index bc1a918..0f1e524 100644 --- a/source/dcompute/tests/main.d +++ b/source/dcompute/tests/main.d @@ -23,6 +23,8 @@ version(DComputeTestOpenCL) import dcompute.driver.ocl; else version(DComputeTestCUDA) import dcompute.driver.cuda; +else version(DComputeTestMetal) + import dcompute.driver.metal; else static assert(false, "Need to test something!"); @@ -170,6 +172,46 @@ int main(string[] args) } } + version(DComputeTestMetal) + { + auto devices = Platform.getDevices(); + + auto device = devices[0]; + + if (device.raw is null) + { + "Failed to fetch default device".writeln; + return 1; + } + + auto program = Program.fromFile(device, "./kernels_metal400_64.metallib"); + + Program.globalProgram = program; + + if (program.metalLibrary is null) + { + "Failed to load .metallibrary".writeln; + return 2; + } + + auto deviceX = device.makeBuffer!float(x); + auto deviceY = device.makeBuffer!float(y); + auto deviceRes = device.makeBuffer!float(res); + + auto queue = Queue(device); + + queue.enqueue!(saxpy) + ([N,1,1],[1,1,1]) + (deviceRes, alpha, deviceX, deviceY, N); + + queue.finish(); + + // Copy data from device buffer to host + auto contents = deviceRes.contents(); + + res = contents[0 .. res.length]; + } + foreach(i; 0 .. N) enforce(res[i] == alpha * x[i] + y[i]); writeln(res[]);