Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ynnpack/kernels/binary/bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void bench(benchmark::State& state, uint64_t arch_flags,
for (auto _ : state) {
kernel(m, n, a.stride(0) * sizeof(A), a.stride(1) * sizeof(A), a.base(),
b.stride(0) * sizeof(B), b.stride(1) * sizeof(B), b.base(),
x.stride(0) * sizeof(X), x.base());
x.stride(0) * sizeof(X), x.base(), nullptr);
}

const size_t ops = m * n;
Expand Down
3 changes: 2 additions & 1 deletion ynnpack/kernels/binary/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace {
template <typename T, typename Operator>
void binary_impl(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n,
const void* va, size_t stride_b_m, size_t stride_b_n,
const void* vb, size_t stride_x_m, void* vx) {
const void* vb, size_t stride_x_m, void* vx,
const binary_params*) {
auto a = reinterpret_cast<const T*>(va);
auto b = reinterpret_cast<const T*>(vb);
auto x = reinterpret_cast<T*>(vx);
Expand Down
10 changes: 8 additions & 2 deletions ynnpack/kernels/binary/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@

namespace ynn {

union binary_params {};

// The stride of dimension `n` for any operand must be 0 or the size of one
// element.
typedef void (*binary_kernel_fn)(size_t m, size_t n, size_t stride_a_m,
size_t stride_a_n, const void* a,
size_t stride_b_m, size_t stride_b_n,
const void* b, size_t stride_x_m, void* x);
const void* b, size_t stride_x_m, void* x,
const binary_params* params);

#define YNN_ELEMENTWISE_KERNEL(arch, name, op, type_a, type_b, type_c) \
void name(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, \
const void* a, size_t stride_b_m, size_t stride_b_n, \
const void* b, size_t stride_x_m, void* x);
const void* b, size_t stride_x_m, void* x, \
const binary_params* params);
#include "ynnpack/kernels/binary/kernels.inc"
#undef YNN_ELEMENTWISE_KERNEL

Expand All @@ -35,6 +39,8 @@ binary_kernel_fn get_binary_kernel(
ynn_binary_operator op, ynn_type type_a, ynn_type type_b, ynn_type type_x,
uint64_t supported_arch_flags = get_supported_arch_flags());

binary_params get_binary_params(ynn_binary_operator op);

} // namespace ynn

#endif // XNNPACK_YNNPACK_KERNELS_BINARY_H_
2 changes: 1 addition & 1 deletion ynnpack/kernels/binary/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void TestImpl(const KernelInfo& kernel_info, const OpInfo& op_info, size_t m,

kernel(m, n, a.stride(0) * sizeof(A), a.stride(1) * sizeof(A), a.base(),
b.stride(0) * sizeof(B), b.stride(1) * sizeof(B), b.base(),
x.stride(0) * sizeof(X), x.base());
x.stride(0) * sizeof(X), x.base(), nullptr);

check_results(op_info, a, b, x);
}
Expand Down
63 changes: 54 additions & 9 deletions ynnpack/kernels/elementwise/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,13 @@ class BroadcastMode(enum.Enum):
AUTO = 5


class Scalar:

def __init__(self, name, ty):
self.name = name
self.ty = ty


class Buffer:

def __init__(
Expand All @@ -713,6 +720,8 @@ def __init__(


buffer_args = []
scalar_args = []

op_name = "unknown"
code = ""

Expand Down Expand Up @@ -807,12 +816,16 @@ def __init__(
"""


def scalar(name, ty):
def params(*ps):
"""Decorator to add scalar parameters to the function."""
def actual_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
assert len(ps) > 0
scalar_args.extend(ps)
# fn_args.append((name, ty, 0))
args += (Var(name, ty),)
for p in ps:
args += (Var(p.name, p.ty),)
return func(*args, **kwargs)

return wrapper
Expand Down Expand Up @@ -983,6 +996,12 @@ def as_buffer(self, arg, buffers):
b = next((buf for buf in buffers if buf.name == arg.name), None)
return b

def is_scalar_arg(self, arg):
s = None
if isinstance(arg, Var):
s = next((s for s in scalar_args if s.name == arg.name), None)
return s

def compute_all_features(self, features, implied_features, all_features):
for feature in features:
if feature in all_features:
Expand Down Expand Up @@ -1158,6 +1177,8 @@ def begin_function(self, name, args):
f" base_{args[-1].name}"
)

arity = self.get_arity_string(args)
args_str.append(f"{self.indent()}const {arity}_params* params")
self.result += ",\n".join(args_str)
self.result += ") {\n"

Expand Down Expand Up @@ -1289,6 +1310,20 @@ def emit_constants(self, constants):
f" {self.legalize_op(v)}({v.args[0]});\n"
)

def emit_scalar_arguments(self, scalars, tile_width):
"""Emits scalar arguments."""

if scalars:
self.result += f"{self.indent()}assert(params != nullptr);\n"

for s in scalars:
self.result += (
f"{self.indent()}const"
f" {self.legalize_type(s.ty.with_lanes(tile_width))} {s.name} ="
f" {self.legalize_op(broadcast(Constant(s.ty, 0), tile_width))}(reinterpret_cast<const"
f" {op_name}_params*>(params)->{s.name});\n"
)

def emit_op(self, i, j, is_rem_width, buffers, constants, tile_width):
"""Emits a single operation."""
op = i[1]
Expand Down Expand Up @@ -1345,7 +1380,9 @@ def emit_op(self, i, j, is_rem_width, buffers, constants, tile_width):
else:
if isinstance(arg, Constant):
str_args.append(f"{arg}")
elif isinstance(arg, Var) and arg in constants:
elif isinstance(arg, Var) and (
arg in constants or (self.is_scalar_arg(arg) is not None)
):
str_args.append(f"{arg}{self.simd_suffix(op)}")
else:
str_args.append(f"{arg}_{j}{self.simd_suffix(op)}")
Expand Down Expand Up @@ -1570,6 +1607,7 @@ def compile(self, name, buffers, func, tile_shapes):
self.emit_asserts(buffers)

self.emit_constants(constants)
self.emit_scalar_arguments(scalar_args, tile_width)

self.handle_specialize(
ops,
Expand All @@ -1590,9 +1628,20 @@ def arch_flags(self):
def arch_string(self):
return "x86_" + "_".join([i.lower() for i in self.features])

def get_arity_string(self, buffers):
if len(buffers) == 4:
return "ternary"
elif len(buffers) == 3:
return "binary"
elif len(buffers) == 2:
return "unary"
else:
assert False, "Unsupported number of buffers."

def compile_function(self, name, fn, tile_shapes):
self.result = ""
buffer_args.clear()
scalar_args.clear()
global op_name
op_name = "unknown"
result = fn()
Expand All @@ -1613,12 +1662,8 @@ def compile_function(self, name, fn, tile_shapes):
)

src = '#include "ynnpack/kernels/'
if len(buffer_args) == 4:
src += "ternary/ternary.h"
elif len(buffer_args) == 3:
src += "binary/binary.h"
elif len(buffer_args) == 2:
src += "unary/unary.h"
arity = self.get_arity_string(buffer_args)
src += f"{arity}/{arity}.h"
src += '"\n'
src += "namespace ynn {\n"
src += self.compile(func_name, buffer_args, result, tile_shapes)
Expand Down
2 changes: 1 addition & 1 deletion ynnpack/kernels/ternary/bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void bench(benchmark::State& state, uint64_t arch_flags,
kernel(m, n, a.stride(0) * sizeof(A), a.stride(1) * sizeof(A), a.base(),
b.stride(0) * sizeof(B), b.stride(1) * sizeof(B), b.base(),
c.stride(0) * sizeof(C), c.stride(1) * sizeof(C), c.base(),
x.stride(0) * sizeof(X), x.base());
x.stride(0) * sizeof(X), x.base(), nullptr);
}

const size_t ops = m * n;
Expand Down
30 changes: 18 additions & 12 deletions ynnpack/kernels/ternary/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <typename A, typename X>
void quantize(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n,
const A* a, size_t stride_b_m, size_t stride_b_n, const float* b,
size_t stride_c_m, size_t stride_c_n, const int32_t* c,
size_t stride_x_m, X* x) {
size_t stride_x_m, X* x, const ternary_params* params) {
for (size_t i = 0; i < m; ++i) {
// There are 8 cases of broadcasting. Here, we only specialize for
// broadcasting b, because it permits lifting a division out of the loop.
Expand Down Expand Up @@ -53,7 +53,8 @@ template <typename A, typename X>
void dequantize(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n,
const A* a, size_t stride_b_m, size_t stride_b_n,
const int32_t* b, size_t stride_c_m, size_t stride_c_n,
const float* c, size_t stride_x_m, X* x) {
const float* c, size_t stride_x_m, X* x,
const ternary_params* params) {
for (size_t i = 0; i < m; ++i) {
for (size_t j = 0; j < n; ++j) {
const A a_j = *offset_bytes(a, j * stride_a_n);
Expand All @@ -74,58 +75,61 @@ void quantize_fp32_to_int8(size_t m, size_t n, size_t stride_a_m,
size_t stride_a_n, const void* a, size_t stride_b_m,
size_t stride_b_n, const void* b, size_t stride_c_m,
size_t stride_c_n, const void* c, size_t stride_x_m,
void* x) {
void* x, const ternary_params* params) {
quantize(m, n, stride_a_m, stride_a_n, reinterpret_cast<const float*>(a),
stride_b_m, stride_b_n, reinterpret_cast<const float*>(b),
stride_c_m, stride_c_n, reinterpret_cast<const int32_t*>(c),
stride_x_m, reinterpret_cast<int8_t*>(x));
stride_x_m, reinterpret_cast<int8_t*>(x), params);
}

void quantize_fp32_to_uint8(size_t m, size_t n, size_t stride_a_m,
size_t stride_a_n, const void* a, size_t stride_b_m,
size_t stride_b_n, const void* b, size_t stride_c_m,
size_t stride_c_n, const void* c, size_t stride_x_m,
void* x) {
void* x, const ternary_params* params) {
quantize(m, n, stride_a_m, stride_a_n, reinterpret_cast<const float*>(a),
stride_b_m, stride_b_n, reinterpret_cast<const float*>(b),
stride_c_m, stride_c_n, reinterpret_cast<const int32_t*>(c),
stride_x_m, reinterpret_cast<uint8_t*>(x));
stride_x_m, reinterpret_cast<uint8_t*>(x), params);
}

void dequantize_int8_to_fp32(size_t m, size_t n, size_t stride_a_m,
size_t stride_a_n, const void* a,
size_t stride_b_m, size_t stride_b_n,
const void* b, size_t stride_c_m,
size_t stride_c_n, const void* c,
size_t stride_x_m, void* x) {
size_t stride_x_m, void* x,
const ternary_params* params) {
dequantize(m, n, stride_a_m, stride_a_n, reinterpret_cast<const int8_t*>(a),
stride_b_m, stride_b_n, reinterpret_cast<const int32_t*>(b),
stride_c_m, stride_c_n, reinterpret_cast<const float*>(c),
stride_x_m, reinterpret_cast<float*>(x));
stride_x_m, reinterpret_cast<float*>(x), params);
}

void dequantize_uint8_to_fp32(size_t m, size_t n, size_t stride_a_m,
size_t stride_a_n, const void* a,
size_t stride_b_m, size_t stride_b_n,
const void* b, size_t stride_c_m,
size_t stride_c_n, const void* c,
size_t stride_x_m, void* x) {
size_t stride_x_m, void* x,
const ternary_params* params) {
dequantize(m, n, stride_a_m, stride_a_n, reinterpret_cast<const uint8_t*>(a),
stride_b_m, stride_b_n, reinterpret_cast<const int32_t*>(b),
stride_c_m, stride_c_n, reinterpret_cast<const float*>(c),
stride_x_m, reinterpret_cast<float*>(x));
stride_x_m, reinterpret_cast<float*>(x), params);
}

void dequantize_int32_to_fp32(size_t m, size_t n, size_t stride_a_m,
size_t stride_a_n, const void* a,
size_t stride_b_m, size_t stride_b_n,
const void* b, size_t stride_c_m,
size_t stride_c_n, const void* c,
size_t stride_x_m, void* x) {
size_t stride_x_m, void* x,
const ternary_params* params) {
dequantize(m, n, stride_a_m, stride_a_n, reinterpret_cast<const int32_t*>(a),
stride_b_m, stride_b_n, reinterpret_cast<const int32_t*>(b),
stride_c_m, stride_c_n, reinterpret_cast<const float*>(c),
stride_x_m, reinterpret_cast<float*>(x));
stride_x_m, reinterpret_cast<float*>(x), params);
}

ternary_kernel_fn get_ternary_kernel(ternary_op op, ynn_type type_a,
Expand Down Expand Up @@ -164,4 +168,6 @@ const char* to_string(ternary_op op) {
return "unknown";
}

ternary_params get_ternary_params(ternary_op op) { return ternary_params{}; }

} // namespace ynn
10 changes: 8 additions & 2 deletions ynnpack/kernels/ternary/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,24 @@

namespace ynn {

union ternary_params {};

// The stride of dimension `n` for any operand must be 0 or the size of one
// element.
typedef void (*ternary_kernel_fn)(size_t m, size_t n, size_t stride_a_m,
size_t stride_a_n, const void* a,
size_t stride_b_m, size_t stride_b_n,
const void* b, size_t stride_c_m,
size_t stride_c_n, const void* c,
size_t stride_x_m, void* x);
size_t stride_x_m, void* x,
const ternary_params* params);

#define YNN_ELEMENTWISE_KERNEL(arch, name, op, type_a, type_b, type_c, type_x) \
void name(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, \
const void* a, size_t stride_b_m, size_t stride_b_n, \
const void* b, size_t stride_c_m, size_t stride_c_n, \
const void* c, size_t stride_x_m, void* x);
const void* c, size_t stride_x_m, void* x, \
const ternary_params* params);
#include "ynnpack/kernels/ternary/kernels.inc"
#undef YNN_ELEMENTWISE_KERNEL

Expand All @@ -51,6 +55,8 @@ ternary_kernel_fn get_ternary_kernel(ternary_op op, ynn_type type_a,
ynn_type type_b, ynn_type type_c,
ynn_type type_x);

ternary_params get_ternary_params(ternary_op op);

} // namespace ynn

#endif // XNNPACK_YNNPACK_KERNELS_TERNARY_H_
2 changes: 1 addition & 1 deletion ynnpack/kernels/ternary/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void TestImpl(const KernelInfo& kernel_info, const OpInfo& op_info, size_t m,
kernel(m, n, a.stride(0) * sizeof(A), a.stride(1) * sizeof(A), a.base(),
b.stride(0) * sizeof(B), b.stride(1) * sizeof(B), b.base(),
c.stride(0) * sizeof(C), c.stride(1) * sizeof(C), c.base(),
x.stride(0) * sizeof(X), x.base());
x.stride(0) * sizeof(X), x.base(), nullptr);

check_results(op_info, a, b, c, x, a_quantization, b_quantization,
c_quantization, x_quantization);
Expand Down
1 change: 1 addition & 0 deletions ynnpack/kernels/unary/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ cc_library(
hdrs = ["reference.h"],
visibility = ["//ynnpack:__subpackages__"],
deps = [
":unary",
"//ynnpack:ynnpack_h",
"//ynnpack/base",
"//ynnpack/base/test:tensor",
Expand Down
Loading
Loading