From c2a70e7164b9eba33e44d1a896cae142b1d72e0b Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Tue, 31 Mar 2026 21:37:47 -0700 Subject: [PATCH] Fix discrepancy in a_stride_m/a_stride_k for transposed dot kernels Currently, the stride of k for transposed kernels (passed as a_stride_m because this is the row dimension when A is transposed) is the stride of tile_k values of k. This is inconsistent, because the stride is not for one value of k, which we assume in several places. This leads to multiplying or dividing strides to make them consistent. In particular, `run_dot` multiplies the stride by k, while kernels do not, which means we can't use the same stride for both `run_dot` and a kernel. This discrepancy is preventing refactoring `run_dot` to capture the strides to pass to the kernels easily, which I think is a necessary step towards addressing some issues (packing A/B in the loops of `run_dot`). PiperOrigin-RevId: 892702737 --- ynnpack/kernels/dot/BUILD | 1 + ynnpack/kernels/dot/arm64_sme.cc | 4 +-- ynnpack/kernels/dot/arm64_sme2.cc | 4 +-- ynnpack/kernels/dot/bench.cc | 8 +++--- .../kernels/dot/consistent_arithmetic_test.cc | 17 ++++++------- ynnpack/kernels/dot/generator/dot_base.py | 18 +++++++------ ynnpack/kernels/dot/schedule_bench.cc | 25 ++++++++----------- ynnpack/kernels/dot/test.cc | 13 +++++----- ynnpack/subgraph/dot.cc | 21 +++++++++------- 9 files changed, 58 insertions(+), 53 deletions(-) diff --git a/ynnpack/kernels/dot/BUILD b/ynnpack/kernels/dot/BUILD index 0126f69350b..f019509a2c1 100644 --- a/ynnpack/kernels/dot/BUILD +++ b/ynnpack/kernels/dot/BUILD @@ -482,6 +482,7 @@ cc_binary( malloc = ynn_binary_malloc(), deps = [ ":dot", + ":pack_test_tensor", ":schedule", "//ynnpack/base", "//ynnpack/base/test:buffer", diff --git a/ynnpack/kernels/dot/arm64_sme.cc b/ynnpack/kernels/dot/arm64_sme.cc index d59b772f753..8415174bbe0 100644 --- a/ynnpack/kernels/dot/arm64_sme.cc +++ b/ynnpack/kernels/dot/arm64_sme.cc @@ -115,7 +115,7 @@ __arm_new("za") __arm_locally_streaming void sme_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); @@ -208,7 +208,7 @@ __arm_new("za") __arm_locally_streaming void sme_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); diff --git a/ynnpack/kernels/dot/arm64_sme2.cc b/ynnpack/kernels/dot/arm64_sme2.cc index 8bc2629b819..6ee0b3e355f 100644 --- a/ynnpack/kernels/dot/arm64_sme2.cc +++ b/ynnpack/kernels/dot/arm64_sme2.cc @@ -90,7 +90,7 @@ __arm_new("za") __arm_locally_streaming void sme2_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); @@ -183,7 +183,7 @@ __arm_new("za") __arm_locally_streaming void sme2_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); diff --git a/ynnpack/kernels/dot/bench.cc b/ynnpack/kernels/dot/bench.cc index 1eabcf03019..2ce9dee9096 100644 --- a/ynnpack/kernels/dot/bench.cc +++ b/ynnpack/kernels/dot/bench.cc @@ -91,9 +91,11 @@ void dot(benchmark::State& state, uint64_t arch_flags, dot_kernel_fn kernel, for (size_t i = 0; i < m; i += block_m) { size_t m_i = std::min(block_m, m - i); const void* a_i = transpose_a ? &a(0, i * tile_k) : &a(i, 0); - kernel(m_i, n, 1, 1, k, a.stride(0) * sizeof(TA), 0, 0, a_i, 0, 0, - b.stride(0) * sizeof(TB), b.base(), /*init_c_stride_m=*/0, nullptr, - c.stride(0) * sizeof(TC), &c(i, 0)); + kernel(m_i, n, 1, 1, k, + a.stride(0) * sizeof(TA) / (transpose_a ? tile_k : 1), 0, 0, a_i, + 0, 0, b.stride(0) * sizeof(TB), b.base(), + /*init_c_stride_m=*/0, nullptr, c.stride(0) * sizeof(TC), + &c(i, 0)); } } diff --git a/ynnpack/kernels/dot/consistent_arithmetic_test.cc b/ynnpack/kernels/dot/consistent_arithmetic_test.cc index 98ccacd22d7..6f030c3bf51 100644 --- a/ynnpack/kernels/dot/consistent_arithmetic_test.cc +++ b/ynnpack/kernels/dot/consistent_arithmetic_test.cc @@ -129,15 +129,14 @@ void TestMatMul(AT, BT, CT, size_t k) { // dot kernels require B's k and n dimensions to be aligned to tile_k, // tile_n. The kernel might also require b to be packed (tile_k > 1). Tensor packed_b = pack_b(b, tile_k, tile_n); - Tensor packed_a = (kernel.flags & dot_flag::transpose_a) - ? transpose_a(a, tile_m, tile_k) - : a; - - kernel.kernel(m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT), 0, 0, - packed_a.base(), 0, 0, - packed_b.stride(0) * sizeof(BT) / tile_k, packed_b.base(), - kernel_c.stride(0) * sizeof(CT), kernel_c.base(), - kernel_c.stride(0) * sizeof(CT), kernel_c.base()); + const bool pack_a = kernel.flags & dot_flag::transpose_a; + Tensor packed_a = pack_a ? transpose_a(a, tile_m, tile_k) : a; + + kernel.kernel( + m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), + 0, 0, packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, + packed_b.base(), kernel_c.stride(0) * sizeof(CT), kernel_c.base(), + kernel_c.stride(0) * sizeof(CT), kernel_c.base()); if (c.base()) { int finite = 0; diff --git a/ynnpack/kernels/dot/generator/dot_base.py b/ynnpack/kernels/dot/generator/dot_base.py index 66fba1b64e7..45d87c6195a 100644 --- a/ynnpack/kernels/dot/generator/dot_base.py +++ b/ynnpack/kernels/dot/generator/dot_base.py @@ -155,7 +155,6 @@ def a_ptr(self, i, k1, ty=None): # When we clamp, we need to align down to the nearest tile. i = f"min({i}, (M - 1) & ~{self.tile_shape[0] - 1})" if i != 0 else i if "dot_flag::transpose_a" in self.flags: - k1 //= self.tile_shape[2] i = f"{i} * {self.tile_shape[2]}" i, k1 = k1, i offset = f"({i} * A_stride_m) + ({k1} * sizeof({self.a_type}))" @@ -384,6 +383,12 @@ def loop_k1(self): block_body = self.generate_block(self.block_shape[2]) tile_body = self.generate_block(self.tile_shape[2]) + a_step = ( + "A_stride_m" + if "dot_flag::transpose_a" in self.flags + else f"sizeof({self.a_type})" + ) + if block_body == tile_body: tile_body = None @@ -397,11 +402,9 @@ def loop_k1(self): f"B_k1_{j} = offset_bytes(B_k1_{j}, {self.block_shape[2]} *" " B_stride_k1);\n" ) - if "dot_flag::transpose_a" in self.flags: - a_step = f"{self.block_shape[2]//self.tile_shape[2]} * A_stride_m" - else: - a_step = f"{self.block_shape[2]} * sizeof({self.a_type})" - block_body += f"A_k1 = offset_bytes(A_k1, {a_step});\n" + block_body += ( + f"A_k1 = offset_bytes(A_k1, {self.block_shape[2]} * {a_step});\n" + ) result += indent(block_body, " ") + "\n" if tile_body: result += "}\n" @@ -421,8 +424,7 @@ def loop_k1(self): " B_stride_k1);\n" ) tile_body += ( - f"A_k1 = offset_bytes(A_k1, {self.tile_shape[2]} *" - f" sizeof({self.a_type}));\n" + f"A_k1 = offset_bytes(A_k1, {self.tile_shape[2]} * {a_step});\n" ) result += indent(tile_body, " ") + "\n" result += "}\n" diff --git a/ynnpack/kernels/dot/schedule_bench.cc b/ynnpack/kernels/dot/schedule_bench.cc index 7745f2b890f..ace792cd7e5 100644 --- a/ynnpack/kernels/dot/schedule_bench.cc +++ b/ynnpack/kernels/dot/schedule_bench.cc @@ -21,6 +21,7 @@ #include "ynnpack/base/test/util.h" #include "ynnpack/base/type.h" #include "ynnpack/kernels/dot/dot.h" +#include "ynnpack/kernels/dot/pack_test_tensor.h" #include "ynnpack/kernels/dot/schedule.h" namespace ynn { @@ -119,7 +120,7 @@ void fill(T* data, size_t n, int value) { template double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n, size_t k, span loops) { - const bool transpose_a = kernel.flags & dot_flag::transpose_a; + const bool pack_a = kernel.flags & dot_flag::transpose_a; const size_t tile_m = kernel.tile_m; const size_t tile_n = kernel.tile_n; @@ -144,25 +145,21 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n, c.fill(0); b = b.crop_padding({0, 0}, {b.extent(0) - k, b.extent(1) - n}); - if (transpose_a) { - // This mangles the data, but we don't care here. - a = a.reshape({k / tile_k, m * tile_k}); - } + a = pack_a ? transpose_a(a, tile_m, tile_k) : a; auto kernel_wrapper = [&](size_t m, size_t n, size_t k, const void* a_ptr, const void* b_ptr, size_t init_c_stride_m, const void* init_c, void* c_ptr) { - kernel.kernel(m, n, 1, 1, k, a.stride(0) * sizeof(TA), 0, 0, a_ptr, 0, 0, - b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m, init_c, - c.stride(0) * sizeof(TC), c_ptr); + kernel.kernel(m, n, 1, 1, k, + a.stride(0) * sizeof(TA) / (pack_a ? tile_k : 1), 0, 0, a_ptr, + 0, 0, b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m, + init_c, c.stride(0) * sizeof(TC), c_ptr); }; - const size_t a_stride_m = transpose_a - ? kernel.tile_k * sizeof(TA) / a_elem_count - : a.stride(0) * sizeof(TA); - const size_t a_stride_k = transpose_a - ? a.stride(0) * sizeof(TA) / kernel.tile_k - : a.stride(1) * sizeof(TA) / a_elem_count; + const size_t a_stride_m = pack_a ? kernel.tile_k * sizeof(TA) / a_elem_count + : a.stride(0) * sizeof(TA); + const size_t a_stride_k = pack_a ? a.stride(0) * sizeof(TA) / kernel.tile_k + : a.stride(1) * sizeof(TA) / a_elem_count; const size_t b_stride_k = b.stride(0) * sizeof(TB); const size_t b_stride_n = b.stride(1) * sizeof(TB) / b_elem_count; const size_t c_stride_m = c.stride(0) * sizeof(TC); diff --git a/ynnpack/kernels/dot/test.cc b/ynnpack/kernels/dot/test.cc index 70d38cf62b3..712740e8052 100644 --- a/ynnpack/kernels/dot/test.cc +++ b/ynnpack/kernels/dot/test.cc @@ -122,11 +122,11 @@ void TestMatMul(AT, BT, CT, const DotShape& shape, const KernelInfo& kernel, Tensor packed_b = unpacked_b ? b : pack_b(b, tile_k, tile_n); Tensor packed_a = pack_a ? transpose_a(a, tile_m, tile_k) : a; - kernel.kernel(m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT), 0, 0, - packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, - packed_b.base(), c.stride(0) * sizeof(CT), - init_zero ? nullptr : c.base(), c.stride(0) * sizeof(CT), - c.base()); + kernel.kernel( + m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), 0, + 0, packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, + packed_b.base(), c.stride(0) * sizeof(CT), init_zero ? nullptr : c.base(), + c.stride(0) * sizeof(CT), c.base()); // Verify results. Reference(a, b, expected); @@ -226,7 +226,8 @@ void TestConv2D(AT, BT, CT, const KernelInfo& kernel) { {0, 0, 0, b.extent(3) - co / B_info::element_count()}); kernel.kernel( - w, co, kh, kw, ci, packed_a.stride(0) * sizeof(AT), + w, co, kh, kw, ci, + packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), packed_a.stride(1) * sizeof(AT), packed_a.stride(2) * sizeof(AT), packed_a.base(), packed_b.stride(0) * sizeof(BT), packed_b.stride(1) * sizeof(BT), diff --git a/ynnpack/subgraph/dot.cc b/ynnpack/subgraph/dot.cc index ba6190a1ad9..8afe5931bde 100644 --- a/ynnpack/subgraph/dot.cc +++ b/ynnpack/subgraph/dot.cc @@ -81,13 +81,13 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const int b_type_element_count = type_element_count(type.b); const index_t tile_k = b_k1i.extent() * b_type_element_count; - const index_t a_stride_m = a_m.stride(); - const index_t a_stride_k3 = a_k3.stride(); - const index_t a_stride_k2 = a_k2.stride(); - const index_t a_stride_k1 = a_k1.stride(); // If a is transposed, then the k dimension has been reshaped to have tile_k // values in each element. const index_t a_tile_k = transposed_a ? tile_k : 1; + const index_t a_stride_m = a_m.stride(); + const index_t a_stride_k3 = a_k3.stride(); + const index_t a_stride_k2 = a_k2.stride(); + const index_t a_stride_k1 = a_k1.stride() / a_tile_k; const index_t k1 = (a_k1.extent() * a_tile_k) & ~(tile_k - 1); const index_t k1_tail = (a_k1.extent() * a_tile_k) & (tile_k - 1); const index_t k2 = a_k2.extent(); @@ -227,7 +227,7 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, slinky::for_each_element( [&](void* c, const void* a, const void* b, const void* init_c) { run_dot(loops, c_m.extent(), c_n.extent(), k1, block_m, block_n, - block_k, a_stride_m, a_stride_k1 / a_tile_k, a, b_stride_k1, + block_k, a_stride_m, a_stride_k1, a, b_stride_k1, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, call_kernel); }, @@ -265,7 +265,8 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, K2 * a_stride_k2), k1 * a_elem_size); } - kernel.kernel(m, n, /*k3=*/1, /*k2=*/1, tile_k, a_padded_stride_m, + kernel.kernel(m, n, /*k3=*/1, /*k2=*/1, tile_k, + transposed_a ? a_elem_size : a_padded_stride_m, /*a_stride_k3=*/0, /*a_stride_k2=*/0, a_padded, /*b_stride_k3=*/0, /*b_stride_k2=*/0, b_stride_k1, @@ -285,10 +286,12 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, } a = offset_bytes(a, a_stride_k1 * k1); b = offset_bytes(b, b_stride_k1 * k1); + // We only expect to run one iteration of k here, so the k strides + // are irrelevant. run_dot(loops, c_m.extent(), c_n.extent(), k1_tail, block_m, - block_n, block_k, a_stride_m, a_stride_k1, a, b_stride_k1, - b_stride_n, b, tail_init_c_stride_m, init_c, c_stride_m, - c_stride_n, c, call_kernel_tail); + block_n, block_k, a_stride_m, /*a_stride_k1=*/0, a, + /*b_stride_k1=*/0, b_stride_n, b, tail_init_c_stride_m, + init_c, c_stride_m, c_stride_n, c, call_kernel_tail); }, c, a, b, init_c); }