From 38e9a17d6fee3e4501d0a04062962ff3e933b14b Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Sun, 22 Mar 2026 19:54:18 -0700 Subject: [PATCH] Add explicit dimensions for transposed A This is a minor refactor to eliminate a hack. Currently when we transpose A, we "explain" the transpose to downstream ops by changing the elem_size to represent `tile_k` values. This changes that to add an extra dimension with extent `tile_k` instead. PiperOrigin-RevId: 887832727 --- ynnpack/kernels/dot/BUILD | 1 + ynnpack/kernels/dot/arm64_sme.cc | 4 +- ynnpack/kernels/dot/arm64_sme2.cc | 4 +- ynnpack/kernels/dot/bench.cc | 8 +- .../kernels/dot/consistent_arithmetic_test.cc | 17 ++- ynnpack/kernels/dot/generator/dot_base.py | 18 ++-- ynnpack/kernels/dot/schedule_bench.cc | 25 ++--- ynnpack/kernels/dot/test.cc | 13 +-- ynnpack/subgraph/dot.cc | 102 ++++++++++-------- 9 files changed, 108 insertions(+), 84 deletions(-) diff --git a/ynnpack/kernels/dot/BUILD b/ynnpack/kernels/dot/BUILD index 0126f69350b..f019509a2c1 100644 --- a/ynnpack/kernels/dot/BUILD +++ b/ynnpack/kernels/dot/BUILD @@ -482,6 +482,7 @@ cc_binary( malloc = ynn_binary_malloc(), deps = [ ":dot", + ":pack_test_tensor", ":schedule", "//ynnpack/base", "//ynnpack/base/test:buffer", diff --git a/ynnpack/kernels/dot/arm64_sme.cc b/ynnpack/kernels/dot/arm64_sme.cc index d59b772f753..8415174bbe0 100644 --- a/ynnpack/kernels/dot/arm64_sme.cc +++ b/ynnpack/kernels/dot/arm64_sme.cc @@ -115,7 +115,7 @@ __arm_new("za") __arm_locally_streaming void sme_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); @@ -208,7 +208,7 @@ __arm_new("za") __arm_locally_streaming void sme_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); diff --git a/ynnpack/kernels/dot/arm64_sme2.cc b/ynnpack/kernels/dot/arm64_sme2.cc index 8bc2629b819..6ee0b3e355f 100644 --- a/ynnpack/kernels/dot/arm64_sme2.cc +++ b/ynnpack/kernels/dot/arm64_sme2.cc @@ -90,7 +90,7 @@ __arm_new("za") __arm_locally_streaming void sme2_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); @@ -183,7 +183,7 @@ __arm_new("za") __arm_locally_streaming void sme2_dot( k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); - A_k1 = offset_bytes(A_k1, A_stride_m); + A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor); } k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); diff --git a/ynnpack/kernels/dot/bench.cc b/ynnpack/kernels/dot/bench.cc index 1eabcf03019..2ce9dee9096 100644 --- a/ynnpack/kernels/dot/bench.cc +++ b/ynnpack/kernels/dot/bench.cc @@ -91,9 +91,11 @@ void dot(benchmark::State& state, uint64_t arch_flags, dot_kernel_fn kernel, for (size_t i = 0; i < m; i += block_m) { size_t m_i = std::min(block_m, m - i); const void* a_i = transpose_a ? &a(0, i * tile_k) : &a(i, 0); - kernel(m_i, n, 1, 1, k, a.stride(0) * sizeof(TA), 0, 0, a_i, 0, 0, - b.stride(0) * sizeof(TB), b.base(), /*init_c_stride_m=*/0, nullptr, - c.stride(0) * sizeof(TC), &c(i, 0)); + kernel(m_i, n, 1, 1, k, + a.stride(0) * sizeof(TA) / (transpose_a ? tile_k : 1), 0, 0, a_i, + 0, 0, b.stride(0) * sizeof(TB), b.base(), + /*init_c_stride_m=*/0, nullptr, c.stride(0) * sizeof(TC), + &c(i, 0)); } } diff --git a/ynnpack/kernels/dot/consistent_arithmetic_test.cc b/ynnpack/kernels/dot/consistent_arithmetic_test.cc index 98ccacd22d7..6f030c3bf51 100644 --- a/ynnpack/kernels/dot/consistent_arithmetic_test.cc +++ b/ynnpack/kernels/dot/consistent_arithmetic_test.cc @@ -129,15 +129,14 @@ void TestMatMul(AT, BT, CT, size_t k) { // dot kernels require B's k and n dimensions to be aligned to tile_k, // tile_n. The kernel might also require b to be packed (tile_k > 1). Tensor packed_b = pack_b(b, tile_k, tile_n); - Tensor packed_a = (kernel.flags & dot_flag::transpose_a) - ? transpose_a(a, tile_m, tile_k) - : a; - - kernel.kernel(m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT), 0, 0, - packed_a.base(), 0, 0, - packed_b.stride(0) * sizeof(BT) / tile_k, packed_b.base(), - kernel_c.stride(0) * sizeof(CT), kernel_c.base(), - kernel_c.stride(0) * sizeof(CT), kernel_c.base()); + const bool pack_a = kernel.flags & dot_flag::transpose_a; + Tensor packed_a = pack_a ? transpose_a(a, tile_m, tile_k) : a; + + kernel.kernel( + m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), + 0, 0, packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, + packed_b.base(), kernel_c.stride(0) * sizeof(CT), kernel_c.base(), + kernel_c.stride(0) * sizeof(CT), kernel_c.base()); if (c.base()) { int finite = 0; diff --git a/ynnpack/kernels/dot/generator/dot_base.py b/ynnpack/kernels/dot/generator/dot_base.py index 66fba1b64e7..46cf9968999 100644 --- a/ynnpack/kernels/dot/generator/dot_base.py +++ b/ynnpack/kernels/dot/generator/dot_base.py @@ -158,7 +158,12 @@ def a_ptr(self, i, k1, ty=None): k1 //= self.tile_shape[2] i = f"{i} * {self.tile_shape[2]}" i, k1 = k1, i - offset = f"({i} * A_stride_m) + ({k1} * sizeof({self.a_type}))" + offset = ( + f"({i * self.tile_shape[2]} * A_stride_m) + ({k1} *" + f" sizeof({self.a_type}))" + ) + else: + offset = f"({i} * A_stride_m) + ({k1} * sizeof({self.a_type}))" return f"reinterpret_cast(offset_bytes(A_k1, {offset}))" def b_ptr(self, k1, j, ty=None): @@ -398,7 +403,7 @@ def loop_k1(self): " B_stride_k1);\n" ) if "dot_flag::transpose_a" in self.flags: - a_step = f"{self.block_shape[2]//self.tile_shape[2]} * A_stride_m" + a_step = f"{self.block_shape[2]} * A_stride_m" else: a_step = f"{self.block_shape[2]} * sizeof({self.a_type})" block_body += f"A_k1 = offset_bytes(A_k1, {a_step});\n" @@ -420,10 +425,11 @@ def loop_k1(self): f"B_k1_{j} = offset_bytes(B_k1_{j}, {self.tile_shape[2]} *" " B_stride_k1);\n" ) - tile_body += ( - f"A_k1 = offset_bytes(A_k1, {self.tile_shape[2]} *" - f" sizeof({self.a_type}));\n" - ) + if "dot_flag::transpose_a" in self.flags: + a_step = f"{self.tile_shape[2]} * A_stride_m" + else: + a_step = f"{self.tile_shape[2]} * sizeof({self.a_type})" + tile_body += f"A_k1 = offset_bytes(A_k1, {a_step});\n" result += indent(tile_body, " ") + "\n" result += "}\n" diff --git a/ynnpack/kernels/dot/schedule_bench.cc b/ynnpack/kernels/dot/schedule_bench.cc index 7745f2b890f..ace792cd7e5 100644 --- a/ynnpack/kernels/dot/schedule_bench.cc +++ b/ynnpack/kernels/dot/schedule_bench.cc @@ -21,6 +21,7 @@ #include "ynnpack/base/test/util.h" #include "ynnpack/base/type.h" #include "ynnpack/kernels/dot/dot.h" +#include "ynnpack/kernels/dot/pack_test_tensor.h" #include "ynnpack/kernels/dot/schedule.h" namespace ynn { @@ -119,7 +120,7 @@ void fill(T* data, size_t n, int value) { template double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n, size_t k, span loops) { - const bool transpose_a = kernel.flags & dot_flag::transpose_a; + const bool pack_a = kernel.flags & dot_flag::transpose_a; const size_t tile_m = kernel.tile_m; const size_t tile_n = kernel.tile_n; @@ -144,25 +145,21 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n, c.fill(0); b = b.crop_padding({0, 0}, {b.extent(0) - k, b.extent(1) - n}); - if (transpose_a) { - // This mangles the data, but we don't care here. - a = a.reshape({k / tile_k, m * tile_k}); - } + a = pack_a ? transpose_a(a, tile_m, tile_k) : a; auto kernel_wrapper = [&](size_t m, size_t n, size_t k, const void* a_ptr, const void* b_ptr, size_t init_c_stride_m, const void* init_c, void* c_ptr) { - kernel.kernel(m, n, 1, 1, k, a.stride(0) * sizeof(TA), 0, 0, a_ptr, 0, 0, - b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m, init_c, - c.stride(0) * sizeof(TC), c_ptr); + kernel.kernel(m, n, 1, 1, k, + a.stride(0) * sizeof(TA) / (pack_a ? tile_k : 1), 0, 0, a_ptr, + 0, 0, b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m, + init_c, c.stride(0) * sizeof(TC), c_ptr); }; - const size_t a_stride_m = transpose_a - ? kernel.tile_k * sizeof(TA) / a_elem_count - : a.stride(0) * sizeof(TA); - const size_t a_stride_k = transpose_a - ? a.stride(0) * sizeof(TA) / kernel.tile_k - : a.stride(1) * sizeof(TA) / a_elem_count; + const size_t a_stride_m = pack_a ? kernel.tile_k * sizeof(TA) / a_elem_count + : a.stride(0) * sizeof(TA); + const size_t a_stride_k = pack_a ? a.stride(0) * sizeof(TA) / kernel.tile_k + : a.stride(1) * sizeof(TA) / a_elem_count; const size_t b_stride_k = b.stride(0) * sizeof(TB); const size_t b_stride_n = b.stride(1) * sizeof(TB) / b_elem_count; const size_t c_stride_m = c.stride(0) * sizeof(TC); diff --git a/ynnpack/kernels/dot/test.cc b/ynnpack/kernels/dot/test.cc index 70d38cf62b3..712740e8052 100644 --- a/ynnpack/kernels/dot/test.cc +++ b/ynnpack/kernels/dot/test.cc @@ -122,11 +122,11 @@ void TestMatMul(AT, BT, CT, const DotShape& shape, const KernelInfo& kernel, Tensor packed_b = unpacked_b ? b : pack_b(b, tile_k, tile_n); Tensor packed_a = pack_a ? transpose_a(a, tile_m, tile_k) : a; - kernel.kernel(m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT), 0, 0, - packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, - packed_b.base(), c.stride(0) * sizeof(CT), - init_zero ? nullptr : c.base(), c.stride(0) * sizeof(CT), - c.base()); + kernel.kernel( + m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), 0, + 0, packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k, + packed_b.base(), c.stride(0) * sizeof(CT), init_zero ? nullptr : c.base(), + c.stride(0) * sizeof(CT), c.base()); // Verify results. Reference(a, b, expected); @@ -226,7 +226,8 @@ void TestConv2D(AT, BT, CT, const KernelInfo& kernel) { {0, 0, 0, b.extent(3) - co / B_info::element_count()}); kernel.kernel( - w, co, kh, kw, ci, packed_a.stride(0) * sizeof(AT), + w, co, kh, kw, ci, + packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), packed_a.stride(1) * sizeof(AT), packed_a.stride(2) * sizeof(AT), packed_a.base(), packed_b.stride(0) * sizeof(BT), packed_b.stride(1) * sizeof(BT), diff --git a/ynnpack/subgraph/dot.cc b/ynnpack/subgraph/dot.cc index ba6190a1ad9..8349790dac8 100644 --- a/ynnpack/subgraph/dot.cc +++ b/ynnpack/subgraph/dot.cc @@ -64,14 +64,16 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const slinky::dim& dummy_dim = slinky::dim::broadcast(); // Learn what we need to know about m, n, k1, k2, k3 before slicing them. + const int a_k1_dim = transposed_a ? 1 : 0; const slinky::dim& init_c_m = init_c.dim(1); const slinky::dim& init_c_n = init_c.dim(0); const slinky::dim& c_m = c.dim(1); const slinky::dim& c_n = c.dim(0); - const slinky::dim& a_k1 = a.dim(0); - const slinky::dim& a_k2 = num_k_dims >= 2 ? a.dim(1) : dummy_dim; - const slinky::dim& a_k3 = num_k_dims >= 3 ? a.dim(2) : dummy_dim; - const slinky::dim& a_m = a.dim(num_k_dims); + const slinky::dim& a_k1i = transposed_a ? a.dim(0) : dummy_dim; + const slinky::dim& a_k1o = a.dim(a_k1_dim); + const slinky::dim& a_k2 = num_k_dims >= 2 ? a.dim(a_k1_dim + 1) : dummy_dim; + const slinky::dim& a_k3 = num_k_dims >= 3 ? a.dim(a_k1_dim + 2) : dummy_dim; + const slinky::dim& a_m = a.dim(a_k1_dim + num_k_dims); const slinky::dim& b_k1i = b.dim(0); const slinky::dim& b_ni = b.dim(1); const slinky::dim& b_k1o = b.dim(2); @@ -81,15 +83,15 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const int b_type_element_count = type_element_count(type.b); const index_t tile_k = b_k1i.extent() * b_type_element_count; + // If a is transposed, then the k dimension has been reshaped to have tile_k + // values in each element. + const index_t a_tile_k = a_k1i.extent(); const index_t a_stride_m = a_m.stride(); const index_t a_stride_k3 = a_k3.stride(); const index_t a_stride_k2 = a_k2.stride(); - const index_t a_stride_k1 = a_k1.stride(); - // If a is transposed, then the k dimension has been reshaped to have tile_k - // values in each element. - const index_t a_tile_k = transposed_a ? tile_k : 1; - const index_t k1 = (a_k1.extent() * a_tile_k) & ~(tile_k - 1); - const index_t k1_tail = (a_k1.extent() * a_tile_k) & (tile_k - 1); + const index_t a_stride_k1 = a_k1o.stride() / a_tile_k; + const index_t k1 = (a_k1o.extent() * a_tile_k) & ~(tile_k - 1); + const index_t k1_tail = (a_k1o.extent() * a_tile_k) & (tile_k - 1); const index_t k2 = a_k2.extent(); const index_t k3 = a_k3.extent(); const index_t block_n = pack_b ? b_ni.extent() : c_n.extent(); @@ -124,16 +126,22 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, dot_packed_shape packed_shape; packed_shape.block_n = block_n; packed_shape.tile_k = tile_k; + std::optional require_transpose_a = std::make_optional(transposed_a); + if (a_stride_m == a_stride_k1) { + // If the stride of m and k1 is the same (i.e. A is a vector of tile_k + // values), then we don't care if the kernel is transposed or not. + require_transpose_a = std::nullopt; + } dot_kernel kernel = get_dot_kernel( - type, shape, &packed_shape, consistent_arithmetic, - a_stride_m == a_stride_k1 ? std::nullopt - : std::make_optional(transposed_a)); + type, shape, &packed_shape, consistent_arithmetic, require_transpose_a); assert(kernel.kernel); assert(tile_k == kernel.tile_k); const index_t block_m = kernel.block_m; const index_t block_k = kernel.block_k; - assert(a_k1.min() == 0); + assert(a_k1i.min() == 0); + assert(a_k1o.min() == 0); + assert(a_tile_k == 1 || a_k1i.stride() == a.elem_size); assert(a_k2.min() == 0); assert(a_k3.min() == 0); assert(b_k1i.min() == 0); @@ -150,7 +158,8 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, assert(!c_m.is_folded()); assert(!c_n.is_folded()); assert(!a_m.is_folded(c_m.min(), c_m.max())); - assert(!a_k1.is_folded()); + assert(!a_k1i.is_folded()); + assert(!a_k1o.is_folded()); assert(!a_k2.is_folded()); assert(!a_k3.is_folded()); assert(!b_k1o.is_folded()); @@ -174,7 +183,7 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, // `for_each_element` below handles the batch dimensions, we handle the loop // over m, and the kernel handles the rest (n, k1, k2, k3). We need to slice // off these dimensions so we can handle them. - for (size_t i = 0; i < num_k_dims; ++i) { + for (size_t i = 0; i < a_k1_dim + num_k_dims; ++i) { a.slice(0); } a.slice(0, c_m.min()); @@ -200,8 +209,9 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, const index_t a_stride = transposed_a ? a_stride_k1 : a_stride_m; // The kernels assume that the column dimension of a is stride 1 element. - assert(transposed_a ? (a_m.extent() == 1 || a_stride_m == a.elem_size) - : (a_k1.extent() == 1 || a_stride_k1 == a.elem_size)); + assert(transposed_a + ? (a_m.extent() == 1 || a_stride_m == a.elem_size * a_tile_k) + : (a_k1o.extent() == 1 || a_stride_k1 == a.elem_size)); auto call_kernel = [=, kernel = kernel.kernel]( index_t m, index_t n, index_t k1, const void* a, @@ -227,7 +237,7 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, slinky::for_each_element( [&](void* c, const void* a, const void* b, const void* init_c) { run_dot(loops, c_m.extent(), c_n.extent(), k1, block_m, block_n, - block_k, a_stride_m, a_stride_k1 / a_tile_k, a, b_stride_k1, + block_k, a_stride_m, a_stride_k1, a, b_stride_k1, b_stride_n, b, init_c_stride_m, init_c, c_stride_m, c_stride_n, c, call_kernel); }, @@ -286,9 +296,9 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a, a = offset_bytes(a, a_stride_k1 * k1); b = offset_bytes(b, b_stride_k1 * k1); run_dot(loops, c_m.extent(), c_n.extent(), k1_tail, block_m, - block_n, block_k, a_stride_m, a_stride_k1, a, b_stride_k1, - b_stride_n, b, tail_init_c_stride_m, init_c, c_stride_m, - c_stride_n, c, call_kernel_tail); + block_n, block_k, a_stride_m, /*a_stride_k1=*/0, a, + /*b_stride_k1=*/0, b_stride_n, b, tail_init_c_stride_m, + init_c, c_stride_m, c_stride_n, c, call_kernel_tail); }, c, a, b, init_c); } @@ -474,15 +484,17 @@ uint32_t define_pack_b(ynn_subgraph_t subgraph, const dot_type& type, // interleaving `tile_k` rows at a time. // TODO(b/454146513): We should try to combine both pack_b and transpose_a into // a `split_transpose` op that can handle padding, split, and transpose. -auto make_transpose_a_impl(index_t tile_k, int m_dim) { +auto make_transpose_a_impl(int m_dim) { constexpr size_t max_rank = YNN_MAX_TENSOR_RANK + ynn_internal_extra_dims; - return [tile_k, m_dim](slinky::buffer input, - slinky::buffer output) -> index_t { + return [m_dim](slinky::buffer input, + slinky::buffer output) -> index_t { const slinky::dim& input_k = input.dim(0); const slinky::dim& input_m = input.dim(m_dim); - const slinky::dim& output_ko = output.dim(0); - const slinky::dim& output_m = output.dim(m_dim); + const slinky::dim& output_ki = output.dim(0); + const slinky::dim& output_ko = output.dim(1); + const slinky::dim& output_m = output.dim(m_dim + 1); + const index_t tile_k = output_ki.extent(); const index_t elem_size = input.elem_size; assert(output_m.extent() == 1 || output_m.stride() == elem_size * tile_k); (void)output_m; @@ -504,7 +516,7 @@ auto make_transpose_a_impl(index_t tile_k, int m_dim) { input.slice(0); input.slice(m_dim - 1, output_m.min()); - output.slice({0, static_cast(m_dim)}); + output.slice({0, 1, static_cast(m_dim + 1)}); slinky::for_each_element( [&](void* output, const void* input) { @@ -530,6 +542,7 @@ void define_transpose_a(ynn_subgraph& subgraph, ynn_node& node, index_t tile_k, output.extents = a.extents; output.extents[0] = slinky::simplify(slinky::ceil_div(k, tile_k)); + output.extents.insert(output.extents.begin(), tile_k); node.inputs = {input_a_id}; node.outputs = {output.id}; @@ -541,26 +554,29 @@ void define_transpose_a(ynn_subgraph& subgraph, ynn_node& node, index_t tile_k, const ynn_runtime_value& input = runtime.value(node.inputs[0]); ynn_runtime_value& output = runtime.value(node.outputs[0]); - slinky::expr elem_size = input.buffer->elem_size() * tile_k; + slinky::expr elem_size = input.buffer->elem_size(); output.make_buffer(runtime, elem_size); - output.buffer->dim(m_dim).stride = elem_size; - output.buffer->dim(0).stride = - elem_size * output.buffer->dim(m_dim).extent(); + output.buffer->dim(0).stride = elem_size; + output.buffer->dim(m_dim + 1).stride = elem_size * tile_k; + output.buffer->dim(1).stride = + elem_size * tile_k * output.buffer->dim(m_dim + 1).extent(); + // Don't allow folding of dimensions we transpose. - output.buffer->dim(m_dim).fold_factor = slinky::dim::unfolded; output.buffer->dim(0).fold_factor = slinky::dim::unfolded; + output.buffer->dim(m_dim + 1).fold_factor = slinky::dim::unfolded; + output.buffer->dim(1).fold_factor = slinky::dim::unfolded; // Split + Transpose std::vector dims = runtime.globals.make_dims(output.buffer->rank()); - slinky::expr ko = dims[0]; + slinky::expr ko = dims[1]; slinky::func::input func_input = {input.buffer}; func_input.bounds = { slinky::min_extent(ko * tile_k, tile_k), }; - for (size_t i = 1; i < dims.size(); ++i) { + for (size_t i = 2; i < dims.size(); ++i) { func_input.bounds.push_back(slinky::point(dims[i])); } // This transpose handles padding the input up to tile_k. @@ -570,7 +586,7 @@ void define_transpose_a(ynn_subgraph& subgraph, ynn_node& node, index_t tile_k, slinky::call_stmt::attributes attrs; attrs.name = "transpose_a"; - auto func = slinky::func::make(make_transpose_a_impl(tile_k, m_dim), + auto func = slinky::func::make(make_transpose_a_impl(m_dim), {std::move(func_input)}, {{output.buffer, dims}}, std::move(attrs)); @@ -958,22 +974,24 @@ ynn_status define_dot(ynn_subgraph& subgraph, size_t num_k_dims, slinky::var j = dims[0]; // A: We need all of the k dims, i is elementwise. - slinky::box_expr a_bounds(std::min(input_a.rank(), num_k_dims)); + const int num_a_k_dims = num_k_dims + (transpose_a ? 1 : 0); + slinky::box_expr a_bounds(std::min(input_a.rank(), num_a_k_dims)); for (size_t i = 0; i < a_bounds.size(); ++i) { a_bounds[i] = all_bounds(input_a.extent(i)); } // B: We need all of the k dims, j is elementwise. j has been split into // two dimensions. - slinky::box_expr b_bounds(num_k_dims + 3); + const int num_b_k_dims = num_k_dims + 3; + slinky::box_expr b_bounds(num_b_k_dims); b_bounds[0] = all_bounds(packed_b.extent(0)); // ki b_bounds[1] = all_bounds(packed_b.extent(1)); // ni b_bounds[2] = all_bounds(packed_b.extent(2)); // ko // When we split a packed dimension, the inner part of the split remains // packed, but the outer part is not. b_bounds[3] = slinky::point(j) / packed_b.extent(1); - for (size_t i = 1; i < num_k_dims; ++i) { - b_bounds[i + 3] = all_bounds(packed_b.extent(i + 3)); + for (size_t i = 4; i < num_b_k_dims; ++i) { + b_bounds[i] = all_bounds(packed_b.extent(i)); } // C: Elementwise @@ -984,9 +1002,9 @@ ynn_status define_dot(ynn_subgraph& subgraph, size_t num_k_dims, // Batch dims are elementwise too. for (size_t i = 1; i < dims.size(); ++i) { - if (i + num_k_dims - 1 < input_a.rank()) { + if (i + num_a_k_dims - 1 < input_a.rank()) { a_bounds.push_back( - elementwise_bounds(dims[i], input_a.extent(i + num_k_dims - 1))); + elementwise_bounds(dims[i], input_a.extent(i + num_a_k_dims - 1))); } if (i >= 2 && i + 2 + num_k_dims - 1 < packed_b.rank()) { b_bounds.push_back(elementwise_bounds(