1212//
1313// ===----------------------------------------------------------------------===//
1414
15+ #include < cstdlib>
1516#include < cuda.h>
1617#include < cuda_runtime_api.h>
1718#include < stdio.h>
1819#include < stdlib.h>
20+ #include < sys/types.h>
1921
2022#include " cuda.h"
2123#include " cuda_bf16.h"
2224#include " cuda_fp16.h"
25+ #include < vector>
2326
2427// We assume the program runs on the linux platform if not on Windows.
2528// Copy from
@@ -246,6 +249,8 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
246249 defaultDevice = device;
247250}
248251
252+ // ===----------------------------------------------------------------------===//
253+
249254extern " C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCtxSynchronize () {
250255 ScopedContext scopedContext;
251256 CUDA_REPORT_IF_ERROR (cuCtxSynchronize ());
@@ -263,4 +268,261 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemcpyDtoH(void *dst, void *src,
263268 cuMemcpyDtoH (dst, reinterpret_cast <CUdeviceptr>(src), sizeBytes));
264269}
265270
271+ // ===----------------------------------------------------------------------===//
272+
273+ static inline CUdeviceptr asDevPtr (uint64_t h) {
274+ return static_cast <CUdeviceptr>(h);
275+ }
276+ static inline uint64_t asHandle (CUdeviceptr p) {
277+ return static_cast <uint64_t >(p);
278+ }
279+
280+ static inline CUstream asStream (uint64_t h) {
281+ return reinterpret_cast <CUstream>(static_cast <uintptr_t >(h));
282+ }
283+ static inline uint64_t asStreamHandle (CUstream s) {
284+ return static_cast <uint64_t >(reinterpret_cast <uintptr_t >(s));
285+ }
286+
287+ static inline CUevent asEvent (uint64_t h) {
288+ return reinterpret_cast <CUevent>(static_cast <uintptr_t >(h));
289+ }
290+ static inline uint64_t asEventHandle (CUevent e) {
291+ return static_cast <uint64_t >(reinterpret_cast <uintptr_t >(e));
292+ }
293+
294+ static inline void *asHostPtr (uint64_t h) {
295+ return reinterpret_cast <void *>(static_cast <uintptr_t >(h));
296+ }
297+ static inline const void *asHostCPtr (uint64_t h) {
298+ return reinterpret_cast <const void *>(static_cast <uintptr_t >(h));
299+ }
300+
301+ // Align up helper
302+ static inline uint64_t alignUp (uint64_t x, uint64_t a) {
303+ return (x + (a - 1 )) & ~(a - 1 );
304+ }
305+
306+ // Load module from PTX or CUBIN image in memory.
307+ // Driver API supports cuModuleLoadDataEx for both PTX and cubin (it
308+ // auto-detects).
309+ extern " C" uint64_t cuda_shim_load_module_from_image (uint64_t image_ptr,
310+ uint64_t image_nbytes) {
311+
312+ (void )image_nbytes;
313+ auto data = const_cast <void *>(asHostCPtr (image_ptr));
314+ CUmodule mod = mgpuModuleLoad (data, image_nbytes);
315+ return static_cast <uint64_t >(reinterpret_cast <uintptr_t >(mod));
316+ }
317+
318+ extern " C" uint64_t cuda_shim_load_module_jit_from_image (uint64_t image_ptr,
319+ uint64_t image_nbytes,
320+ int opt_level) {
321+
322+ (void )image_nbytes;
323+ auto data = const_cast <void *>(asHostCPtr (image_ptr));
324+ CUmodule mod = mgpuModuleLoadJIT (data, opt_level);
325+ return static_cast <uint64_t >(reinterpret_cast <uintptr_t >(mod));
326+ }
327+
328+ extern " C" uint64_t
329+ cuda_shim_load_module_from_file (uint64_t file_path_ptr,
330+ uint64_t /* file_path_nbytes*/ ) {
331+ auto file_path_cstr =
332+ reinterpret_cast <const char *>(asHostCPtr (file_path_ptr));
333+ // fprintf(stdout, "%s", file_path_cstr);
334+ CUmodule module = nullptr ;
335+ ScopedContext scopedContext;
336+ CUDA_REPORT_IF_ERROR (cuModuleLoad (&module , file_path_cstr));
337+ return static_cast <uint64_t >(reinterpret_cast <uintptr_t >(module ));
338+ }
339+
340+ extern " C" void cuda_shim_unload_module (uint64_t module_handle) {
341+ CUmodule module =
342+ reinterpret_cast <CUmodule>(static_cast <uintptr_t >(module_handle));
343+ mgpuModuleUnload (module );
344+ }
345+
346+ extern " C" uint64_t cuda_shim_malloc (uint64_t nbytes, uint64_t stream,
347+ bool is_host_shared) {
348+ CUstream cu_stream = asStream (stream);
349+ if (stream == 0 )
350+ cu_stream = nullptr ;
351+ void *ptr = mgpuMemAlloc (nbytes, /* stream=*/ cu_stream,
352+ /* isHostShared=*/ is_host_shared);
353+ return static_cast <uint64_t >(reinterpret_cast <uintptr_t >(ptr));
354+ }
355+
356+ extern " C" void cuda_shim_free (uint64_t dptr, uint64_t stream) {
357+ CUstream cu_stream = asStream (stream);
358+ void *ptr = reinterpret_cast <void *>(static_cast <uintptr_t >(dptr));
359+ if (stream == 0 ) {
360+ cu_stream = nullptr ;
361+ }
362+ mgpuMemFree (ptr, /* stream=*/ cu_stream);
363+ }
364+
365+ extern " C" void cuda_shim_memset32 (uint64_t dptr, uint32_t value,
366+ uint64_t count_dwords, uint64_t stream) {
367+ void *ptr = reinterpret_cast <void *>(static_cast <uintptr_t >(dptr));
368+ CUstream cu_stream = asStream (stream);
369+ mgpuMemset32 (ptr, value, count_dwords, cu_stream);
370+ }
371+
372+ extern " C" void cuda_shim_memset16 (uint64_t dptr, uint32_t value,
373+ uint64_t count_dwords, uint64_t stream) {
374+ void *ptr = reinterpret_cast <void *>(static_cast <uintptr_t >(dptr));
375+ CUstream cu_stream = asStream (stream);
376+ mgpuMemset16 (ptr, value, count_dwords, cu_stream);
377+ }
378+
379+ extern " C" uint64_t cuda_shim_stream_create (void ) {
380+ CUstream stream = mgpuStreamCreate ();
381+ return asStreamHandle (stream);
382+ }
383+
384+ extern " C" void cuda_shim_stream_destroy (uint64_t stream) {
385+ CUstream cu_stream = asStream (stream);
386+ mgpuStreamDestroy (cu_stream);
387+ }
388+
389+ extern " C" void cuda_shim_stream_synchronize (uint64_t stream) {
390+ CUstream cu_stream = asStream (stream);
391+ mgpuStreamSynchronize (cu_stream);
392+ }
393+
394+ extern " C" uint64_t cuda_shim_event_create (void ) {
395+ CUevent event = mgpuEventCreate ();
396+ return asEventHandle (event);
397+ }
398+
399+ extern " C" void cuda_shim_event_destroy (uint64_t ev) {
400+ CUevent event = asEvent (ev);
401+ mgpuEventDestroy (event);
402+ }
403+
404+ extern " C" void cuda_shim_event_record (uint64_t ev, uint64_t stream) {
405+ CUevent event = asEvent (ev);
406+ CUstream cu_stream = asStream (stream);
407+ mgpuEventRecord (event, cu_stream);
408+ }
409+
410+ extern " C" void cuda_shim_event_synchronize (uint64_t ev) {
411+ CUevent event = asEvent (ev);
412+ mgpuEventSynchronize (event);
413+ }
414+
415+ extern " C" void cuda_shim_stream_wait_event (uint64_t stream, uint64_t ev) {
416+ CUstream cu_stream = asStream (stream);
417+ CUevent event = asEvent (ev);
418+ mgpuStreamWaitEvent (cu_stream, event);
419+ }
420+
421+ // ----------------------------- Memcpy (raw ABI) --------------------------
422+ // Host pointers are passed as uint64_t. This is the key of 2A.
423+
424+ extern " C" void cuda_shim_memcpy_h2d (uint64_t dst_dptr, uint64_t src_hptr,
425+ uint64_t nbytes) {
426+ ScopedContext scopedContext;
427+ auto dst = asHostPtr (dst_dptr);
428+ auto src = asHostPtr (src_hptr);
429+ mgpuMemcpyHtoD (dst, src, static_cast <size_t >(nbytes));
430+ }
431+
432+ extern " C" void cuda_shim_memcpy_d2h (uint64_t dst_hptr, uint64_t src_dptr,
433+ uint64_t nbytes) {
434+ ScopedContext scopedContext;
435+ auto dst = asHostPtr (dst_hptr);
436+ auto src = asHostPtr (src_dptr);
437+ mgpuMemcpyDtoH (dst, src, static_cast <size_t >(nbytes));
438+ }
439+
440+ // ----------------------------- Kernel launch -----------------------------
441+ // The hardest part is kernelParams (void**).
442+ // We avoid building it in MLIR. Instead MLIR passes:
443+ // - arg_data_ptr: host pointer to a packed buffer containing raw argument bytes
444+ // - arg_sizes_ptr: host pointer to uint64_t[num_args], each is the byte-size of
445+ // that argument The shim constructs kernelParams[i] = &arg_data[offset_i] with
446+ // 8-byte alignment. This matches typical ABI expectations for scalar/pointer
447+ // args. If you have special alignment requirements, extend this (e.g., per-arg
448+ // alignment array).
449+
450+ extern " C" void cuda_shim_launch_packed (
451+ uint64_t module_handle, uint64_t kernel_name_ptr, uint32_t gridX,
452+ uint32_t gridY, uint32_t gridZ, uint32_t blockX, uint32_t blockY,
453+ uint32_t blockZ, uint32_t sharedMemBytes, uint64_t stream,
454+ uint64_t arg_data_ptr, uint64_t arg_sizes_ptr, uint32_t num_args) {
455+
456+ auto mh = reinterpret_cast <CUmodule>(static_cast <uintptr_t >(module_handle));
457+ if (!mh) {
458+ fprintf (stderr, " [cuda_shim] launch_packed: invalid module handle\n " );
459+ abort ();
460+ }
461+
462+ const char *kname =
463+ reinterpret_cast <const char *>(asHostCPtr (kernel_name_ptr));
464+ if (!kname) {
465+ fprintf (stderr, " [cuda_shim] launch_packed: null kernel name\n " );
466+ abort ();
467+ }
468+
469+ CUfunction fn = mgpuModuleGetFunction (mh, kname);
470+
471+ auto *argData = reinterpret_cast <uint8_t *>(asHostPtr (arg_data_ptr));
472+ auto *argSizes =
473+ reinterpret_cast <const uint64_t *>(asHostCPtr (arg_sizes_ptr));
474+
475+ if (num_args > 0 && (!argData || !argSizes)) {
476+ fprintf (stderr, " [cuda_shim] launch_packed: argData/argSizes null\n " );
477+ abort ();
478+ }
479+
480+ // Build kernelParams array on heap (safe for large num_args).
481+ std::vector<void *> params;
482+ params.resize (num_args);
483+
484+ uint64_t off = 0 ;
485+ for (uint32_t i = 0 ; i < num_args; ++i) {
486+ // 8-byte align each argument start (common safe default).
487+ off = alignUp (off, 8 );
488+ params[i] = argData + off;
489+ off += argSizes[i];
490+ }
491+
492+ auto cu_stream = asStream (stream);
493+
494+ if (stream == 0 ) {
495+ cu_stream = nullptr ;
496+ }
497+
498+ mgpuLaunchKernel (fn, static_cast <intptr_t >(gridX),
499+ static_cast <intptr_t >(gridY), static_cast <intptr_t >(gridZ),
500+ static_cast <intptr_t >(blockX), static_cast <intptr_t >(blockY),
501+ static_cast <intptr_t >(blockZ),
502+ static_cast <int32_t >(sharedMemBytes), cu_stream,
503+ params.data (), nullptr , static_cast <size_t >(num_args));
504+ }
505+
506+ // Convenience: 1D launch, shared=0, stream optional
507+ extern " C" void
508+ cuda_shim_launch_block_packed (uint64_t module_handle, uint64_t kernel_name_ptr,
509+ uint32_t blockX, uint32_t blockY, uint32_t blockZ,
510+ uint64_t stream, uint64_t arg_data_ptr,
511+ uint64_t arg_sizes_ptr, uint32_t num_args) {
512+ cuda_shim_launch_packed (module_handle, kernel_name_ptr, 1 , 1 , 1 , blockX,
513+ blockY, blockZ, 0 , stream, arg_data_ptr,
514+ arg_sizes_ptr, num_args);
515+ }
516+
517+ // Optional: global sync (avoid in async pipeline; prefer event/stream sync)
518+ extern " C" void cuda_shim_ctx_synchronize (void ) { mgpuCtxSynchronize (); }
519+
520+ // only for debugging
521+ // extern "C" void cuda_debug_dump_float(uint64_t dptr, int n) {
522+ // auto *p = reinterpret_cast<const float *>(static_cast<uintptr_t>(dptr));
523+ // for (uint32_t i = 0; i < n; ++i) {
524+ // fprintf(stderr, "i=%u v=%f\n", i, p[i]);
525+ // }
526+ // }
527+
266528#endif
0 commit comments