Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions example/15_grouped_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)

add_example_executable(example_grouped_gemm_wmma_fixed_nk_fp16 grouped_gemm_wmma_fixed_nk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_fp16)


list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
Expand Down
382 changes: 382 additions & 0 deletions example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}

std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
return pass;
}

Expand Down Expand Up @@ -329,9 +330,9 @@ int main(int argc, char* argv[])

for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(128 + rand() % 128);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(1024);
problem_size.Ms.push_back(256);
problem_size.Ns.push_back(256);
problem_size.Ks.push_back(256);

problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
Expand Down
2 changes: 2 additions & 0 deletions example/15_grouped_gemm/run_grouped_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
ComputeDataType>(c_device_tensors[i], c_host_tensors[i]);
#endif
}

std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}

if(config.time_kernel)
Expand Down
Loading
Loading