@@ -58,7 +58,7 @@ def gpu_module():
5858set_container_module (ctx .module )
5959
6060v_len = 16
61- M , K , N = 16 , 16 , 16
61+ M , K , N = 512 , 512 , 512
6262TILE_SIZE = BK = 16
6363dtype = T .f16 ()
6464np_dtype = np .float16
@@ -78,23 +78,27 @@ def kernel(
7878
7979 row = block_idx .y * TILE_SIZE + thread_idx .y
8080 col = block_idx .x * TILE_SIZE + thread_idx .x
81+ lane = thread_idx .x % v_len
8182 # gpu.printf("(%ld, %ld)\n", row, col)
8283 # vector.print_(source=row)
8384
8485 sum = arith .constant (np .full ([v_len ], 0.0 , np_dtype ), v16 )
85- for t , sum , _ in scf .range_ (0 , N , BK , iter_args = [sum ]):
86- Bs [thread_idx .y , thread_idx .x ] = B [col , thread_idx .y + t ]
87- As [thread_idx .y , thread_idx .x ] = A [row , thread_idx .x + t ]
8886
87+ Bs [thread_idx .y , thread_idx .x ] = B [col , thread_idx .y + 0 ]
88+ As [thread_idx .y , thread_idx .x ] = A [row , thread_idx .x + 0 ]
89+
90+ for t , sum , _ in scf .range_ (BK , N + BK , BK , iter_args = [sum ]):
8991 gpu .barrier ()
9092
91- lane = thread_idx .x % v_len
9293 a_frag = As @ vector .load (v16 ) @ [lane , 0 ]
9394 b_frag = Bs @ vector .load (v16 ) @ [lane , 0 ]
9495
95- # call the WMMA intrinsic
96- false = arith .constant (False , T .bool ())
97- sum = rocdl .wmma_f16_16x16x16_f16 (v16 , [a_frag , b_frag , sum , false ])
96+ sum = rocdl .wmma_f16_16x16x16_f16 (a_frag , b_frag , sum )
97+
98+ if arith .index_cast (t , T .i32 ()) < N :
99+ Bs [thread_idx .y , thread_idx .x ] = B [col , thread_idx .y + t ]
100+ As [thread_idx .y , thread_idx .x ] = A [row , thread_idx .x + t ]
101+
98102 sum = yield sum
99103
100104 C [row , col ] = sum [2 * (row // 2 )]
@@ -142,18 +146,25 @@ def gpu_module():
142146hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
143147function = hip_check (hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ()))
144148
145- a_h = np .random .randint (0 , 10 , (M , K )).astype (dtype = np_dtype )
146- b_h = np .random .randint (0 , 10 , (K , N )).astype (dtype = np_dtype )
147- # a_h = np.ones((M, K)).astype(dtype=np_dtype)
148- # b_h = np.ones((K, N)).astype(dtype=np_dtype)
149- c_h = 0 * np .ones ((M , N ), dtype = np_dtype )
149+ # a_h = np.random.randint(1, 5, (M, K)).astype(dtype=np_dtype)
150+ # b_h = np.random.randint(1, 5, (K, N)).astype(dtype=np_dtype)
150151
152+ # a_h = np.random.rand(M, K).astype(np_dtype)
153+ # b_h = np.random.rand(K, N).astype(np_dtype)
154+
155+ a_h = 3 * np .ones ((M , K )).astype (dtype = np_dtype )
156+ a_h [0 : M // 2 , 0 : K // 2 ] = 0
157+ a_h [M // 2 : M , K // 2 : K ] = 1
158+ b_h = 2 * np .ones ((K , N )).astype (dtype = np_dtype )
159+ b_h [0 : K // 2 , 0 : N // 2 ] = 2
160+ b_h [K // 2 : K , N // 2 : N ] = 3
161+
162+ c_h = 0 * np .ones ((M , N ), dtype = np .float32 )
151163for k in range (K ):
152- a = a_h [:, k ]
153- b = b_h [k , :]
164+ a = a_h . astype ( np . float32 ) [:, k ]
165+ b = b_h . astype ( np . float32 ) [k , :]
154166 c_h += np .outer (a , b )
155-
156- assert np .allclose (a_h @ b_h , c_h )
167+ assert np .allclose (a_h .astype (np .float32 ) @ b_h .astype (np .float32 ), c_h )
157168
158169c_h = - 3 * np .ones ((M , N ), dtype = np_dtype )
159170a_num_bytes = a_h .size * a_h .itemsize
@@ -210,10 +221,12 @@ def gpu_module():
210221
211222if not np .allclose (c_h , correct ):
212223 with np .printoptions (threshold = np .inf , linewidth = np .inf ):
213- print ("correct\n " , correct )
214- print ("c_h\n " , c_h )
224+ # print("correct\n", correct)
225+ # print("c_h\n", c_h)
215226 print ("off by atol" , np .max (np .abs (correct - c_h )))
216227 print ("off by rtol" , np .max (np .abs (correct - c_h ) / correct ))
228+ print ("num incorrect" , np .sum (np .abs (correct - c_h ) != 0 ))
229+ print ("fraction incorrect" , np .sum (np .abs (correct - c_h ) != 0 ) / (M * N ))
217230
218231
219232hip_check (hip .hipFree (a_d ))
0 commit comments