diff --git a/ynnpack/kernels/dot/BUILD b/ynnpack/kernels/dot/BUILD index f019509a2c1..1fa18db2309 100644 --- a/ynnpack/kernels/dot/BUILD +++ b/ynnpack/kernels/dot/BUILD @@ -433,6 +433,7 @@ cc_test( malloc = ynn_binary_malloc(), deps = [ ":schedule", + "//ynnpack/base", ] + ynn_test_deps(), ) diff --git a/ynnpack/kernels/dot/schedule.cc b/ynnpack/kernels/dot/schedule.cc index 134d9d9e7f7..3459b44a27a 100644 --- a/ynnpack/kernels/dot/schedule.cc +++ b/ynnpack/kernels/dot/schedule.cc @@ -14,12 +14,18 @@ namespace ynn { span schedule_dot(span cache_sizes, size_t m, size_t n, - size_t k1, size_t k2, size_t k3, size_t block_m, + span ks, size_t block_m, size_t block_n, size_t block_k, size_t a_elem_size, size_t b_elem_size, dot_loop* storage) { dot_loop* begin = storage; dot_loop* loop = begin; + size_t k1 = ks[0]; + size_t k2 = 1; + for (size_t i = 1; i < ks.size(); ++i) { + k2 *= ks[i]; + } + // When we make a loop in a dimension, the extent of that dimension becomes // the step size of that loop. auto make_m_loop = [&](size_t blocks) { @@ -41,7 +47,7 @@ span schedule_dot(span cache_sizes, size_t m, size_t n, for (size_t cache_size : cache_sizes) { // TODO(b/447988052): We can be way smarter about this than we are now. make_k_loop( - floor_div(cache_size, k2 * k3 * block_n * b_elem_size * block_k)); + floor_div(cache_size, k2 * block_n * b_elem_size * block_k)); if (n * b_elem_size <= m * a_elem_size) { // Tiles of B are smaller than tiles of A, we should assume B fits in // cache. diff --git a/ynnpack/kernels/dot/schedule.h b/ynnpack/kernels/dot/schedule.h index 1b8deae73f6..38a61be3689 100644 --- a/ynnpack/kernels/dot/schedule.h +++ b/ynnpack/kernels/dot/schedule.h @@ -30,18 +30,21 @@ struct dot_loop { // maximized for each cache in `cache_sizes`. `storage` must have room for at // most 3 loops per cache size. span schedule_dot(span cache_sizes, size_t m, size_t n, - size_t k1, size_t k2, size_t k3, size_t block_m, + span ks, size_t block_m, size_t block_n, size_t block_k, size_t a_elem_size, size_t b_elem_size, dot_loop* storage); // Block a dot's m dimension, calling f at each block. template -void block_dot_m(ptrdiff_t m, size_t n, size_t k, ptrdiff_t block_m, - size_t a_stride_m, const void* a, const void* b, +void block_dot_m(ptrdiff_t m, size_t n, span ks, + ptrdiff_t block_m, size_t a_stride_m, + span a_k_strides, const void* a, + span b_k_strides, const void* b, size_t init_c_stride_m, const void* init_c, size_t c_stride_m, size_t c_stride_n, void* c, DotFn f) { do { - f(std::min(m, block_m), n, k, a, b, init_c_stride_m, init_c, c); + f(std::min(m, block_m), n, ks, a, a_stride_m, a_k_strides, b, b_k_strides, + init_c_stride_m, init_c, c); m -= block_m; if (init_c) init_c = offset_bytes(init_c, init_c_stride_m * block_m); @@ -52,12 +55,15 @@ void block_dot_m(ptrdiff_t m, size_t n, size_t k, ptrdiff_t block_m, // Block a dot's n dimension, calling f at each block. template -void block_dot_n(size_t m, ptrdiff_t n, size_t k, ptrdiff_t block_n, - const void* a, size_t b_stride_n, const void* b, - size_t init_c_stride_m, const void* init_c, size_t c_stride_m, - size_t c_stride_n, void* c, DotFn f) { +void block_dot_n(size_t m, ptrdiff_t n, span ks, + ptrdiff_t block_n, size_t a_stride_m, + span a_k_strides, const void* a, + size_t b_stride_n, span b_k_strides, + const void* b, size_t init_c_stride_m, const void* init_c, + size_t c_stride_m, size_t c_stride_n, void* c, DotFn f) { do { - f(m, std::min(n, block_n), k, a, b, init_c_stride_m, init_c, c); + f(m, std::min(n, block_n), ks, a, a_stride_m, a_k_strides, b, b_k_strides, + init_c_stride_m, init_c, c); n -= block_n; if (init_c) init_c = offset_bytes(init_c, c_stride_n * block_n); @@ -68,12 +74,18 @@ void block_dot_n(size_t m, ptrdiff_t n, size_t k, ptrdiff_t block_n, // Block a dot's k dimension, calling f at each block. template -void block_dot_k(size_t m, size_t n, ptrdiff_t k, ptrdiff_t block_k, - size_t a_stride_k, const void* a, size_t b_stride_k, - const void* b, size_t init_c_stride_m, const void* init_c, - size_t c_stride_m, size_t c_stride_n, void* c, DotFn f) { +void block_dot_k(size_t m, size_t n, span ks, ptrdiff_t block_k, + size_t a_stride_m, span a_k_strides, + const void* a, span b_k_strides, const void* b, + size_t init_c_stride_m, const void* init_c, size_t c_stride_m, + size_t c_stride_n, void* c, DotFn f) { + ptrdiff_t k = ks[0]; + size_t k_block[3]; + std::copy(ks.begin(), ks.end(), k_block); do { - f(m, n, std::min(k, block_k), a, b, init_c_stride_m, init_c, c); + k_block[0] = std::min(k, block_k); + f(m, n, {k_block, ks.size()}, a, a_stride_m, a_k_strides, b, b_k_strides, + init_c_stride_m, init_c, c); // Splitting k requires care for the initializer. The dot kernels read and // write from a separate buffer, so for each tile that we process in k, the @@ -90,18 +102,18 @@ void block_dot_k(size_t m, size_t n, ptrdiff_t k, ptrdiff_t block_k, init_c_stride_m = c_stride_m; k -= block_k; - b = offset_bytes(b, b_stride_k * block_k); - a = offset_bytes(a, a_stride_k * block_k); + b = offset_bytes(b, b_k_strides[0] * block_k); + a = offset_bytes(a, a_k_strides[0] * block_k); } while (k > 0); } template -void run_dot(span loops, size_t m, size_t n, size_t k, size_t block_m, - size_t block_n, size_t block_k, size_t a_stride_m, - size_t a_stride_k, const void* a, size_t b_stride_k, - size_t b_stride_n, const void* b, size_t init_c_stride_m, - const void* init_c, size_t c_stride_m, size_t c_stride_n, void* c, - DotFn f) { +void run_dot(span loops, size_t m, size_t n, span ks, + size_t block_m, size_t block_n, size_t block_k, size_t a_stride_m, + span a_k_strides, const void* a, + span b_k_strides, size_t b_stride_n, const void* b, + size_t init_c_stride_m, const void* init_c, size_t c_stride_m, + size_t c_stride_n, void* c, DotFn f) { assert(!loops.empty()); const dot_loop loop = loops.front(); loops = loops.subspan(1); @@ -110,40 +122,44 @@ void run_dot(span loops, size_t m, size_t n, size_t k, size_t block_m, // There are no more loops after this one. switch (loop.dim) { case dot_loop::m: - return block_dot_m(m, n, k, block_m * loop.blocks, a_stride_m, a, b, - init_c_stride_m, init_c, c_stride_m, c_stride_n, c, - f); + return block_dot_m(m, n, ks, block_m * loop.blocks, a_stride_m, + a_k_strides, a, b_k_strides, b, init_c_stride_m, + init_c, c_stride_m, c_stride_n, c, f); case dot_loop::n: - return block_dot_n(m, n, k, block_n * loop.blocks, a, b_stride_n, b, + return block_dot_n(m, n, ks, block_n * loop.blocks, a_stride_m, + a_k_strides, a, b_stride_n, b_k_strides, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, f); case dot_loop::k: - return block_dot_k(m, n, k, block_k * loop.blocks, a_stride_k, a, - b_stride_k, b, init_c_stride_m, init_c, c_stride_m, - c_stride_n, c, f); + return block_dot_k(m, n, ks, block_k * loop.blocks, a_stride_m, + a_k_strides, a, b_k_strides, b, init_c_stride_m, + init_c, c_stride_m, c_stride_n, c, f); } } else { // Recursively call `run_dot` with the subsequent loops. - auto recursive_f = [=](size_t m, size_t n, size_t k, const void* a, - const void* b, size_t init_c_stride_m, - const void* init_c, void* c) { - run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, - a, b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, - c_stride_n, c, f); - }; + auto recursive_f = + [=](size_t m, size_t n, span ks, const void* a, + size_t a_stride_m, span a_k_strides, const void* b, + span b_k_strides, size_t init_c_stride_m, + const void* init_c, void* c) { + run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, + a_k_strides, a, b_k_strides, b_stride_n, b, init_c_stride_m, + init_c, c_stride_m, c_stride_n, c, f); + }; switch (loop.dim) { case dot_loop::m: - return block_dot_m(m, n, k, block_m * loop.blocks, a_stride_m, a, b, - init_c_stride_m, init_c, c_stride_m, c_stride_n, c, - recursive_f); + return block_dot_m(m, n, ks, block_m * loop.blocks, a_stride_m, + a_k_strides, a, b_k_strides, b, init_c_stride_m, + init_c, c_stride_m, c_stride_n, c, recursive_f); case dot_loop::n: - return block_dot_n(m, n, k, block_n * loop.blocks, a, b_stride_n, b, + return block_dot_n(m, n, ks, block_n * loop.blocks, a_stride_m, + a_k_strides, a, b_stride_n, b_k_strides, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, recursive_f); case dot_loop::k: - return block_dot_k(m, n, k, block_k * loop.blocks, a_stride_k, a, - b_stride_k, b, init_c_stride_m, init_c, c_stride_m, - c_stride_n, c, recursive_f); + return block_dot_k(m, n, ks, block_k * loop.blocks, a_stride_m, + a_k_strides, a, b_k_strides, b, init_c_stride_m, + init_c, c_stride_m, c_stride_n, c, recursive_f); } } } diff --git a/ynnpack/kernels/dot/schedule_bench.cc b/ynnpack/kernels/dot/schedule_bench.cc index ace792cd7e5..f86ff96591c 100644 --- a/ynnpack/kernels/dot/schedule_bench.cc +++ b/ynnpack/kernels/dot/schedule_bench.cc @@ -147,14 +147,16 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n, a = pack_a ? transpose_a(a, tile_m, tile_k) : a; - auto kernel_wrapper = [&](size_t m, size_t n, size_t k, const void* a_ptr, - const void* b_ptr, size_t init_c_stride_m, - const void* init_c, void* c_ptr) { - kernel.kernel(m, n, 1, 1, k, - a.stride(0) * sizeof(TA) / (pack_a ? tile_k : 1), 0, 0, a_ptr, - 0, 0, b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m, - init_c, c.stride(0) * sizeof(TC), c_ptr); - }; + auto kernel_wrapper = + [&](size_t m, size_t n, span k, const void* a_ptr, + size_t a_stride_m, span a_k_strides, const void* b_ptr, + span b_k_strides, size_t init_c_stride_m, + const void* init_c, void* c_ptr) { + kernel.kernel(m, n, k[2], k[1], k[0], a_stride_m, + a_k_strides[2], a_k_strides[1], a_ptr, b_k_strides[2], + b_k_strides[1], b_k_strides[0], b_ptr, init_c_stride_m, + init_c, c.stride(0) * sizeof(TC), c_ptr); + }; const size_t a_stride_m = pack_a ? kernel.tile_k * sizeof(TA) / a_elem_count : a.stride(0) * sizeof(TA); @@ -165,10 +167,15 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n, const size_t c_stride_m = c.stride(0) * sizeof(TC); const size_t c_stride_n = c.stride(1) * sizeof(TC); + const size_t ks[] = {k, 1, 1}; + const size_t a_k_strides[] = {a_stride_k, 0, 0}; + const size_t b_k_strides[] = {b_stride_k, 0, 0}; + double t = benchmark([&]() { - run_dot(loops, m, n, k, kernel.block_m, kernel.block_n, kernel.block_k, - a_stride_m, a_stride_k, a.base(), b_stride_k, b_stride_n, b.base(), - 0, nullptr, c_stride_m, c_stride_n, c.base(), kernel_wrapper); + run_dot(loops, m, n, ks, kernel.block_m, kernel.block_n, kernel.block_k, + a_stride_m, a_k_strides, a.base(), b_k_strides, b_stride_n, + b.base(), 0, nullptr, c_stride_m, c_stride_n, c.base(), + kernel_wrapper); }); // Check that the kernel didn't compute the wrong thing. We assume the kernel // is correct, but we have some logic here that needs validation too. We diff --git a/ynnpack/kernels/dot/schedule_test.cc b/ynnpack/kernels/dot/schedule_test.cc index 81c6c7c3491..c20aa3de374 100644 --- a/ynnpack/kernels/dot/schedule_test.cc +++ b/ynnpack/kernels/dot/schedule_test.cc @@ -12,6 +12,7 @@ #include #include +#include "ynnpack/base/span.h" using testing::ElementsAre; @@ -84,17 +85,22 @@ dot_call dot_call_at(size_t m, size_t n, size_t k, size_t i, size_t j, }; auto make_record_calls(std::vector& calls) { - return [&](size_t m, size_t n, size_t k, const void* a, const void* b, - size_t init_c_stride_m, const void* init_c, - const void* c) { calls.push_back({m, n, k, a, b, init_c, c}); }; + return [&](size_t m, size_t n, span k, const void* a, + size_t a_stride_m, span a_k_strides, const void* b, + span b_k_strides, size_t init_c_stride_m, + const void* init_c, + const void* c) { calls.push_back({m, n, k[0], a, b, init_c, c}); }; } TEST(run_dot, loop_m) { const dot_loop loops[] = {{dot_loop::m, 1}}; + const size_t ks[] = {k}; + const size_t a_k_strides[] = {a_stride_k}; + const size_t b_k_strides[] = {b_stride_k}; std::vector calls; - run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, a, - b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, + run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, a_k_strides, + a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, make_record_calls(calls)); ASSERT_THAT(calls, ElementsAre(dot_call_at(block_m, n, k, 0 * block_m, 0, 0), @@ -104,10 +110,13 @@ TEST(run_dot, loop_m) { TEST(run_dot, loop_n) { const dot_loop loops[] = {{dot_loop::n, 1}}; + const size_t ks[] = {k}; + const size_t a_k_strides[] = {a_stride_k}; + const size_t b_k_strides[] = {b_stride_k}; std::vector calls; - run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, a, - b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, + run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, a_k_strides, + a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, make_record_calls(calls)); ASSERT_THAT(calls, @@ -120,24 +129,31 @@ TEST(run_dot, loop_n) { TEST(run_dot, loop_n_tail) { const dot_loop loops[] = {{dot_loop::n, 2}}; + const size_t ks[] = {k}; + const size_t a_k_strides[] = {a_stride_k}; + const size_t b_k_strides[] = {b_stride_k}; std::vector calls; - run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, a, - b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, + run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, a_k_strides, + a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, make_record_calls(calls)); - ASSERT_THAT(calls, - ElementsAre(dot_call_at(m, 2 * block_n, k, 0, 0 * block_n, 0), - dot_call_at(m, 2 * block_n, k, 0, 2 * block_n, 0), - dot_call_at(m, block_n, k, 0, 4 * block_n, 0))); + ASSERT_THAT( + calls, + ElementsAre(dot_call_at(m, 2 * block_n, k, 0, 0 * block_n, 0), + dot_call_at(m, 2 * block_n, k, 0, 2 * block_n, 0), + dot_call_at(m, n - 4 * block_n, k, 0, 4 * block_n, 0))); } TEST(run_dot, loop_k) { const dot_loop loops[] = {{dot_loop::k, 1}}; + const size_t ks[] = {k}; + const size_t a_k_strides[] = {a_stride_k}; + const size_t b_k_strides[] = {b_stride_k}; std::vector calls; - run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, a, - b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, + run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, a_k_strides, + a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, make_record_calls(calls)); ASSERT_THAT(calls, diff --git a/ynnpack/subgraph/dot.cc b/ynnpack/subgraph/dot.cc index a119a437efa..e25bdd23755 100644 --- a/ynnpack/subgraph/dot.cc +++ b/ynnpack/subgraph/dot.cc @@ -21,6 +21,7 @@ #include "ynnpack/base/arithmetic.h" #include "ynnpack/base/base.h" #include "ynnpack/base/log.h" +#include "ynnpack/base/span.h" #include "ynnpack/base/type.h" #include "ynnpack/include/ynnpack.h" #include "ynnpack/kernels/dot/pack.h" @@ -207,22 +208,36 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, // hopes of making i bigger, which should improve performance in cases where // block_m does not divide c_m.extent() - const index_t a_stride = transposed_a ? a_stride_k1 : a_stride_m; // The kernels assume that the column dimension of a is stride 1 element. assert(transposed_a ? (a_m.extent() == 1 || a_stride_m == a.elem_size * a_tile_k) : (a_k1o.extent() == 1 || a_stride_k1 == a.elem_size)); - auto call_kernel = [=, kernel = kernel.kernel]( - index_t m, index_t n, index_t k1, const void* a, - const void* b, index_t init_c_stride_m, - const void* init_c, void* c) { - assert(n <= block_n); - assert(m <= block_m); - kernel(m, n, k3, k2, k1, a_stride, a_stride_k3, a_stride_k2, a, - b_stride_k3, b_stride_k2, b_stride_k1, b, init_c_stride_m, init_c, - c_stride_m, c); + std::array k = {static_cast(k1), static_cast(k2), + static_cast(k3)}; + std::array a_k_strides = { + static_cast(a_stride_k1), + static_cast(a_stride_k2), + static_cast(a_stride_k3), }; + std::array b_k_strides = { + static_cast(b_stride_k1), + static_cast(b_stride_k2), + static_cast(b_stride_k3), + }; + + auto call_kernel = + [transposed_a, c_stride_m, kernel = kernel.kernel]( + index_t m, index_t n, span k, const void* a, + size_t a_stride_m, span a_k_strides, const void* b, + span b_k_strides, index_t init_c_stride_m, + const void* init_c, void* c) { + kernel(m, n, k[2], k[1], k[0], + transposed_a ? a_k_strides[0] : a_stride_m, + a_k_strides[2], a_k_strides[1], a, b_k_strides[2], + b_k_strides[1], b_k_strides[0], b, init_c_stride_m, init_c, + c_stride_m, c); + }; const size_t cache_sizes[] = {cache_size_l2}; @@ -230,23 +245,26 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, dot_loop loops_storage[std::size(cache_sizes) * 3]; if (k1) { - auto loops = schedule_dot(cache_sizes, c_m.extent(), c_n.extent(), k1, k2, - k3, block_m, block_n, block_k, a.elem_size, + auto loops = schedule_dot(cache_sizes, c_m.extent(), c_n.extent(), k, + block_m, block_n, block_k, a.elem_size, b.elem_size, loops_storage); slinky::for_each_element( [&](void* c, const void* a, const void* b, const void* init_c) { - run_dot(loops, c_m.extent(), c_n.extent(), k1, block_m, block_n, - block_k, a_stride_m, a_stride_k1, a, b_stride_k1, + run_dot(loops, c_m.extent(), c_n.extent(), k, block_m, block_n, + block_k, a_stride_m, a_k_strides, a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, call_kernel); }, c, a, b, init_c); } if (k1_tail) { - auto loops = schedule_dot(cache_sizes, c_m.extent(), c_n.extent(), - k1_tail, k2, k3, block_m, block_n, block_k, - a.elem_size, b.elem_size, loops_storage); + std::array k_tail = {static_cast(k1_tail), + static_cast(k2), + static_cast(k3)}; + auto loops = schedule_dot(cache_sizes, c_m.extent(), c_n.extent(), k_tail, + block_m, block_n, block_k, a.elem_size, + b.elem_size, loops_storage); // Dot kernels can't handle k1 not aligned to tile_k. We handle that here // by making a padded copy of the unaligned elements and calling the // kernel again. @@ -260,32 +278,38 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const index_t a_padded_stride_m = a.elem_size * tile_k; void* a_padded = YNN_ALLOCA(uint8_t, block_m* a_padded_stride_m); memset(a_padded, 0, a_padded_stride_m * block_m); - auto call_kernel_tail = [&](index_t m, index_t n, index_t k1, - const void* a, const void* b, - index_t init_c_stride_m, const void* init_c, - void* c) { - assert(m <= block_m); - assert(n <= block_n); - assert(k1 < tile_k); - for (index_t K3 = 0; K3 < k3; ++K3) { - for (index_t K2 = 0; K2 < k2; ++K2) { - for (index_t i = 0; i < m; ++i) { - memcpy(offset_bytes(a_padded, i * a_padded_stride_m), - offset_bytes(a, i * a_stride_m + K3 * a_stride_k3 + - K2 * a_stride_k2), - k1 * a_elem_size); + const size_t a_k_strides_tail_v[] = {static_cast(a_stride_k1), + static_cast(a_stride_k2), + static_cast(a_stride_k3)}; + span a_k_strides_tail = a_k_strides_tail_v; + auto call_kernel_tail = + [&](index_t m, index_t n, span k, const void* a, + size_t a_stride_m, span a_k_strides, const void* b, + span b_k_strides, index_t init_c_stride_m, + const void* init_c, void* c) { + assert(m <= block_m); + assert(n <= block_n); + assert(k[0] < tile_k); + for (index_t K3 = 0; K3 < k3; ++K3) { + for (index_t K2 = 0; K2 < k2; ++K2) { + for (index_t i = 0; i < m; ++i) { + memcpy(offset_bytes(a_padded, i * a_padded_stride_m), + offset_bytes(a, i * a_stride_m + K3 * a_k_strides[2] + + K2 * a_k_strides[1]), + k[0] * a_elem_size); + } + kernel.kernel( + m, n, /*k3=*/1, /*k2=*/1, tile_k, a_padded_stride_m, + /*a_stride_k3=*/0, /*a_stride_k2=*/0, a_padded, + /*b_stride_k3=*/0, + /*b_stride_k2=*/0, b_k_strides[0], + offset_bytes(b, K3 * b_k_strides[2] + K2 * b_k_strides[1]), + init_c_stride_m, init_c, c_stride_m, c); + init_c_stride_m = c_stride_m; + init_c = c; + } } - kernel.kernel(m, n, /*k3=*/1, /*k2=*/1, tile_k, a_padded_stride_m, - /*a_stride_k3=*/0, /*a_stride_k2=*/0, a_padded, - /*b_stride_k3=*/0, - /*b_stride_k2=*/0, b_stride_k1, - offset_bytes(b, K3 * b_stride_k3 + K2 * b_stride_k2), - init_c_stride_m, init_c, c_stride_m, c); - init_c_stride_m = c_stride_m; - init_c = c; - } - } - }; + }; slinky::for_each_element( [&](void* c, const void* a, const void* b, const void* init_c) { index_t tail_init_c_stride_m = init_c_stride_m; @@ -295,12 +319,10 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, } a = offset_bytes(a, a_stride_k1 * k1); b = offset_bytes(b, b_stride_k1 * k1); - // We only expect to run one iteration of k here, so the k strides - // are irrelevant. - run_dot(loops, c_m.extent(), c_n.extent(), k1_tail, block_m, - block_n, block_k, a_stride_m, /*a_stride_k1=*/0, a, - /*b_stride_k1=*/0, b_stride_n, b, tail_init_c_stride_m, - init_c, c_stride_m, c_stride_n, c, call_kernel_tail); + run_dot(loops, c_m.extent(), c_n.extent(), k_tail, block_m, block_n, + block_k, a_stride_m, a_k_strides_tail, a, b_k_strides, + b_stride_n, b, tail_init_c_stride_m, init_c, c_stride_m, + c_stride_n, c, call_kernel_tail); }, c, a, b, init_c); }