From 116042955b0e377a17ba6fda70e50dac8831c3f4 Mon Sep 17 00:00:00 2001 From: Francesco Derme Date: Sat, 20 Dec 2025 11:04:39 +0100 Subject: [PATCH 1/3] Added warmup and emptied cache --- run.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/run.py b/run.py index 02f3469..5ee2af6 100644 --- a/run.py +++ b/run.py @@ -63,7 +63,7 @@ def get_op_name(executable_path): if "topk" in executable_path: return "topk" return "unknown" -def run_pytorch_benchmark(op_name, N, num_repeats=1, warmup=0): +def run_pytorch_benchmark(op_name, N, num_repeats=10, warmup=10): """ Runs the equivalent operation in PyTorch and measures execution time. Compatible with both CPU and GPU. @@ -210,6 +210,9 @@ def main(): else: cpp_k_ms, cpp_t_ms = None, None + if HAS_TORCH and torch.cuda.is_available(): + torch.cuda.empty_cache() + # 2. Run PyTorch Benchmark torch_ms = run_pytorch_benchmark(op_name, N) if HAS_TORCH else None From 5bca3a20c3f2ac694ef4005dc9dac956799b1b95 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Dec 2025 10:09:32 +0000 Subject: [PATCH 2/3] style: pre-commit fixes --- run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/run.py b/run.py index 74b23fd..66841bd 100644 --- a/run.py +++ b/run.py @@ -111,6 +111,7 @@ def get_op_name(executable_path): return "topk" return "unknown" + def run_pytorch_benchmark(op_name, N, num_repeats=10, warmup=10): """ Runs the equivalent operation in PyTorch and measures execution time. From 0899155832e2768e8f7274bae7eba825fcca68c6 Mon Sep 17 00:00:00 2001 From: Francesco Derme Date: Sat, 20 Dec 2025 11:15:48 +0100 Subject: [PATCH 3/3] Fixed caching --- run.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/run.py b/run.py index 66841bd..95ec09d 100644 --- a/run.py +++ b/run.py @@ -281,6 +281,9 @@ def main(gpu=False): print("-" * len(header)) for N in BENCHMARK_SIZES: + if HAS_TORCH and torch.cuda.is_available(): + torch.cuda.empty_cache() + # 1. Run C++ Benchmark cpp_res = run_cpp_benchmark(EXECUTABLE_PATH, [N]) if cpp_res: