diff --git a/ynnpack/kernels/dot/BUILD b/ynnpack/kernels/dot/BUILD index 0126f69350b..f019509a2c1 100644 --- a/ynnpack/kernels/dot/BUILD +++ b/ynnpack/kernels/dot/BUILD @@ -482,6 +482,7 @@ cc_binary( malloc = ynn_binary_malloc(), deps = [ ":dot", + ":pack_test_tensor", ":schedule", "//ynnpack/base", "//ynnpack/base/test:buffer", diff --git a/ynnpack/kernels/dot/arm64_sme.cc b/ynnpack/kernels/dot/arm64_sme.cc index d59b772f753..8415174bbe0 100644 --- a/ynnpack/kernels/dot/arm64_sme.cc +++ b/ynnpack/kernels/dot/arm64_sme.cc @@ -115,7 +115,7 @@ __arm_new("za") __arm_locally_streaming void sme_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); @@ -208,7 +208,7 @@ __arm_new("za") __arm_locally_streaming void sme_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); diff --git a/ynnpack/kernels/dot/arm64_sme2.cc b/ynnpack/kernels/dot/arm64_sme2.cc index 8bc2629b819..6ee0b3e355f 100644 --- a/ynnpack/kernels/dot/arm64_sme2.cc +++ b/ynnpack/kernels/dot/arm64_sme2.cc @@ -90,7 +90,7 @@ __arm_new("za") __arm_locally_streaming void sme2_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); @@ -183,7 +183,7 @@ __arm_new("za") __arm_locally_streaming void sme2_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); diff --git a/ynnpack/kernels/dot/bench.cc b/ynnpack/kernels/dot/bench.cc index 1eabcf03019..2ce9dee9096 100644 --- a/ynnpack/kernels/dot/bench.cc +++ b/ynnpack/kernels/dot/bench.cc @@ -91,9 +91,11 @@ void dot(benchmark::State& state, uint64_t arch_flags, dot_kernel_fn kernel, for (size_t i = 0; i < m; i += block_m) { size_t m_i = std::min(block_m, m - i); const void* a_i = transpose_a ? &a(0, i * tile_k) : &a(i, 0); - kernel(m_i, n, 1, 1, k, a.stride(0) * sizeof(TA), 0, 0, a_i, 0, 0, - b.stride(0) * sizeof(TB), b.base(), /*init_c_stride_m=*/0, nullptr, - c.stride(0) * sizeof(TC), &c(i, 0)); + kernel(m_i, n, 1, 1, k, + a.stride(0) * sizeof(TA) / (transpose_a ? tile_k : 1), 0, 0, a_i, + 0, 0, b.stride(0) * sizeof(TB), b.base(), + /*init_c_stride_m=*/0, nullptr, c.stride(0) * sizeof(TC), + &c(i, 0)); } } diff --git a/ynnpack/kernels/dot/consistent_arithmetic_test.cc b/ynnpack/kernels/dot/consistent_arithmetic_test.cc index 98ccacd22d7..6f030c3bf51 100644 --- a/ynnpack/kernels/dot/consistent_arithmetic_test.cc +++ b/ynnpack/kernels/dot/consistent_arithmetic_test.cc @@ -129,15 +129,14 @@ void TestMatMul(AT, BT, CT, size_t k) { // dot kernels require B's k and n dimensions to be aligned to tile_k, // tile_n. The kernel might also require b to be packed (tile_k > 1). Tensor packed_b = pack_b(b, tile_k, tile_n); - Tensor packed_a = (kernel.flags & dot_flag::transpose_a) - ? transpose_a(a, tile_m, tile_k) - : a; - - kernel.kernel(m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT), 0, 0, - packed_a.base(), 0, 0, - packed_b.stride(0) * sizeof(BT) / tile_k, packed_b.base(), - kernel_c.stride(0) * sizeof(CT), kernel_c.base(), - kernel_c.stride(0) * sizeof(CT), kernel_c.base()); + const bool pack_a = kernel.flags & dot_flag::transpose_a; + Tensor packed_a = pack_a ? transpose_a(a, tile_m, tile_k) : a; + + kernel.kernel( + m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), + 0, 0, packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, + packed_b.base(), kernel_c.stride(0) * sizeof(CT), kernel_c.base(), + kernel_c.stride(0) * sizeof(CT), kernel_c.base()); if (c.base()) { int finite = 0; diff --git a/ynnpack/kernels/dot/generator/dot_base.py b/ynnpack/kernels/dot/generator/dot_base.py index 66fba1b64e7..45d87c6195a 100644 --- a/ynnpack/kernels/dot/generator/dot_base.py +++ b/ynnpack/kernels/dot/generator/dot_base.py @@ -155,7 +155,6 @@ def a_ptr(self, i, k1, ty=None): # When we clamp, we need to align down to the nearest tile. i = f"min({i}, (M - 1) & ~{self.tile_shape[0] - 1})" if i != 0 else i if "dot_flag::transpose_a" in self.flags: - k1 //= self.tile_shape[2] i = f"{i} * {self.tile_shape[2]}" i, k1 = k1, i offset = f"({i} * A_stride_m) + ({k1} * sizeof({self.a_type}))" @@ -384,6 +383,12 @@ def loop_k1(self): block_body = self.generate_block(self.block_shape[2]) tile_body = self.generate_block(self.tile_shape[2]) + a_step = ( + "A_stride_m" + if "dot_flag::transpose_a" in self.flags + else f"sizeof({self.a_type})" + ) + if block_body == tile_body: tile_body = None @@ -397,11 +402,9 @@ def loop_k1(self): f"B_k1_{j} = offset_bytes(B_k1_{j}, {self.block_shape[2]} *" " B_stride_k1);\n" ) - if "dot_flag::transpose_a" in self.flags: - a_step = f"{self.block_shape[2]//self.tile_shape[2]} * A_stride_m" - else: - a_step = f"{self.block_shape[2]} * sizeof({self.a_type})" - block_body += f"A_k1 = offset_bytes(A_k1, {a_step});\n" + block_body += ( + f"A_k1 = offset_bytes(A_k1, {self.block_shape[2]} * {a_step});\n" + ) result += indent(block_body, " ") + "\n" if tile_body: result += "}\n" @@ -421,8 +424,7 @@ def loop_k1(self): " B_stride_k1);\n" ) tile_body += ( - f"A_k1 = offset_bytes(A_k1, {self.tile_shape[2]} *" - f" sizeof({self.a_type}));\n" + f"A_k1 = offset_bytes(A_k1, {self.tile_shape[2]} * {a_step});\n" ) result += indent(tile_body, " ") + "\n" result += "}\n" diff --git a/ynnpack/kernels/dot/schedule_bench.cc b/ynnpack/kernels/dot/schedule_bench.cc index 7745f2b890f..ace792cd7e5 100644 --- a/ynnpack/kernels/dot/schedule_bench.cc +++ b/ynnpack/kernels/dot/schedule_bench.cc @@ -21,6 +21,7 @@ #include "ynnpack/base/test/util.h" #include "ynnpack/base/type.h" #include "ynnpack/kernels/dot/dot.h" +#include "ynnpack/kernels/dot/pack_test_tensor.h" #include "ynnpack/kernels/dot/schedule.h" namespace ynn { @@ -119,7 +120,7 @@ void fill(T* data, size_t n, int value) { template double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n, size_t k, span loops) { - const bool transpose_a = kernel.flags & dot_flag::transpose_a; + const bool pack_a = kernel.flags & dot_flag::transpose_a; const size_t tile_m = kernel.tile_m; const size_t tile_n = kernel.tile_n; @@ -144,25 +145,21 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n, c.fill(0); b = b.crop_padding({0, 0}, {b.extent(0) - k, b.extent(1) - n}); - if (transpose_a) { - // This mangles the data, but we don't care here. - a = a.reshape({k / tile_k, m * tile_k}); - } + a = pack_a ? transpose_a(a, tile_m, tile_k) : a; auto kernel_wrapper = [&](size_t m, size_t n, size_t k, const void* a_ptr, const void* b_ptr, size_t init_c_stride_m, const void* init_c, void* c_ptr) { - kernel.kernel(m, n, 1, 1, k, a.stride(0) * sizeof(TA), 0, 0, a_ptr, 0, 0, - b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m, init_c, - c.stride(0) * sizeof(TC), c_ptr); + kernel.kernel(m, n, 1, 1, k, + a.stride(0) * sizeof(TA) / (pack_a ? tile_k : 1), 0, 0, a_ptr, + 0, 0, b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m, + init_c, c.stride(0) * sizeof(TC), c_ptr); }; - const size_t a_stride_m = transpose_a - ? kernel.tile_k * sizeof(TA) / a_elem_count - : a.stride(0) * sizeof(TA); - const size_t a_stride_k = transpose_a - ? a.stride(0) * sizeof(TA) / kernel.tile_k - : a.stride(1) * sizeof(TA) / a_elem_count; + const size_t a_stride_m = pack_a ? kernel.tile_k * sizeof(TA) / a_elem_count + : a.stride(0) * sizeof(TA); + const size_t a_stride_k = pack_a ? a.stride(0) * sizeof(TA) / kernel.tile_k + : a.stride(1) * sizeof(TA) / a_elem_count; const size_t b_stride_k = b.stride(0) * sizeof(TB); const size_t b_stride_n = b.stride(1) * sizeof(TB) / b_elem_count; const size_t c_stride_m = c.stride(0) * sizeof(TC); diff --git a/ynnpack/kernels/dot/test.cc b/ynnpack/kernels/dot/test.cc index 70d38cf62b3..712740e8052 100644 --- a/ynnpack/kernels/dot/test.cc +++ b/ynnpack/kernels/dot/test.cc @@ -122,11 +122,11 @@ void TestMatMul(AT, BT, CT, const DotShape& shape, const KernelInfo& kernel, Tensor packed_b = unpacked_b ? b : pack_b(b, tile_k, tile_n); Tensor packed_a = pack_a ? transpose_a(a, tile_m, tile_k) : a; - kernel.kernel(m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT), 0, 0, - packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, - packed_b.base(), c.stride(0) * sizeof(CT), - init_zero ? nullptr : c.base(), c.stride(0) * sizeof(CT), - c.base()); + kernel.kernel( + m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), 0, + 0, packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, + packed_b.base(), c.stride(0) * sizeof(CT), init_zero ? nullptr : c.base(), + c.stride(0) * sizeof(CT), c.base()); // Verify results. Reference(a, b, expected); @@ -226,7 +226,8 @@ void TestConv2D(AT, BT, CT, const KernelInfo& kernel) { {0, 0, 0, b.extent(3) - co / B_info::element_count()}); kernel.kernel( - w, co, kh, kw, ci, packed_a.stride(0) * sizeof(AT), + w, co, kh, kw, ci, + packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), packed_a.stride(1) * sizeof(AT), packed_a.stride(2) * sizeof(AT), packed_a.base(), packed_b.stride(0) * sizeof(BT), packed_b.stride(1) * sizeof(BT), diff --git a/ynnpack/subgraph/dot.cc b/ynnpack/subgraph/dot.cc index ba6190a1ad9..8afe5931bde 100644 --- a/ynnpack/subgraph/dot.cc +++ b/ynnpack/subgraph/dot.cc @@ -81,13 +81,13 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const int b_type_element_count = type_element_count(type.b); const index_t tile_k = b_k1i.extent() * b_type_element_count; - const index_t a_stride_m = a_m.stride(); - const index_t a_stride_k3 = a_k3.stride(); - const index_t a_stride_k2 = a_k2.stride(); - const index_t a_stride_k1 = a_k1.stride(); // If a is transposed, then the k dimension has been reshaped to have tile_k // values in each element. const index_t a_tile_k = transposed_a ? tile_k : 1; + const index_t a_stride_m = a_m.stride(); + const index_t a_stride_k3 = a_k3.stride(); + const index_t a_stride_k2 = a_k2.stride(); + const index_t a_stride_k1 = a_k1.stride() / a_tile_k; const index_t k1 = (a_k1.extent() * a_tile_k) & ~(tile_k - 1); const index_t k1_tail = (a_k1.extent() * a_tile_k) & (tile_k - 1); const index_t k2 = a_k2.extent(); @@ -227,7 +227,7 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, slinky::for_each_element( [&](void* c, const void* a, const void* b, const void* init_c) { run_dot(loops, c_m.extent(), c_n.extent(), k1, block_m, block_n, - block_k, a_stride_m, a_stride_k1 / a_tile_k, a, b_stride_k1, + block_k, a_stride_m, a_stride_k1, a, b_stride_k1, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, call_kernel); }, @@ -265,7 +265,8 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, K2 * a_stride_k2), k1 * a_elem_size); } - kernel.kernel(m, n, /*k3=*/1, /*k2=*/1, tile_k, a_padded_stride_m, + kernel.kernel(m, n, /*k3=*/1, /*k2=*/1, tile_k, + transposed_a ? a_elem_size : a_padded_stride_m, /*a_stride_k3=*/0, /*a_stride_k2=*/0, a_padded, /*b_stride_k3=*/0, /*b_stride_k2=*/0, b_stride_k1, @@ -285,10 +286,12 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, } a = offset_bytes(a, a_stride_k1 * k1); b = offset_bytes(b, b_stride_k1 * k1); + // We only expect to run one iteration of k here, so the k strides + // are irrelevant. run_dot(loops, c_m.extent(), c_n.extent(), k1_tail, block_m, - block_n, block_k, a_stride_m, a_stride_k1, a, b_stride_k1, - b_stride_n, b, tail_init_c_stride_m, init_c, c_stride_m, - c_stride_n, c, call_kernel_tail); + block_n, block_k, a_stride_m, /*a_stride_k1=*/0, a, + /*b_stride_k1=*/0, b_stride_n, b, tail_init_c_stride_m, + init_c, c_stride_m, c_stride_n, c, call_kernel_tail); }, c, a, b, init_c); }