Skip to content

Commit 99596b0

Browse files
committed
Verified the cuda shim API and POC is ready for the cuda shim
1 parent f2d06f6 commit 99596b0

12 files changed

Lines changed: 1266 additions & 47 deletions

File tree

mlir/cuda-tile/.devcontainer/devcontainer.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@
6868
"llvm-vs-code-extensions.vscode-clangd",
6969
"llvm-vs-code-extensions.lldb-dap",
7070
"mutantdino.resourcemonitor",
71-
"hoovercj.vscode-power-mode"
71+
"hoovercj.vscode-power-mode",
72+
"GitHub.copilot-chat",
73+
"Codereviewforgithubcopilot.github-copilot-code-review"
7274
]
7375
}
7476
}

mlir/cuda-tile/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*.ptx
2+
*.cubin
3+
*.fatbin

mlir/cuda-tile/Toy/cuda_wrapper/cuda_shim.cpp

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
#include <cstdlib>
1516
#include <cuda.h>
1617
#include <cuda_runtime_api.h>
1718
#include <stdio.h>
1819
#include <stdlib.h>
20+
#include <sys/types.h>
1921

2022
#include "cuda.h"
2123
#include "cuda_bf16.h"
2224
#include "cuda_fp16.h"
25+
#include <vector>
2326

2427
// We assume the program runs on the linux platform if not on Windows.
2528
// Copy from
@@ -246,6 +249,8 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
246249
defaultDevice = device;
247250
}
248251

252+
// ===----------------------------------------------------------------------===//
253+
249254
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCtxSynchronize() {
250255
ScopedContext scopedContext;
251256
CUDA_REPORT_IF_ERROR(cuCtxSynchronize());
@@ -263,4 +268,261 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemcpyDtoH(void *dst, void *src,
263268
cuMemcpyDtoH(dst, reinterpret_cast<CUdeviceptr>(src), sizeBytes));
264269
}
265270

