diff --git a/run.py b/run.py index 4cc9a31..95ec09d 100644 --- a/run.py +++ b/run.py @@ -112,7 +112,7 @@ def get_op_name(executable_path): return "unknown" -def run_pytorch_benchmark(op_name, N, num_repeats=1, warmup=0): +def run_pytorch_benchmark(op_name, N, num_repeats=10, warmup=10): """ Runs the equivalent operation in PyTorch and measures execution time. Compatible with both CPU and GPU. @@ -281,6 +281,9 @@ def main(gpu=False): print("-" * len(header)) for N in BENCHMARK_SIZES: + if HAS_TORCH and torch.cuda.is_available(): + torch.cuda.empty_cache() + # 1. Run C++ Benchmark cpp_res = run_cpp_benchmark(EXECUTABLE_PATH, [N]) if cpp_res: