Skip to content

Commit 034cc87

Browse files
committed
upload file
1 parent 80e2354 commit 034cc87

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#include "mlir/Dialect/Func/IR/FuncOps.h"
2+
#include "mlir/IR/BuiltinOps.h"
3+
#include "mlir/IR/Operation.h"
4+
#include "mlir/IR/PatternMatch.h"
5+
#include "mlir/IR/TypeRange.h"
6+
#include "mlir/IR/Value.h"
7+
#include "llvm/ADT/DenseMap.h"
8+
9+
enum class CudaShimFn {
10+
// ----- Module -----
11+
LoadModuleFromImage,
12+
LoadModuleFromFile,
13+
UnloadModule,
14+
15+
// ----- Memory -----
16+
Malloc,
17+
Free,
18+
// Memset32,
19+
// Memset16,
20+
MemcpyH2D,
21+
MemcpyD2H,
22+
23+
// ----- Stream -----
24+
StreamCreate,
25+
StreamDestroy,
26+
StreamSynchronize,
27+
// StreamWaitEvent,
28+
29+
// ----- Event -----
30+
// EventCreate,
31+
// EventDestroy,
32+
// EventRecord,
33+
// EventSynchronize,
34+
35+
// ----- Kernel Launch -----
36+
LaunchPacked,
37+
LaunchBlockPacked,
38+
39+
// ----- Context -----
40+
CtxSynchronize
41+
};
42+
43+
class CudaShimRegistry {
44+
public:
45+
explicit CudaShimRegistry(mlir::ModuleOp module) : module(module) {}
46+
47+
mlir::func::FuncOp getOrInsert(mlir::PatternRewriter &rewriter,
48+
mlir::Operation *anchor, CudaShimFn which) {
49+
auto key = static_cast<unsigned>(which);
50+
if (auto it = cache.find(key); it != cache.end())
51+
return it->second;
52+
53+
auto spec = specOf(which, rewriter);
54+
auto existing = module.lookupSymbol<mlir::func::FuncOp>(spec.name);
55+
if (existing) {
56+
cache[key] = existing;
57+
return existing;
58+
}
59+
60+
mlir::OpBuilder::InsertionGuard guard(rewriter);
61+
rewriter.setInsertionPointToStart(module.getBody());
62+
63+
auto f = mlir::func::FuncOp::create(rewriter, anchor->getLoc(), spec.name,
64+
spec.ty);
65+
f.setPrivate();
66+
cache[key] = f;
67+
return f;
68+
}
69+
70+
mlir::func::CallOp call(mlir::PatternRewriter &rewriter,
71+
mlir::Operation *anchor, CudaShimFn which,
72+
mlir::ValueRange operands = {}) {
73+
auto f = getOrInsert(rewriter, anchor, which);
74+
75+
return mlir::func::CallOp::create(rewriter, anchor->getLoc(), f.getName(),
76+
f.getFunctionType().getResults(),
77+
operands);
78+
}
79+
80+
private:
81+
struct Spec {
82+
mlir::StringRef name;
83+
mlir::FunctionType ty;
84+
};
85+
86+
static Spec specOf(CudaShimFn which, mlir::PatternRewriter &rewriter) {
87+
auto i64 = rewriter.getI64Type();
88+
auto i32 = rewriter.getI32Type();
89+
auto i1 = rewriter.getI1Type();
90+
91+
switch (which) {
92+
93+
// ===== Module =====
94+
case CudaShimFn::LoadModuleFromImage:
95+
return {"cuda_shim_load_module_from_image",
96+
rewriter.getFunctionType({i64, i64}, {i64})};
97+
98+
case CudaShimFn::LoadModuleFromFile:
99+
return {"cuda_shim_load_module_from_file",
100+
rewriter.getFunctionType({i64, i64}, {i64})};
101+
102+
case CudaShimFn::UnloadModule:
103+
return {"cuda_shim_unload_module", rewriter.getFunctionType({i64}, {})};
104+
105+
// ===== Memory =====
106+
case CudaShimFn::Malloc:
107+
return {"cuda_shim_malloc",
108+
rewriter.getFunctionType({i64, i64, i1}, {i64})};
109+
110+
case CudaShimFn::Free:
111+
return {"cuda_shim_free", rewriter.getFunctionType({i64, i64}, {})};
112+
113+
// case CudaShimFn::Memset32:
114+
// return {"cuda_shim_memset32",
115+
// rewriter.getFunctionType({i64, i32, i64, i64}, {})};
116+
117+
// case CudaShimFn::Memset16:
118+
// return {"cuda_shim_memset16",
119+
// rewriter.getFunctionType({i64, i32, i64, i64}, {})};
120+
121+
case CudaShimFn::MemcpyH2D:
122+
return {"cuda_shim_memcpy_h2d",
123+
rewriter.getFunctionType({i64, i64, i64}, {})};
124+
125+
case CudaShimFn::MemcpyD2H:
126+
return {"cuda_shim_memcpy_d2h",
127+
rewriter.getFunctionType({i64, i64, i64}, {})};
128+
129+
// ===== Stream =====
130+
case CudaShimFn::StreamCreate:
131+
return {"cuda_shim_stream_create", rewriter.getFunctionType({}, {i64})};
132+
133+
case CudaShimFn::StreamDestroy:
134+
return {"cuda_shim_stream_destroy", rewriter.getFunctionType({i64}, {})};
135+
136+
case CudaShimFn::StreamSynchronize:
137+
return {"cuda_shim_stream_synchronize",
138+
rewriter.getFunctionType({i64}, {})};
139+
140+
// case CudaShimFn::StreamWaitEvent:
141+
// return {"cuda_shim_stream_wait_event",
142+
// rewriter.getFunctionType({i64, i64}, {})};
143+
144+
// ===== Event =====
145+
// case CudaShimFn::EventCreate:
146+
// return {"cuda_shim_event_create", rewriter.getFunctionType({}, {i64})};
147+
148+
// case CudaShimFn::EventDestroy:
149+
// return {"cuda_shim_event_destroy", rewriter.getFunctionType({i64},
150+
// {})};
151+
152+
// case CudaShimFn::EventRecord:
153+
// return {"cuda_shim_event_record",
154+
// rewriter.getFunctionType({i64, i64}, {})};
155+
156+
// case CudaShimFn::EventSynchronize:
157+
// return {"cuda_shim_event_synchronize",
158+
// rewriter.getFunctionType({i64}, {})};
159+
160+
// ===== Launch =====
161+
case CudaShimFn::LaunchPacked:
162+
return {"cuda_shim_launch_packed",
163+
rewriter.getFunctionType(
164+
{
165+
i64, // module_handle
166+
i64, // kernel_name_ptr
167+
i32, i32, i32, // grid
168+
i32, i32, i32, // block
169+
i32, // sharedMemBytes
170+
i64, // stream
171+
i64, // arg_data_ptr
172+
i64, // arg_sizes_ptr
173+
i32 // num_args
174+
},
175+
{})};
176+
177+
// case CudaShimFn::LaunchBlockPacked:
178+
// return {"cuda_shim_launch_block_packed",
179+
// rewriter.getFunctionType(
180+
// {
181+
// i64, // module_handle
182+
// i64, // kernel_name_ptr
183+
// i32, i32, i32, // block
184+
// i64, // stream
185+
// i64, // arg_data_ptr
186+
// i64, // arg_sizes_ptr
187+
// i32 // num_args
188+
// },
189+
// {})};
190+
191+
// ===== Context =====
192+
case CudaShimFn::CtxSynchronize:
193+
return {"cuda_shim_ctx_synchronize", rewriter.getFunctionType({}, {})};
194+
}
195+
196+
llvm_unreachable("Unhandled CudaShimFn");
197+
}
198+
199+
mlir::ModuleOp module;
200+
llvm::DenseMap<unsigned, mlir::func::FuncOp> cache;
201+
};

0 commit comments

Comments
 (0)