271+
//===----------------------------------------------------------------------===//
272+
273+
static inline CUdeviceptr asDevPtr(uint64_t h) {
274+
return static_cast<CUdeviceptr>(h);
275+
}
276+
static inline uint64_t asHandle(CUdeviceptr p) {
277+
return static_cast<uint64_t>(p);
278+
}
279+
280+
static inline CUstream asStream(uint64_t h) {
281+
return reinterpret_cast<CUstream>(static_cast<uintptr_t>(h));
282+
}
283+
static inline uint64_t asStreamHandle(CUstream s) {
284+
return static_cast<uint64_t>(reinterpret_cast<uintptr_t>(s));
285+
}
286+
287+
static inline CUevent asEvent(uint64_t h) {
288+
return reinterpret_cast<CUevent>(static_cast<uintptr_t>(h));
289+
}
290+
static inline uint64_t asEventHandle(CUevent e) {
291+
return static_cast<uint64_t>(reinterpret_cast<uintptr_t>(e));
292+
}
293+
294+
static inline void *asHostPtr(uint64_t h) {
295+
return reinterpret_cast<void *>(static_cast<uintptr_t>(h));
296+
}
297+
static inline const void *asHostCPtr(uint64_t h) {
298+
return reinterpret_cast<const void *>(static_cast<uintptr_t>(h));
299+
}
300+
301+
// Align up helper
302+
static inline uint64_t alignUp(uint64_t x, uint64_t a) {
303+
return (x + (a - 1)) & ~(a - 1);
304+
}
305+
306+
// Load module from PTX or CUBIN image in memory.
307+
// Driver API supports cuModuleLoadDataEx for both PTX and cubin (it
308+
// auto-detects).
309+
extern "C" uint64_t cuda_shim_load_module_from_image(uint64_t image_ptr,
310+
uint64_t image_nbytes) {
311+
312+
(void)image_nbytes;
313+
auto data = const_cast<void *>(asHostCPtr(image_ptr));
314+
CUmodule mod = mgpuModuleLoad(data, image_nbytes);
315+
return static_cast<uint64_t>(reinterpret_cast<uintptr_t>(mod));
316+
}
317+
318+
extern "C" uint64_t cuda_shim_load_module_jit_from_image(uint64_t image_ptr,
319+
uint64_t image_nbytes,
320+
int opt_level) {
321+
322+
(void)image_nbytes;
323+
auto data = const_cast<void *>(asHostCPtr(image_ptr));
324+
CUmodule mod = mgpuModuleLoadJIT(data, opt_level);
325+
return static_cast<uint64_t>(reinterpret_cast<uintptr_t>(mod));
326+
}
327+
328+
extern "C" uint64_t
329+
cuda_shim_load_module_from_file(uint64_t file_path_ptr,
330+
uint64_t /*file_path_nbytes*/) {
331+
auto file_path_cstr =
332+
reinterpret_cast<const char *>(asHostCPtr(file_path_ptr));
333+
// fprintf(stdout, "%s", file_path_cstr);
334+
CUmodule module = nullptr;
335+
ScopedContext scopedContext;
336+
CUDA_REPORT_IF_ERROR(cuModuleLoad(&module, file_path_cstr));
337+
return static_cast<uint64_t>(reinterpret_cast<uintptr_t>(module));
338+
}
339+
340+
extern "C" void cuda_shim_unload_module(uint64_t module_handle) {
341+
CUmodule module =
342+
reinterpret_cast<CUmodule>(static_cast<uintptr_t>(module_handle));
343+
mgpuModuleUnload(module);
344+
}
345+
346+
extern "C" uint64_t cuda_shim_malloc(uint64_t nbytes, uint64_t stream,
347+
bool is_host_shared) {
348+
CUstream cu_stream = asStream(stream);
349+
if (stream == 0)
350+
cu_stream = nullptr;
351+
void *ptr = mgpuMemAlloc(nbytes, /*stream=*/cu_stream,
352+
/*isHostShared=*/is_host_shared);
353+
return static_cast<uint64_t>(reinterpret_cast<uintptr_t>(ptr));
354+
}
355+
356+
extern "C" void cuda_shim_free(uint64_t dptr, uint64_t stream) {
357+
CUstream cu_stream = asStream(stream);
358+
void *ptr = reinterpret_cast<void *>(static_cast<uintptr_t>(dptr));
359+
if (stream == 0) {
360+
cu_stream = nullptr;
361+
}
362+
mgpuMemFree(ptr, /*stream=*/cu_stream);
363+
}
364+
365+
extern "C" void cuda_shim_memset32(uint64_t dptr, uint32_t value,
366+
uint64_t count_dwords, uint64_t stream) {
367+
void *ptr = reinterpret_cast<void *>(static_cast<uintptr_t>(dptr));
368+
CUstream cu_stream = asStream(stream);
369+
mgpuMemset32(ptr, value, count_dwords, cu_stream);
370+
}
371+
372+
extern "C" void cuda_shim_memset16(uint64_t dptr, uint32_t value,
373+
uint64_t count_dwords, uint64_t stream) {
374+
void *ptr = reinterpret_cast<void *>(static_cast<uintptr_t>(dptr));
375+
CUstream cu_stream = asStream(stream);
376+
mgpuMemset16(ptr, value, count_dwords, cu_stream);
377+
}
378+
379+
extern "C" uint64_t cuda_shim_stream_create(void) {
380+
CUstream stream = mgpuStreamCreate();
381+
return asStreamHandle(stream);
382+
}
383+
384+
extern "C" void cuda_shim_stream_destroy(uint64_t stream) {
385+
CUstream cu_stream = asStream(stream);
386+
mgpuStreamDestroy(cu_stream);
387+
}
388+
389+
extern "C" void cuda_shim_stream_synchronize(uint64_t stream) {
390+
CUstream cu_stream = asStream(stream);
391+
mgpuStreamSynchronize(cu_stream);
392+
}
393+
394+
extern "C" uint64_t cuda_shim_event_create(void) {
395+
CUevent event = mgpuEventCreate();
396+
return asEventHandle(event);
397+
}
398+
399+
extern "C" void cuda_shim_event_destroy(uint64_t ev) {
400+
CUevent event = asEvent(ev);
401+
mgpuEventDestroy(event);
402+
}
403+
404+
extern "C" void cuda_shim_event_record(uint64_t ev, uint64_t stream) {
405+
CUevent event = asEvent(ev);
406+
CUstream cu_stream = asStream(stream);
407+
mgpuEventRecord(event, cu_stream);
408+
}
409+
410+
extern "C" void cuda_shim_event_synchronize(uint64_t ev) {
411+
CUevent event = asEvent(ev);
412+
mgpuEventSynchronize(event);
413+
}
414+
415+
extern "C" void cuda_shim_stream_wait_event(uint64_t stream, uint64_t ev) {
416+
CUstream cu_stream = asStream(stream);
417+
CUevent event = asEvent(ev);
418+
mgpuStreamWaitEvent(cu_stream, event);
419+
}
420+
421+
// ----------------------------- Memcpy (raw ABI) --------------------------
422+
// Host pointers are passed as uint64_t. This is the key of 2A.
423+
424+
extern "C" void cuda_shim_memcpy_h2d(uint64_t dst_dptr, uint64_t src_hptr,
425+
uint64_t nbytes) {
426+
ScopedContext scopedContext;
427+
auto dst = asHostPtr(dst_dptr);
428+
auto src = asHostPtr(src_hptr);
429+
mgpuMemcpyHtoD(dst, src, static_cast<size_t>(nbytes));
430+
}
431+
432+
extern "C" void cuda_shim_memcpy_d2h(uint64_t dst_hptr, uint64_t src_dptr,
433+
uint64_t nbytes) {
434+
ScopedContext scopedContext;
435+
auto dst = asHostPtr(dst_hptr);
436+
auto src = asHostPtr(src_dptr);
437+
mgpuMemcpyDtoH(dst, src, static_cast<size_t>(nbytes));
438+
}
439+
440+
// ----------------------------- Kernel launch -----------------------------
441+
// The hardest part is kernelParams (void**).
442+
// We avoid building it in MLIR. Instead MLIR passes:
443+
// - arg_data_ptr: host pointer to a packed buffer containing raw argument bytes
444+
// - arg_sizes_ptr: host pointer to uint64_t[num_args], each is the byte-size of
445+
// that argument The shim constructs kernelParams[i] = &arg_data[offset_i] with
446+
// 8-byte alignment. This matches typical ABI expectations for scalar/pointer
447+
// args. If you have special alignment requirements, extend this (e.g., per-arg
448+
// alignment array).
449+
450+
extern "C" void cuda_shim_launch_packed(
451+
uint64_t module_handle, uint64_t kernel_name_ptr, uint32_t gridX,
452+
uint32_t gridY, uint32_t gridZ, uint32_t blockX, uint32_t blockY,
453+
uint32_t blockZ, uint32_t sharedMemBytes, uint64_t stream,
454+
uint64_t arg_data_ptr, uint64_t arg_sizes_ptr, uint32_t num_args) {
455+
456+
auto mh = reinterpret_cast<CUmodule>(static_cast<uintptr_t>(module_handle));
457+
if (!mh) {
458+
fprintf(stderr, "[cuda_shim] launch_packed: invalid module handle\n");
459+
abort();
460+
}
461+
462+
const char *kname =
463+
reinterpret_cast<const char *>(asHostCPtr(kernel_name_ptr));
464+
if (!kname) {
465+
fprintf(stderr, "[cuda_shim] launch_packed: null kernel name\n");
466+
abort();
467+
}
468+
469+
CUfunction fn = mgpuModuleGetFunction(mh, kname);
470+
471+
auto *argData = reinterpret_cast<uint8_t *>(asHostPtr(arg_data_ptr));
472+
auto *argSizes =
473+
reinterpret_cast<const uint64_t *>(asHostCPtr(arg_sizes_ptr));
474+
475+
if (num_args > 0 && (!argData || !argSizes)) {
476+
fprintf(stderr, "[cuda_shim] launch_packed: argData/argSizes null\n");
477+
abort();
478+
}
479+
480+
// Build kernelParams array on heap (safe for large num_args).
481+
std::vector<void *> params;
482+
params.resize(num_args);
483+
484+
uint64_t off = 0;
485+
for (uint32_t i = 0; i < num_args; ++i) {
486+
// 8-byte align each argument start (common safe default).
487+
off = alignUp(off, 8);
488+
params[i] = argData + off;
489+
off += argSizes[i];
490+
}
491+
492+
auto cu_stream = asStream(stream);
493+
494+
if (stream == 0) {
495+
cu_stream = nullptr;
496+
}
497+
498+
mgpuLaunchKernel(fn, static_cast<intptr_t>(gridX),
499+
static_cast<intptr_t>(gridY), static_cast<intptr_t>(gridZ),
500+
static_cast<intptr_t>(blockX), static_cast<intptr_t>(blockY),
501+
static_cast<intptr_t>(blockZ),
502+
static_cast<int32_t>(sharedMemBytes), cu_stream,
503+
params.data(), nullptr, static_cast<size_t>(num_args));
504+
}
505+
506+
// Convenience: 1D launch, shared=0, stream optional
507+
extern "C" void
508+
cuda_shim_launch_block_packed(uint64_t module_handle, uint64_t kernel_name_ptr,
509+
uint32_t blockX, uint32_t blockY, uint32_t blockZ,
510+
uint64_t stream, uint64_t arg_data_ptr,
511+
uint64_t arg_sizes_ptr, uint32_t num_args) {
512+
cuda_shim_launch_packed(module_handle, kernel_name_ptr, 1, 1, 1, blockX,
513+
blockY, blockZ, 0, stream, arg_data_ptr,
514+
arg_sizes_ptr, num_args);
515+
}
516+
517+
// Optional: global sync (avoid in async pipeline; prefer event/stream sync)
518+
extern "C" void cuda_shim_ctx_synchronize(void) { mgpuCtxSynchronize(); }
519+
520+
// only for debugging
521+
// extern "C" void cuda_debug_dump_float(uint64_t dptr, int n) {
522+
// auto *p = reinterpret_cast<const float *>(static_cast<uintptr_t>(dptr));
523+
// for (uint32_t i = 0; i < n; ++i) {
524+
// fprintf(stderr, "i=%u v=%f\n", i, p[i]);
525+
// }
526+
// }
527+
266528
#endif

mlir/cuda-tile/Toy/toyc.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,10 @@ static int loadAndProcessMLIRGPU(mlir::MLIRContext &context,
342342
if (isLoweringToAffine) {
343343
pm.addPass(mlir::toy::createEmbedCudaTileBinaryPass(
344344
"/usr/local/cuda/bin/tileiras", "sm_120"));
345-
// // Partially lower the toy dialect.
346-
// optPM.addPass(mlir::toy::createLowerToAffinePass());
345+
346+
// mlir::OpPassManager &gpuOptPM = pm.nest<mlir::toy::FuncOp>();
347+
// // Partially lower the toy dialect.
348+
// pm.addPass(mlir::toy::createLowerToAffinePass());
347349

348350
// // Add a few cleanups post lowering.
349351
// mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();

0 commit comments

Comments
 (0)