-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmatmul.comp
More file actions
48 lines (39 loc) · 1.32 KB
/
matmul.comp
File metadata and controls
48 lines (39 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#version 450
// Workgroup size (specialization constants)
layout(constant_id = 0) const uint Bx = 16;
layout(constant_id = 1) const uint By = 16;
layout(constant_id = 2) const uint Bz = 1;
layout(local_size_x_id = 0,
local_size_y_id = 1,
local_size_z_id = 2) in;
// Buffers
layout(set = 0, binding = 0) buffer InA { int a[]; }; // M x K
layout(set = 0, binding = 1) buffer InB { int b[]; }; // K x N
layout(set = 0, binding = 2) buffer OutC { int c[]; }; // M x N
layout(set = 0, binding = 3) buffer M { int m[]; }; // m[0] = M
layout(set = 0, binding = 4) buffer K { int k[]; }; // k[0] = K
layout(set = 0, binding = 5) buffer N { int n[]; }; // n[0] = N
void main()
{
// Global coordinates in C: row = y, col = x
uint col = gl_GlobalInvocationID.x; // N dimension
uint row = gl_GlobalInvocationID.y; // M dimension
int M = m[0];
int K = k[0];
int N = n[0];
// Bounds check: some threads may fall outside matrix size
if (row >= uint(M) || col >= uint(N)) {
return;
}
int sum = 0;
// Row-major:
// A[row, kk] = a[row * K + kk]
// B[kk, col] = b[kk * N + col]
for (int kk = 0; kk < K; ++kk) {
int aVal = a[row * K + kk];
int bVal = b[kk * N + col];
sum += aVal * bVal;
}
// C[row, col] = sum
c[row * N + col] = sum;
}