Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ynnpack/kernels/dot/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ cc_test(
malloc = ynn_binary_malloc(),
deps = [
":schedule",
"//ynnpack/base",
] + ynn_test_deps(),
)

Expand Down
10 changes: 8 additions & 2 deletions ynnpack/kernels/dot/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
namespace ynn {

span<dot_loop> schedule_dot(span<const size_t> cache_sizes, size_t m, size_t n,
size_t k1, size_t k2, size_t k3, size_t block_m,
span<const size_t> ks, size_t block_m,
size_t block_n, size_t block_k, size_t a_elem_size,
size_t b_elem_size, dot_loop* storage) {
dot_loop* begin = storage;
dot_loop* loop = begin;

size_t k1 = ks[0];
size_t k2 = 1;
for (size_t i = 1; i < ks.size(); ++i) {
k2 *= ks[i];
}

// When we make a loop in a dimension, the extent of that dimension becomes
// the step size of that loop.
auto make_m_loop = [&](size_t blocks) {
Expand All @@ -41,7 +47,7 @@ span<dot_loop> schedule_dot(span<const size_t> cache_sizes, size_t m, size_t n,
for (size_t cache_size : cache_sizes) {
// TODO(b/447988052): We can be way smarter about this than we are now.
make_k_loop(
floor_div(cache_size, k2 * k3 * block_n * b_elem_size * block_k));
floor_div(cache_size, k2 * block_n * b_elem_size * block_k));
if (n * b_elem_size <= m * a_elem_size) {
// Tiles of B are smaller than tiles of A, we should assume B fits in
// cache.
Expand Down
102 changes: 59 additions & 43 deletions ynnpack/kernels/dot/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,21 @@ struct dot_loop {
// maximized for each cache in `cache_sizes`. `storage` must have room for at
// most 3 loops per cache size.
span<dot_loop> schedule_dot(span<const size_t> cache_sizes, size_t m, size_t n,
size_t k1, size_t k2, size_t k3, size_t block_m,
span<const size_t> ks, size_t block_m,
size_t block_n, size_t block_k, size_t a_elem_size,
size_t b_elem_size, dot_loop* storage);

// Block a dot's m dimension, calling f at each block.
template <typename DotFn>
void block_dot_m(ptrdiff_t m, size_t n, size_t k, ptrdiff_t block_m,
size_t a_stride_m, const void* a, const void* b,
void block_dot_m(ptrdiff_t m, size_t n, span<const size_t> ks,
ptrdiff_t block_m, size_t a_stride_m,
span<const size_t> a_k_strides, const void* a,
span<const size_t> b_k_strides, const void* b,
size_t init_c_stride_m, const void* init_c, size_t c_stride_m,
size_t c_stride_n, void* c, DotFn f) {
do {
f(std::min(m, block_m), n, k, a, b, init_c_stride_m, init_c, c);
f(std::min(m, block_m), n, ks, a, a_stride_m, a_k_strides, b, b_k_strides,
init_c_stride_m, init_c, c);

m -= block_m;
if (init_c) init_c = offset_bytes(init_c, init_c_stride_m * block_m);
Expand All @@ -52,12 +55,15 @@ void block_dot_m(ptrdiff_t m, size_t n, size_t k, ptrdiff_t block_m,

// Block a dot's n dimension, calling f at each block.
template <typename DotFn>
void block_dot_n(size_t m, ptrdiff_t n, size_t k, ptrdiff_t block_n,
const void* a, size_t b_stride_n, const void* b,
size_t init_c_stride_m, const void* init_c, size_t c_stride_m,
size_t c_stride_n, void* c, DotFn f) {
void block_dot_n(size_t m, ptrdiff_t n, span<const size_t> ks,
ptrdiff_t block_n, size_t a_stride_m,
span<const size_t> a_k_strides, const void* a,
size_t b_stride_n, span<const size_t> b_k_strides,
const void* b, size_t init_c_stride_m, const void* init_c,
size_t c_stride_m, size_t c_stride_n, void* c, DotFn f) {
do {
f(m, std::min(n, block_n), k, a, b, init_c_stride_m, init_c, c);
f(m, std::min(n, block_n), ks, a, a_stride_m, a_k_strides, b, b_k_strides,
init_c_stride_m, init_c, c);

n -= block_n;
if (init_c) init_c = offset_bytes(init_c, c_stride_n * block_n);
Expand All @@ -68,12 +74,18 @@ void block_dot_n(size_t m, ptrdiff_t n, size_t k, ptrdiff_t block_n,

// Block a dot's k dimension, calling f at each block.
template <typename DotFn>
void block_dot_k(size_t m, size_t n, ptrdiff_t k, ptrdiff_t block_k,
size_t a_stride_k, const void* a, size_t b_stride_k,
const void* b, size_t init_c_stride_m, const void* init_c,
size_t c_stride_m, size_t c_stride_n, void* c, DotFn f) {
void block_dot_k(size_t m, size_t n, span<const size_t> ks, ptrdiff_t block_k,
size_t a_stride_m, span<const size_t> a_k_strides,
const void* a, span<const size_t> b_k_strides, const void* b,
size_t init_c_stride_m, const void* init_c, size_t c_stride_m,
size_t c_stride_n, void* c, DotFn f) {
ptrdiff_t k = ks[0];
size_t k_block[3];
std::copy(ks.begin(), ks.end(), k_block);
do {
f(m, n, std::min(k, block_k), a, b, init_c_stride_m, init_c, c);
k_block[0] = std::min(k, block_k);
f(m, n, {k_block, ks.size()}, a, a_stride_m, a_k_strides, b, b_k_strides,
init_c_stride_m, init_c, c);

// Splitting k requires care for the initializer. The dot kernels read and
// write from a separate buffer, so for each tile that we process in k, the
Expand All @@ -90,18 +102,18 @@ void block_dot_k(size_t m, size_t n, ptrdiff_t k, ptrdiff_t block_k,
init_c_stride_m = c_stride_m;

k -= block_k;
b = offset_bytes(b, b_stride_k * block_k);
a = offset_bytes(a, a_stride_k * block_k);
b = offset_bytes(b, b_k_strides[0] * block_k);
a = offset_bytes(a, a_k_strides[0] * block_k);
} while (k > 0);
}

template <typename DotFn>
void run_dot(span<dot_loop> loops, size_t m, size_t n, size_t k, size_t block_m,
size_t block_n, size_t block_k, size_t a_stride_m,
size_t a_stride_k, const void* a, size_t b_stride_k,
size_t b_stride_n, const void* b, size_t init_c_stride_m,
const void* init_c, size_t c_stride_m, size_t c_stride_n, void* c,
DotFn f) {
void run_dot(span<dot_loop> loops, size_t m, size_t n, span<const size_t> ks,
size_t block_m, size_t block_n, size_t block_k, size_t a_stride_m,
span<const size_t> a_k_strides, const void* a,
span<const size_t> b_k_strides, size_t b_stride_n, const void* b,
size_t init_c_stride_m, const void* init_c, size_t c_stride_m,
size_t c_stride_n, void* c, DotFn f) {
assert(!loops.empty());
const dot_loop loop = loops.front();
loops = loops.subspan(1);
Expand All @@ -110,40 +122,44 @@ void run_dot(span<dot_loop> loops, size_t m, size_t n, size_t k, size_t block_m,
// There are no more loops after this one.
switch (loop.dim) {
case dot_loop::m:
return block_dot_m(m, n, k, block_m * loop.blocks, a_stride_m, a, b,
init_c_stride_m, init_c, c_stride_m, c_stride_n, c,
f);
return block_dot_m(m, n, ks, block_m * loop.blocks, a_stride_m,
a_k_strides, a, b_k_strides, b, init_c_stride_m,
init_c, c_stride_m, c_stride_n, c, f);
case dot_loop::n:
return block_dot_n(m, n, k, block_n * loop.blocks, a, b_stride_n, b,
return block_dot_n(m, n, ks, block_n * loop.blocks, a_stride_m,
a_k_strides, a, b_stride_n, b_k_strides, b,
init_c_stride_m, init_c, c_stride_m, c_stride_n, c,
f);
case dot_loop::k:
return block_dot_k(m, n, k, block_k * loop.blocks, a_stride_k, a,
b_stride_k, b, init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, f);
return block_dot_k(m, n, ks, block_k * loop.blocks, a_stride_m,
a_k_strides, a, b_k_strides, b, init_c_stride_m,
init_c, c_stride_m, c_stride_n, c, f);
}
} else {
// Recursively call `run_dot` with the subsequent loops.
auto recursive_f = [=](size_t m, size_t n, size_t k, const void* a,
const void* b, size_t init_c_stride_m,
const void* init_c, void* c) {
run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k,
a, b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, f);
};
auto recursive_f =
[=](size_t m, size_t n, span<const size_t> ks, const void* a,
size_t a_stride_m, span<const size_t> a_k_strides, const void* b,
span<const size_t> b_k_strides, size_t init_c_stride_m,
const void* init_c, void* c) {
run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m,
a_k_strides, a, b_k_strides, b_stride_n, b, init_c_stride_m,
init_c, c_stride_m, c_stride_n, c, f);
};
switch (loop.dim) {
case dot_loop::m:
return block_dot_m(m, n, k, block_m * loop.blocks, a_stride_m, a, b,
init_c_stride_m, init_c, c_stride_m, c_stride_n, c,
recursive_f);
return block_dot_m(m, n, ks, block_m * loop.blocks, a_stride_m,
a_k_strides, a, b_k_strides, b, init_c_stride_m,
init_c, c_stride_m, c_stride_n, c, recursive_f);
case dot_loop::n:
return block_dot_n(m, n, k, block_n * loop.blocks, a, b_stride_n, b,
return block_dot_n(m, n, ks, block_n * loop.blocks, a_stride_m,
a_k_strides, a, b_stride_n, b_k_strides, b,
init_c_stride_m, init_c, c_stride_m, c_stride_n, c,
recursive_f);
case dot_loop::k:
return block_dot_k(m, n, k, block_k * loop.blocks, a_stride_k, a,
b_stride_k, b, init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, recursive_f);
return block_dot_k(m, n, ks, block_k * loop.blocks, a_stride_m,
a_k_strides, a, b_k_strides, b, init_c_stride_m,
init_c, c_stride_m, c_stride_n, c, recursive_f);
}
}
}
Expand Down
29 changes: 18 additions & 11 deletions ynnpack/kernels/dot/schedule_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,16 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n,

a = pack_a ? transpose_a(a, tile_m, tile_k) : a;

auto kernel_wrapper = [&](size_t m, size_t n, size_t k, const void* a_ptr,
const void* b_ptr, size_t init_c_stride_m,
const void* init_c, void* c_ptr) {
kernel.kernel(m, n, 1, 1, k,
a.stride(0) * sizeof(TA) / (pack_a ? tile_k : 1), 0, 0, a_ptr,
0, 0, b.stride(0) * sizeof(TB), b_ptr, init_c_stride_m,
init_c, c.stride(0) * sizeof(TC), c_ptr);
};
auto kernel_wrapper =
[&](size_t m, size_t n, span<const size_t> k, const void* a_ptr,
size_t a_stride_m, span<const size_t> a_k_strides, const void* b_ptr,
span<const size_t> b_k_strides, size_t init_c_stride_m,
const void* init_c, void* c_ptr) {
kernel.kernel(m, n, k[2], k[1], k[0], a_stride_m,
a_k_strides[2], a_k_strides[1], a_ptr, b_k_strides[2],
b_k_strides[1], b_k_strides[0], b_ptr, init_c_stride_m,
init_c, c.stride(0) * sizeof(TC), c_ptr);
};

const size_t a_stride_m = pack_a ? kernel.tile_k * sizeof(TA) / a_elem_count
: a.stride(0) * sizeof(TA);
Expand All @@ -165,10 +167,15 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n,
const size_t c_stride_m = c.stride(0) * sizeof(TC);
const size_t c_stride_n = c.stride(1) * sizeof(TC);

const size_t ks[] = {k, 1, 1};
const size_t a_k_strides[] = {a_stride_k, 0, 0};
const size_t b_k_strides[] = {b_stride_k, 0, 0};

double t = benchmark([&]() {
run_dot(loops, m, n, k, kernel.block_m, kernel.block_n, kernel.block_k,
a_stride_m, a_stride_k, a.base(), b_stride_k, b_stride_n, b.base(),
0, nullptr, c_stride_m, c_stride_n, c.base(), kernel_wrapper);
run_dot(loops, m, n, ks, kernel.block_m, kernel.block_n, kernel.block_k,
a_stride_m, a_k_strides, a.base(), b_k_strides, b_stride_n,
b.base(), 0, nullptr, c_stride_m, c_stride_n, c.base(),
kernel_wrapper);
});
// Check that the kernel didn't compute the wrong thing. We assume the kernel
// is correct, but we have some logic here that needs validation too. We
Expand Down
46 changes: 31 additions & 15 deletions ynnpack/kernels/dot/schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "ynnpack/base/span.h"

using testing::ElementsAre;

Expand Down Expand Up @@ -84,17 +85,22 @@ dot_call dot_call_at(size_t m, size_t n, size_t k, size_t i, size_t j,
};

auto make_record_calls(std::vector<dot_call>& calls) {
return [&](size_t m, size_t n, size_t k, const void* a, const void* b,
size_t init_c_stride_m, const void* init_c,
const void* c) { calls.push_back({m, n, k, a, b, init_c, c}); };
return [&](size_t m, size_t n, span<const size_t> k, const void* a,
size_t a_stride_m, span<const size_t> a_k_strides, const void* b,
span<const size_t> b_k_strides, size_t init_c_stride_m,
const void* init_c,
const void* c) { calls.push_back({m, n, k[0], a, b, init_c, c}); };
}

TEST(run_dot, loop_m) {
const dot_loop loops[] = {{dot_loop::m, 1}};
const size_t ks[] = {k};
const size_t a_k_strides[] = {a_stride_k};
const size_t b_k_strides[] = {b_stride_k};

std::vector<dot_call> calls;
run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, a,
b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, a_k_strides,
a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, make_record_calls(calls));
ASSERT_THAT(calls,
ElementsAre(dot_call_at(block_m, n, k, 0 * block_m, 0, 0),
Expand All @@ -104,10 +110,13 @@ TEST(run_dot, loop_m) {

TEST(run_dot, loop_n) {
const dot_loop loops[] = {{dot_loop::n, 1}};
const size_t ks[] = {k};
const size_t a_k_strides[] = {a_stride_k};
const size_t b_k_strides[] = {b_stride_k};

std::vector<dot_call> calls;
run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, a,
b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, a_k_strides,
a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, make_record_calls(calls));

ASSERT_THAT(calls,
Expand All @@ -120,24 +129,31 @@ TEST(run_dot, loop_n) {

TEST(run_dot, loop_n_tail) {
const dot_loop loops[] = {{dot_loop::n, 2}};
const size_t ks[] = {k};
const size_t a_k_strides[] = {a_stride_k};
const size_t b_k_strides[] = {b_stride_k};

std::vector<dot_call> calls;
run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, a,
b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, a_k_strides,
a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, make_record_calls(calls));

ASSERT_THAT(calls,
ElementsAre(dot_call_at(m, 2 * block_n, k, 0, 0 * block_n, 0),
dot_call_at(m, 2 * block_n, k, 0, 2 * block_n, 0),
dot_call_at(m, block_n, k, 0, 4 * block_n, 0)));
ASSERT_THAT(
calls,
ElementsAre(dot_call_at(m, 2 * block_n, k, 0, 0 * block_n, 0),
dot_call_at(m, 2 * block_n, k, 0, 2 * block_n, 0),
dot_call_at(m, n - 4 * block_n, k, 0, 4 * block_n, 0)));
}

TEST(run_dot, loop_k) {
const dot_loop loops[] = {{dot_loop::k, 1}};
const size_t ks[] = {k};
const size_t a_k_strides[] = {a_stride_k};
const size_t b_k_strides[] = {b_stride_k};

std::vector<dot_call> calls;
run_dot(loops, m, n, k, block_m, block_n, block_k, a_stride_m, a_stride_k, a,
b_stride_k, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
run_dot(loops, m, n, ks, block_m, block_n, block_k, a_stride_m, a_k_strides,
a, b_k_strides, b_stride_n, b, init_c_stride_m, init_c, c_stride_m,
c_stride_n, c, make_record_calls(calls));

ASSERT_THAT(calls,
Expand Down
Loading
Loading