Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ynnpack/kernels/dot/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ cc_binary(
malloc = ynn_binary_malloc(),
deps = [
":dot",
":pack_test_tensor",
":schedule",
"//ynnpack/base",
"//ynnpack/base/test:buffer",
Expand Down
4 changes: 2 additions & 2 deletions ynnpack/kernels/dot/arm64_sme.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ __arm_new("za") __arm_locally_streaming void sme_dot(

k1 -= dot_factor;
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor);
}
k2 -= 1;
B_k2 = offset_bytes(B_k2, B_stride_k2);
Expand Down Expand Up @@ -208,7 +208,7 @@ __arm_new("za") __arm_locally_streaming void sme_dot(

k1 -= dot_factor;
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor);
}
k2 -= 1;
B_k2 = offset_bytes(B_k2, B_stride_k2);
Expand Down
4 changes: 2 additions & 2 deletions ynnpack/kernels/dot/arm64_sme2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ __arm_new("za") __arm_locally_streaming void sme2_dot(

k1 -= dot_factor;
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor);
}
k2 -= 1;
B_k2 = offset_bytes(B_k2, B_stride_k2);
Expand Down Expand Up @@ -183,7 +183,7 @@ __arm_new("za") __arm_locally_streaming void sme2_dot(

k1 -= dot_factor;
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
A_k1 = offset_bytes(A_k1, A_stride_m * dot_factor);
}
k2 -= 1;
B_k2 = offset_bytes(B_k2, B_stride_k2);
Expand Down
8 changes: 5 additions & 3 deletions ynnpack/kernels/dot/bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@ void dot(benchmark::State& state, uint64_t arch_flags, dot_kernel_fn kernel,
for (size_t i = 0; i < m; i += block_m) {
size_t m_i = std::min(block_m, m - i);
const void* a_i = transpose_a ? &a(0, i * tile_k) : &a(i, 0);
kernel(m_i, n, 1, 1, k, a.stride(0) * sizeof(TA), 0, 0, a_i, 0, 0,
b.stride(0) * sizeof(TB), b.base(), /*init_c_stride_m=*/0, nullptr,
c.stride(0) * sizeof(TC), &c(i, 0));
kernel(m_i, n, 1, 1, k,
a.stride(0) * sizeof(TA) / (transpose_a ? tile_k : 1), 0, 0, a_i,
0, 0, b.stride(0) * sizeof(TB), b.base(),
/*init_c_stride_m=*/0, nullptr, c.stride(0) * sizeof(TC),
&c(i, 0));
}
}

Expand Down
17 changes: 8 additions & 9 deletions ynnpack/kernels/dot/consistent_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,14 @@ void TestMatMul(AT, BT, CT, size_t k) {
// dot kernels require B's k and n dimensions to be aligned to tile_k,
// tile_n. The kernel might also require b to be packed (tile_k > 1).
Tensor<BT> packed_b = pack_b(b, tile_k, tile_n);
Tensor<AT> packed_a = (kernel.flags & dot_flag::transpose_a)
? transpose_a(a, tile_m, tile_k)
: a;

kernel.kernel(m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT), 0, 0,
packed_a.base(), 0, 0,
packed_b.stride(0) * sizeof(BT) / tile_k, packed_b.base(),
kernel_c.stride(0) * sizeof(CT), kernel_c.base(),
kernel_c.stride(0) * sizeof(CT), kernel_c.base());
const bool pack_a = kernel.flags & dot_flag::transpose_a;
Tensor<AT> packed_a = pack_a ? transpose_a(a, tile_m, tile_k) : a;

kernel.kernel(
m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1),
0, 0, packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k,
packed_b.base(), kernel_c.stride(0) * sizeof(CT), kernel_c.base(),
kernel_c.stride(0) * sizeof(CT), kernel_c.base());

if (c.base()) {
int finite = 0;
Expand Down
18 changes: 10 additions & 8 deletions ynnpack/kernels/dot/generator/dot_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def a_ptr(self, i, k1, ty=None):
# When we clamp, we need to align down to the nearest tile.
i = f"min({i}, (M - 1) & ~{self.tile_shape[0] - 1})" if i != 0 else i
if "dot_flag::transpose_a" in self.flags:
k1 //= self.tile_shape[2]
i = f"{i} * {self.tile_shape[2]}"
i, k1 = k1, i
offset = f"({i} * A_stride_m) + ({k1} * sizeof({self.a_type}))"
Expand Down Expand Up @@ -384,6 +383,12 @@ def loop_k1(self):
block_body = self.generate_block(self.block_shape[2])
tile_body = self.generate_block(self.tile_shape[2])

a_step = (
"A_stride_m"
if "dot_flag::transpose_a" in self.flags
else f"sizeof({self.a_type})"
)

if block_body == tile_body:
tile_body = None

Expand All @@ -397,11 +402,9 @@ def loop_k1(self):
f"B_k1_{j} = offset_bytes(B_k1_{j}, {self.block_shape[2]} *"
" B_stride_k1);\n"
)
if "dot_flag::transpose_a" in self.flags:
a_step = f"{self.block_shape[2]//self.tile_shape[2]} * A_stride_m"
else:
a_step = f"{self.block_shape[2]} * sizeof({self.a_type})"
block_body += f"A_k1 = offset_bytes(A_k1, {a_step});\n"
block_body += (
f"A_k1 = offset_bytes(A_k1, {self.block_shape[2]} * {a_step});\n"
)
result += indent(block_body, " ") + "\n"
if tile_body:
result += "}\n"
Expand All @@ -421,8 +424,7 @@ def loop_k1(self):
" B_stride_k1);\n"
)
tile_body += (
f"A_k1 = offset_bytes(A_k1, {self.tile_shape[2]} *"
f" sizeof({self.a_type}));\n"
f"A_k1 = offset_bytes(A_k1, {self.tile_shape[2]} * {a_step});\n"
)
result += indent(tile_body, " ") + "\n"
result += "}\n"
Expand Down
25 changes: 11 additions & 14 deletions ynnpack/kernels/dot/schedule_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "ynnpack/base/test/util.h"
#include "ynnpack/base/type.h"
#include "ynnpack/kernels/dot/dot.h"
#include "ynnpack/kernels/dot/pack_test_tensor.h"
#include "ynnpack/kernels/dot/schedule.h"

namespace ynn {
Expand Down Expand Up @@ -119,7 +120,7 @@ void fill(T* data, size_t n, int value) {
template <typename TA, typename TB, typename TC>
double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n,
size_t k, span<dot_loop> loops) {
const bool transpose_a = kernel.flags & dot_flag::transpose_a;
const bool pack_a = kernel.flags & dot_flag::transpose_a;

const size_t tile_m = kernel.tile_m;
const size_t tile_n = kernel.tile_n;
Expand All @@ -144,25 +145,21 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n,
c.fill(0);
b = b.crop_padding({0, 0}, {b.extent(0) - k, b.extent(1) - n});

if (transpose_a) {
// This mangles the data, but we don't care here.
a = a.reshape({k / tile_k, m * tile_k});
}
a = pack_a ? transpose_a(a, tile_m, tile_k) : a;

auto kernel_wrapper = [&](size_t m, size_t n, size_t k, const void* a_ptr,
const void* b_ptr, size_t init_c_stride_m,
const void* init_c, void* c_ptr) {
kernel.kernel(m, n, 1, 1, k, a.stride(0) * sizeof(TA), 0, 0, a_ptr, 0, 0,
b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m, init_c,
c.stride(0) * sizeof(TC), c_ptr);
kernel.kernel(m, n, 1, 1, k,
a.stride(0) * sizeof(TA) / (pack_a ? tile_k : 1), 0, 0, a_ptr,
0, 0, b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m,
init_c, c.stride(0) * sizeof(TC), c_ptr);
};

const size_t a_stride_m = transpose_a
? kernel.tile_k * sizeof(TA) / a_elem_count
: a.stride(0) * sizeof(TA);
const size_t a_stride_k = transpose_a
? a.stride(0) * sizeof(TA) / kernel.tile_k
: a.stride(1) * sizeof(TA) / a_elem_count;
const size_t a_stride_m = pack_a ? kernel.tile_k * sizeof(TA) / a_elem_count
: a.stride(0) * sizeof(TA);
const size_t a_stride_k = pack_a ? a.stride(0) * sizeof(TA) / kernel.tile_k
: a.stride(1) * sizeof(TA) / a_elem_count;
const size_t b_stride_k = b.stride(0) * sizeof(TB);
const size_t b_stride_n = b.stride(1) * sizeof(TB) / b_elem_count;
const size_t c_stride_m = c.stride(0) * sizeof(TC);
Expand Down
13 changes: 7 additions & 6 deletions ynnpack/kernels/dot/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ void TestMatMul(AT, BT, CT, const DotShape& shape, const KernelInfo& kernel,
Tensor<BT> packed_b = unpacked_b ? b : pack_b(b, tile_k, tile_n);
Tensor<AT> packed_a = pack_a ? transpose_a(a, tile_m, tile_k) : a;

kernel.kernel(m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT), 0, 0,
packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k,
packed_b.base(), c.stride(0) * sizeof(CT),
init_zero ? nullptr : c.base(), c.stride(0) * sizeof(CT),
c.base());
kernel.kernel(
m, n, 1, 1, k, packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1), 0,
0, packed_a.base(), 0, 0, packed_b.stride(0) * sizeof(BT) / tile_k,
packed_b.base(), c.stride(0) * sizeof(CT), init_zero ? nullptr : c.base(),
c.stride(0) * sizeof(CT), c.base());

// Verify results.
Reference(a, b, expected);
Expand Down Expand Up @@ -226,7 +226,8 @@ void TestConv2D(AT, BT, CT, const KernelInfo& kernel) {
{0, 0, 0, b.extent(3) - co / B_info::element_count()});

kernel.kernel(
w, co, kh, kw, ci, packed_a.stride(0) * sizeof(AT),
w, co, kh, kw, ci,
packed_a.stride(0) * sizeof(AT) / (pack_a ? tile_k : 1),
packed_a.stride(1) * sizeof(AT), packed_a.stride(2) * sizeof(AT),
packed_a.base(), packed_b.stride(0) * sizeof(BT),
packed_b.stride(1) * sizeof(BT),
Expand Down
21 changes: 12 additions & 9 deletions ynnpack/subgraph/dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a,

const int b_type_element_count = type_element_count(type.b);
const index_t tile_k = b_k1i.extent() * b_type_element_count;
const index_t a_stride_m = a_m.stride();
const index_t a_stride_k3 = a_k3.stride();
const index_t a_stride_k2 = a_k2.stride();
const index_t a_stride_k1 = a_k1.stride();
// If a is transposed, then the k dimension has been reshaped to have tile_k
// values in each element.
const index_t a_tile_k = transposed_a ? tile_k : 1;
const index_t a_stride_m = a_m.stride();
const index_t a_stride_k3 = a_k3.stride();
const index_t a_stride_k2 = a_k2.stride();
const index_t a_stride_k1 = a_k1.stride() / a_tile_k;
const index_t k1 = (a_k1.extent() * a_tile_k) & ~(tile_k - 1);
const index_t k1_tail = (a_k1.extent() * a_tile_k) & (tile_k - 1);
const index_t k2 = a_k2.extent();
Expand Down Expand Up @@ -227,7 +227,7 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a,
slinky::for_each_element(
[&](void* c, const void* a, const void* b, const void* init_c) {
run_dot(loops, c_m.extent(), c_n.extent(), k1, block_m, block_n,
block_k, a_stride_m, a_stride_k1 / a_tile_k, a, b_stride_k1,
block_k, a_stride_m, a_stride_k1, a, b_stride_k1,
b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, call_kernel);
},
Expand Down Expand Up @@ -265,7 +265,8 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a,
K2 * a_stride_k2),
k1 * a_elem_size);
}
kernel.kernel(m, n, /*k3=*/1, /*k2=*/1, tile_k, a_padded_stride_m,
kernel.kernel(m, n, /*k3=*/1, /*k2=*/1, tile_k,
transposed_a ? a_elem_size : a_padded_stride_m,
/*a_stride_k3=*/0, /*a_stride_k2=*/0, a_padded,
/*b_stride_k3=*/0,
/*b_stride_k2=*/0, b_stride_k1,
Expand All @@ -285,10 +286,12 @@ auto make_dot_impl(dot_type type, bool consistent_arithmetic, bool transposed_a,
}
a = offset_bytes(a, a_stride_k1 * k1);
b = offset_bytes(b, b_stride_k1 * k1);
// We only expect to run one iteration of k here, so the k strides
// are irrelevant.
run_dot(loops, c_m.extent(), c_n.extent(), k1_tail, block_m,
block_n, block_k, a_stride_m, a_stride_k1, a, b_stride_k1,
b_stride_n, b, tail_init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, call_kernel_tail);
block_n, block_k, a_stride_m, /*a_stride_k1=*/0, a,
/*b_stride_k1=*/0, b_stride_n, b, tail_init_c_stride_m,
init_c, c_stride_m, c_stride_n, c, call_kernel_tail);
},
c, a, b, init_c);
}
Expand Down
Loading