From dfa0f4b05fd83f99c15beee5b4abf2de9b11f0cb Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Thu, 2 Apr 2026 17:26:56 -0700 Subject: [PATCH] Add support for passing scalar parameters to elementwise kernels. PiperOrigin-RevId: 893786966 --- ynnpack/kernels/binary/bench.cc | 2 +- ynnpack/kernels/binary/binary.cc | 3 +- ynnpack/kernels/binary/binary.h | 10 +++- ynnpack/kernels/binary/test.cc | 2 +- ynnpack/kernels/elementwise/compiler.py | 63 +++++++++++++++++++---- ynnpack/kernels/ternary/bench.cc | 2 +- ynnpack/kernels/ternary/ternary.cc | 30 ++++++----- ynnpack/kernels/ternary/ternary.h | 10 +++- ynnpack/kernels/ternary/test.cc | 2 +- ynnpack/kernels/unary/BUILD | 1 + ynnpack/kernels/unary/bench.cc | 25 +++++---- ynnpack/kernels/unary/exp.py | 12 +++-- ynnpack/kernels/unary/reference.cc | 67 +++++++++---------------- ynnpack/kernels/unary/reference.h | 33 +++++++++++- ynnpack/kernels/unary/test.cc | 57 ++++++++++++++------- ynnpack/kernels/unary/unary.cc | 43 ++++++++++++++-- ynnpack/kernels/unary/unary.h | 15 +++++- ynnpack/subgraph/elementwise.cc | 63 ++++++++++++----------- ynnpack/subgraph/test/unary.cc | 5 +- 19 files changed, 298 insertions(+), 147 deletions(-) diff --git a/ynnpack/kernels/binary/bench.cc b/ynnpack/kernels/binary/bench.cc index ac0539f8cd0..2da391396a7 100644 --- a/ynnpack/kernels/binary/bench.cc +++ b/ynnpack/kernels/binary/bench.cc @@ -41,7 +41,7 @@ void bench(benchmark::State& state, uint64_t arch_flags, for (auto _ : state) { kernel(m, n, a.stride(0) * sizeof(A), a.stride(1) * sizeof(A), a.base(), b.stride(0) * sizeof(B), b.stride(1) * sizeof(B), b.base(), - x.stride(0) * sizeof(X), x.base()); + x.stride(0) * sizeof(X), x.base(), nullptr); } const size_t ops = m * n; diff --git a/ynnpack/kernels/binary/binary.cc b/ynnpack/kernels/binary/binary.cc index ef7e04d2999..487d616ea0f 100644 --- a/ynnpack/kernels/binary/binary.cc +++ b/ynnpack/kernels/binary/binary.cc @@ -29,7 +29,8 @@ namespace { template void binary_impl(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, const void* va, size_t stride_b_m, size_t stride_b_n, - const void* vb, size_t stride_x_m, void* vx) { + const void* vb, size_t stride_x_m, void* vx, + const binary_params*) { auto a = reinterpret_cast(va); auto b = reinterpret_cast(vb); auto x = reinterpret_cast(vx); diff --git a/ynnpack/kernels/binary/binary.h b/ynnpack/kernels/binary/binary.h index 4feab29e91c..9c96b8845b4 100644 --- a/ynnpack/kernels/binary/binary.h +++ b/ynnpack/kernels/binary/binary.h @@ -14,17 +14,21 @@ namespace ynn { +union binary_params {}; + // The stride of dimension `n` for any operand must be 0 or the size of one // element. typedef void (*binary_kernel_fn)(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, const void* a, size_t stride_b_m, size_t stride_b_n, - const void* b, size_t stride_x_m, void* x); + const void* b, size_t stride_x_m, void* x, + const binary_params* params); #define YNN_ELEMENTWISE_KERNEL(arch, name, op, type_a, type_b, type_c) \ void name(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, \ const void* a, size_t stride_b_m, size_t stride_b_n, \ - const void* b, size_t stride_x_m, void* x); + const void* b, size_t stride_x_m, void* x, \ + const binary_params* params); #include "ynnpack/kernels/binary/kernels.inc" #undef YNN_ELEMENTWISE_KERNEL @@ -35,6 +39,8 @@ binary_kernel_fn get_binary_kernel( ynn_binary_operator op, ynn_type type_a, ynn_type type_b, ynn_type type_x, uint64_t supported_arch_flags = get_supported_arch_flags()); +binary_params get_binary_params(ynn_binary_operator op); + } // namespace ynn #endif // XNNPACK_YNNPACK_KERNELS_BINARY_H_ diff --git a/ynnpack/kernels/binary/test.cc b/ynnpack/kernels/binary/test.cc index 78b181bdf58..7a6c465dd4c 100644 --- a/ynnpack/kernels/binary/test.cc +++ b/ynnpack/kernels/binary/test.cc @@ -96,7 +96,7 @@ void TestImpl(const KernelInfo& kernel_info, const OpInfo& op_info, size_t m, kernel(m, n, a.stride(0) * sizeof(A), a.stride(1) * sizeof(A), a.base(), b.stride(0) * sizeof(B), b.stride(1) * sizeof(B), b.base(), - x.stride(0) * sizeof(X), x.base()); + x.stride(0) * sizeof(X), x.base(), nullptr); check_results(op_info, a, b, x); } diff --git a/ynnpack/kernels/elementwise/compiler.py b/ynnpack/kernels/elementwise/compiler.py index afd47a8621c..ebfebd6d9bc 100644 --- a/ynnpack/kernels/elementwise/compiler.py +++ b/ynnpack/kernels/elementwise/compiler.py @@ -701,6 +701,13 @@ class BroadcastMode(enum.Enum): AUTO = 5 +class Scalar: + + def __init__(self, name, ty): + self.name = name + self.ty = ty + + class Buffer: def __init__( @@ -713,6 +720,8 @@ def __init__( buffer_args = [] +scalar_args = [] + op_name = "unknown" code = "" @@ -807,12 +816,16 @@ def __init__( """ -def scalar(name, ty): +def params(*ps): + """Decorator to add scalar parameters to the function.""" def actual_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): + assert len(ps) > 0 + scalar_args.extend(ps) # fn_args.append((name, ty, 0)) - args += (Var(name, ty),) + for p in ps: + args += (Var(p.name, p.ty),) return func(*args, **kwargs) return wrapper @@ -983,6 +996,12 @@ def as_buffer(self, arg, buffers): b = next((buf for buf in buffers if buf.name == arg.name), None) return b + def is_scalar_arg(self, arg): + s = None + if isinstance(arg, Var): + s = next((s for s in scalar_args if s.name == arg.name), None) + return s + def compute_all_features(self, features, implied_features, all_features): for feature in features: if feature in all_features: @@ -1158,6 +1177,8 @@ def begin_function(self, name, args): f" base_{args[-1].name}" ) + arity = self.get_arity_string(args) + args_str.append(f"{self.indent()}const {arity}_params* params") self.result += ",\n".join(args_str) self.result += ") {\n" @@ -1289,6 +1310,20 @@ def emit_constants(self, constants): f" {self.legalize_op(v)}({v.args[0]});\n" ) + def emit_scalar_arguments(self, scalars, tile_width): + """Emits scalar arguments.""" + + if scalars: + self.result += f"{self.indent()}assert(params != nullptr);\n" + + for s in scalars: + self.result += ( + f"{self.indent()}const" + f" {self.legalize_type(s.ty.with_lanes(tile_width))} {s.name} =" + f" {self.legalize_op(broadcast(Constant(s.ty, 0), tile_width))}(reinterpret_cast(params)->{s.name});\n" + ) + def emit_op(self, i, j, is_rem_width, buffers, constants, tile_width): """Emits a single operation.""" op = i[1] @@ -1345,7 +1380,9 @@ def emit_op(self, i, j, is_rem_width, buffers, constants, tile_width): else: if isinstance(arg, Constant): str_args.append(f"{arg}") - elif isinstance(arg, Var) and arg in constants: + elif isinstance(arg, Var) and ( + arg in constants or (self.is_scalar_arg(arg) is not None) + ): str_args.append(f"{arg}{self.simd_suffix(op)}") else: str_args.append(f"{arg}_{j}{self.simd_suffix(op)}") @@ -1570,6 +1607,7 @@ def compile(self, name, buffers, func, tile_shapes): self.emit_asserts(buffers) self.emit_constants(constants) + self.emit_scalar_arguments(scalar_args, tile_width) self.handle_specialize( ops, @@ -1590,9 +1628,20 @@ def arch_flags(self): def arch_string(self): return "x86_" + "_".join([i.lower() for i in self.features]) + def get_arity_string(self, buffers): + if len(buffers) == 4: + return "ternary" + elif len(buffers) == 3: + return "binary" + elif len(buffers) == 2: + return "unary" + else: + assert False, "Unsupported number of buffers." + def compile_function(self, name, fn, tile_shapes): self.result = "" buffer_args.clear() + scalar_args.clear() global op_name op_name = "unknown" result = fn() @@ -1613,12 +1662,8 @@ def compile_function(self, name, fn, tile_shapes): ) src = '#include "ynnpack/kernels/' - if len(buffer_args) == 4: - src += "ternary/ternary.h" - elif len(buffer_args) == 3: - src += "binary/binary.h" - elif len(buffer_args) == 2: - src += "unary/unary.h" + arity = self.get_arity_string(buffer_args) + src += f"{arity}/{arity}.h" src += '"\n' src += "namespace ynn {\n" src += self.compile(func_name, buffer_args, result, tile_shapes) diff --git a/ynnpack/kernels/ternary/bench.cc b/ynnpack/kernels/ternary/bench.cc index da4a2d780fd..7a1dcd53ac8 100644 --- a/ynnpack/kernels/ternary/bench.cc +++ b/ynnpack/kernels/ternary/bench.cc @@ -46,7 +46,7 @@ void bench(benchmark::State& state, uint64_t arch_flags, kernel(m, n, a.stride(0) * sizeof(A), a.stride(1) * sizeof(A), a.base(), b.stride(0) * sizeof(B), b.stride(1) * sizeof(B), b.base(), c.stride(0) * sizeof(C), c.stride(1) * sizeof(C), c.base(), - x.stride(0) * sizeof(X), x.base()); + x.stride(0) * sizeof(X), x.base(), nullptr); } const size_t ops = m * n; diff --git a/ynnpack/kernels/ternary/ternary.cc b/ynnpack/kernels/ternary/ternary.cc index 43741c37992..3bf3f30e5d7 100644 --- a/ynnpack/kernels/ternary/ternary.cc +++ b/ynnpack/kernels/ternary/ternary.cc @@ -24,7 +24,7 @@ template void quantize(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, const A* a, size_t stride_b_m, size_t stride_b_n, const float* b, size_t stride_c_m, size_t stride_c_n, const int32_t* c, - size_t stride_x_m, X* x) { + size_t stride_x_m, X* x, const ternary_params* params) { for (size_t i = 0; i < m; ++i) { // There are 8 cases of broadcasting. Here, we only specialize for // broadcasting b, because it permits lifting a division out of the loop. @@ -53,7 +53,8 @@ template void dequantize(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, const A* a, size_t stride_b_m, size_t stride_b_n, const int32_t* b, size_t stride_c_m, size_t stride_c_n, - const float* c, size_t stride_x_m, X* x) { + const float* c, size_t stride_x_m, X* x, + const ternary_params* params) { for (size_t i = 0; i < m; ++i) { for (size_t j = 0; j < n; ++j) { const A a_j = *offset_bytes(a, j * stride_a_n); @@ -74,22 +75,22 @@ void quantize_fp32_to_int8(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, const void* a, size_t stride_b_m, size_t stride_b_n, const void* b, size_t stride_c_m, size_t stride_c_n, const void* c, size_t stride_x_m, - void* x) { + void* x, const ternary_params* params) { quantize(m, n, stride_a_m, stride_a_n, reinterpret_cast(a), stride_b_m, stride_b_n, reinterpret_cast(b), stride_c_m, stride_c_n, reinterpret_cast(c), - stride_x_m, reinterpret_cast(x)); + stride_x_m, reinterpret_cast(x), params); } void quantize_fp32_to_uint8(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, const void* a, size_t stride_b_m, size_t stride_b_n, const void* b, size_t stride_c_m, size_t stride_c_n, const void* c, size_t stride_x_m, - void* x) { + void* x, const ternary_params* params) { quantize(m, n, stride_a_m, stride_a_n, reinterpret_cast(a), stride_b_m, stride_b_n, reinterpret_cast(b), stride_c_m, stride_c_n, reinterpret_cast(c), - stride_x_m, reinterpret_cast(x)); + stride_x_m, reinterpret_cast(x), params); } void dequantize_int8_to_fp32(size_t m, size_t n, size_t stride_a_m, @@ -97,11 +98,12 @@ void dequantize_int8_to_fp32(size_t m, size_t n, size_t stride_a_m, size_t stride_b_m, size_t stride_b_n, const void* b, size_t stride_c_m, size_t stride_c_n, const void* c, - size_t stride_x_m, void* x) { + size_t stride_x_m, void* x, + const ternary_params* params) { dequantize(m, n, stride_a_m, stride_a_n, reinterpret_cast(a), stride_b_m, stride_b_n, reinterpret_cast(b), stride_c_m, stride_c_n, reinterpret_cast(c), - stride_x_m, reinterpret_cast(x)); + stride_x_m, reinterpret_cast(x), params); } void dequantize_uint8_to_fp32(size_t m, size_t n, size_t stride_a_m, @@ -109,11 +111,12 @@ void dequantize_uint8_to_fp32(size_t m, size_t n, size_t stride_a_m, size_t stride_b_m, size_t stride_b_n, const void* b, size_t stride_c_m, size_t stride_c_n, const void* c, - size_t stride_x_m, void* x) { + size_t stride_x_m, void* x, + const ternary_params* params) { dequantize(m, n, stride_a_m, stride_a_n, reinterpret_cast(a), stride_b_m, stride_b_n, reinterpret_cast(b), stride_c_m, stride_c_n, reinterpret_cast(c), - stride_x_m, reinterpret_cast(x)); + stride_x_m, reinterpret_cast(x), params); } void dequantize_int32_to_fp32(size_t m, size_t n, size_t stride_a_m, @@ -121,11 +124,12 @@ void dequantize_int32_to_fp32(size_t m, size_t n, size_t stride_a_m, size_t stride_b_m, size_t stride_b_n, const void* b, size_t stride_c_m, size_t stride_c_n, const void* c, - size_t stride_x_m, void* x) { + size_t stride_x_m, void* x, + const ternary_params* params) { dequantize(m, n, stride_a_m, stride_a_n, reinterpret_cast(a), stride_b_m, stride_b_n, reinterpret_cast(b), stride_c_m, stride_c_n, reinterpret_cast(c), - stride_x_m, reinterpret_cast(x)); + stride_x_m, reinterpret_cast(x), params); } ternary_kernel_fn get_ternary_kernel(ternary_op op, ynn_type type_a, @@ -164,4 +168,6 @@ const char* to_string(ternary_op op) { return "unknown"; } +ternary_params get_ternary_params(ternary_op op) { return ternary_params{}; } + } // namespace ynn diff --git a/ynnpack/kernels/ternary/ternary.h b/ynnpack/kernels/ternary/ternary.h index 57e80880ff2..2b4376fa85c 100644 --- a/ynnpack/kernels/ternary/ternary.h +++ b/ynnpack/kernels/ternary/ternary.h @@ -14,6 +14,8 @@ namespace ynn { +union ternary_params {}; + // The stride of dimension `n` for any operand must be 0 or the size of one // element. typedef void (*ternary_kernel_fn)(size_t m, size_t n, size_t stride_a_m, @@ -21,13 +23,15 @@ typedef void (*ternary_kernel_fn)(size_t m, size_t n, size_t stride_a_m, size_t stride_b_m, size_t stride_b_n, const void* b, size_t stride_c_m, size_t stride_c_n, const void* c, - size_t stride_x_m, void* x); + size_t stride_x_m, void* x, + const ternary_params* params); #define YNN_ELEMENTWISE_KERNEL(arch, name, op, type_a, type_b, type_c, type_x) \ void name(size_t m, size_t n, size_t stride_a_m, size_t stride_a_n, \ const void* a, size_t stride_b_m, size_t stride_b_n, \ const void* b, size_t stride_c_m, size_t stride_c_n, \ - const void* c, size_t stride_x_m, void* x); + const void* c, size_t stride_x_m, void* x, \ + const ternary_params* params); #include "ynnpack/kernels/ternary/kernels.inc" #undef YNN_ELEMENTWISE_KERNEL @@ -51,6 +55,8 @@ ternary_kernel_fn get_ternary_kernel(ternary_op op, ynn_type type_a, ynn_type type_b, ynn_type type_c, ynn_type type_x); +ternary_params get_ternary_params(ternary_op op); + } // namespace ynn #endif // XNNPACK_YNNPACK_KERNELS_TERNARY_H_ diff --git a/ynnpack/kernels/ternary/test.cc b/ynnpack/kernels/ternary/test.cc index adb224fad7a..c427db9be6b 100644 --- a/ynnpack/kernels/ternary/test.cc +++ b/ynnpack/kernels/ternary/test.cc @@ -99,7 +99,7 @@ void TestImpl(const KernelInfo& kernel_info, const OpInfo& op_info, size_t m, kernel(m, n, a.stride(0) * sizeof(A), a.stride(1) * sizeof(A), a.base(), b.stride(0) * sizeof(B), b.stride(1) * sizeof(B), b.base(), c.stride(0) * sizeof(C), c.stride(1) * sizeof(C), c.base(), - x.stride(0) * sizeof(X), x.base()); + x.stride(0) * sizeof(X), x.base(), nullptr); check_results(op_info, a, b, c, x, a_quantization, b_quantization, c_quantization, x_quantization); diff --git a/ynnpack/kernels/unary/BUILD b/ynnpack/kernels/unary/BUILD index 944049a754e..069fad27072 100644 --- a/ynnpack/kernels/unary/BUILD +++ b/ynnpack/kernels/unary/BUILD @@ -77,6 +77,7 @@ cc_library( hdrs = ["reference.h"], visibility = ["//ynnpack:__subpackages__"], deps = [ + ":unary", "//ynnpack:ynnpack_h", "//ynnpack/base", "//ynnpack/base/test:tensor", diff --git a/ynnpack/kernels/unary/bench.cc b/ynnpack/kernels/unary/bench.cc index 65e977f642d..2f158747740 100644 --- a/ynnpack/kernels/unary/bench.cc +++ b/ynnpack/kernels/unary/bench.cc @@ -19,7 +19,7 @@ namespace ynn { template void bench(benchmark::State& state, uint64_t arch_flags, unary_kernel_fn kernel, - TA, TX) { + const unary_params& params, TA, TX) { if (!is_arch_supported(arch_flags)) { state.SkipWithMessage("Unsupported hardware"); return; @@ -37,7 +37,7 @@ void bench(benchmark::State& state, uint64_t arch_flags, unary_kernel_fn kernel, for (auto _ : state) { kernel(m, n, a.stride(0) * sizeof(TA), a.base(), x.stride(0) * sizeof(TX), - x.base()); + x.base(), ¶ms); } const size_t ops = m * n; @@ -49,14 +49,15 @@ void bench(benchmark::State& state, uint64_t arch_flags, unary_kernel_fn kernel, benchmark::Counter::kIsRate); } -void bench_reference(benchmark::State& state, unary_kernel_fn kernel) { - return bench(state, arch_flag::none, kernel, float{}, float{}); +void bench_reference(benchmark::State& state, unary_kernel_fn kernel, + const unary_params& params) { + return bench(state, arch_flag::none, kernel, params, float{}, float{}); } template void bench_reference_convert(benchmark::State& state, unary_kernel_fn kernel, - TA, TX) { - return bench(state, arch_flag::none, kernel, TA{}, TX{}); + const unary_params& params, TA, TX) { + return bench(state, arch_flag::none, kernel, params, TA{}, TX{}); } template @@ -70,7 +71,8 @@ void Params(benchmark::Benchmark* b) { #define BENCHMARK_REFERENCE(op, type) \ BENCHMARK_CAPTURE( \ bench_reference, op##_##type, \ - get_unary_reference_kernel(ynn_unary_##op, type_of())) \ + get_unary_reference_kernel(ynn_unary_##op, type_of()), \ + get_unary_params(ynn_unary_##op)) \ ->Apply(Params) \ ->UseRealTime(); @@ -102,7 +104,7 @@ BENCHMARK_REFERENCE(sign, int32_t); BENCHMARK_CAPTURE( \ bench_reference_convert, type_a##_##type_x, \ get_convert_reference_kernel(type_of(), type_of()), \ - type_a(), type_x()) \ + get_unary_params(ynn_unary_convert), type_a(), type_x()) \ ->Apply(Params) \ ->UseRealTime(); @@ -121,9 +123,10 @@ BENCHMARK_REFERENCE_CONVERT_FROM(int8_t); BENCHMARK_REFERENCE_CONVERT_FROM(uint8_t); BENCHMARK_REFERENCE_CONVERT_FROM(int32_t); -#define YNN_ELEMENTWISE_KERNEL(arch_flags, kernel, op, type_a, type_x) \ - BENCHMARK_CAPTURE(bench, kernel, arch_flags, kernel, type_a(), type_x()) \ - ->Apply(Params) \ +#define YNN_ELEMENTWISE_KERNEL(arch_flags, kernel, op, type_a, type_x) \ + BENCHMARK_CAPTURE(bench, kernel, arch_flags, kernel, \ + get_unary_params(ynn_unary_##op), type_a(), type_x()) \ + ->Apply(Params) \ ->UseRealTime(); #include "ynnpack/kernels/unary/kernels.inc" #undef YNN_ELEMENTWISE_KERNEL diff --git a/ynnpack/kernels/unary/exp.py b/ynnpack/kernels/unary/exp.py index c28e1040199..e6f78051a22 100644 --- a/ynnpack/kernels/unary/exp.py +++ b/ynnpack/kernels/unary/exp.py @@ -1,8 +1,7 @@ """Definition of exp kernel.""" -import math - # pylint: disable=undefined-variable +# pylint: disable=missing-function-docstring from ynnpack.kernels.elementwise.compiler import * # pylint: disable=wildcard-import @@ -30,8 +29,11 @@ def qd_round_f32(a): @const_buffer("a", Float(32)) @buffer("x", Float(32)) +@params( + Scalar("input_multiplier", Float(32)), +) @operator_name("exp") -def exp_fp32(a, x): +def exp_fp32(a, x, input_multiplier): # The monomial coefficients of the numerator polynomial (`valpha_0` = 1.0). valpha_1 = 4.1594290733e-01 valpha_2 = 7.2068706155e-02 @@ -41,9 +43,9 @@ def exp_fp32(a, x): vbeta_1 = -2.7720427513e-01 vbeta_2 = 2.3986088112e-02 - va = load(a) + va = load(a) * input_multiplier # Clamp `vz_prime = x * log2(e)` to the maximum exponents [-127, 128]. - vz_prime = min(max(va * f32(math.log2(math.e)), -127.0), 128.0) + vz_prime = min(max(va, -127.0), 128.0) # Decompose x * log2e into `z` (integer part) and `r` (remainder). vz = qd_round_f32(vz_prime) diff --git a/ynnpack/kernels/unary/reference.cc b/ynnpack/kernels/unary/reference.cc index f73567a8a02..5940ffee2de 100644 --- a/ynnpack/kernels/unary/reference.cc +++ b/ynnpack/kernels/unary/reference.cc @@ -9,72 +9,51 @@ namespace ynn { -const unary_op_info* get_unary_op_info(ynn_unary_operator op) { - static abs abs; - static convert convert; - static exp exp; - static expm1 expm1; - static erf erf; - static log log; - static log1p log1p; - static negate negate; - static reciprocal_square_root reciprocal_square_root; - static sigmoid sigmoid; - static square square; - static square_root square_root; - static tanh tanh; - static round round; - static ceil ceil; - static floor floor; - static cube_root cube_root; - static sign sign; - static sine sine; - static cosine cosine; - static hardswish hardswish; - +std::unique_ptr get_unary_op_info(ynn_unary_operator op, + const unary_params& params) { switch (op) { case ynn_unary_abs: - return &abs; + return std::make_unique(params); case ynn_unary_round: - return &round; + return std::make_unique(params); case ynn_unary_ceil: - return &ceil; + return std::make_unique(params); case ynn_unary_convert: - return &convert; + return std::make_unique(params); case ynn_unary_exp: - return &exp; + return std::make_unique(params); case ynn_unary_expm1: - return &expm1; + return std::make_unique(params); case ynn_unary_erf: - return &erf; + return std::make_unique(params); case ynn_unary_floor: - return &floor; + return std::make_unique(params); case ynn_unary_log: - return &log; + return std::make_unique(params); case ynn_unary_log1p: - return &log1p; + return std::make_unique(params); case ynn_unary_negate: - return &negate; + return std::make_unique(params); case ynn_unary_reciprocal_square_root: - return &reciprocal_square_root; + return std::make_unique(params); case ynn_unary_square: - return □ + return std::make_unique(params); case ynn_unary_square_root: - return &square_root; + return std::make_unique(params); case ynn_unary_tanh: - return &tanh; + return std::make_unique(params); case ynn_unary_cube_root: - return &cube_root; + return std::make_unique(params); case ynn_unary_sign: - return &sign; + return std::make_unique(params); case ynn_unary_sine: - return &sine; + return std::make_unique(params); case ynn_unary_cosine: - return &cosine; + return std::make_unique(params); case ynn_unary_sigmoid: - return &sigmoid; + return std::make_unique(params); case ynn_unary_hardswish: - return &hardswish; + return std::make_unique(params); case ynn_unary_invalid: return nullptr; } diff --git a/ynnpack/kernels/unary/reference.h b/ynnpack/kernels/unary/reference.h index 963b9a1bd80..218669d588e 100644 --- a/ynnpack/kernels/unary/reference.h +++ b/ynnpack/kernels/unary/reference.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -20,6 +21,7 @@ #include "ynnpack/base/test/tensor.h" #include "ynnpack/base/type.h" #include "ynnpack/include/ynnpack.h" +#include "ynnpack/kernels/unary/unary.h" namespace ynn { @@ -118,6 +120,7 @@ struct unary_op_info { }; struct convert : public unary_op_info { + explicit convert(const unary_params& = {}) {} float operator()(float x) const override { return x; } int32_t operator()(int32_t x) const override { return x; } @@ -133,28 +136,34 @@ struct convert : public unary_op_info { }; struct abs : public unary_op_info { + explicit abs(const unary_params& = {}) {} float operator()(float x) const override { return std::abs(x); } int32_t operator()(int32_t x) const override { return std::abs(x); } }; struct negate : public unary_op_info { + explicit negate(const unary_params& = {}) {} float operator()(float x) const override { return -x; } int32_t operator()(int32_t x) const override { return -x; } }; struct round : public unary_op_info { + explicit round(const unary_params& = {}) {} float operator()(float x) const override { return std::nearbyint(x); } }; struct ceil : public unary_op_info { + explicit ceil(const unary_params& = {}) {} float operator()(float x) const override { return std::ceil(x); } }; struct floor : public unary_op_info { + explicit floor(const unary_params& = {}) {} float operator()(float x) const override { return std::floor(x); } }; struct sigmoid : public unary_op_info { + explicit sigmoid(const unary_params& = {}) {} float operator()(float x) const override { return 1.0 / (1.0 + std::exp(static_cast(-x))); } @@ -174,6 +183,7 @@ struct sigmoid : public unary_op_info { }; struct square : public unary_op_info { + explicit square(const unary_params& = {}) {} float operator()(float x) const override { return x * x; } int32_t operator()(int32_t x) const override { return static_cast(static_cast(x) * @@ -199,6 +209,7 @@ struct square : public unary_op_info { }; struct square_root : public unary_op_info { + explicit square_root(const unary_params& = {}) {} float operator()(float x) const override { return std::sqrt(x); } float tolerance(float y_ref, ynn_type type) const override { @@ -240,6 +251,7 @@ struct square_root : public unary_op_info { }; struct tanh : public unary_op_info { + explicit tanh(const unary_params& = {}) {} float operator()(float x) const override { return std::tanh(x); } float tolerance(float y_ref, ynn_type type) const override { @@ -265,6 +277,7 @@ struct tanh : public unary_op_info { }; struct reciprocal_square_root : public unary_op_info { + explicit reciprocal_square_root(const unary_params& = {}) {} float operator()(float x) const override { return 1.0 / std::sqrt(x); } float tolerance(float y_ref, ynn_type type) const override { @@ -306,6 +319,7 @@ struct reciprocal_square_root : public unary_op_info { }; struct log : public unary_op_info { + explicit log(const unary_params& = {}) {} float operator()(float x) const override { return std::log(x); } float tolerance(float y_ref, ynn_type type) const override { @@ -318,7 +332,12 @@ struct log : public unary_op_info { }; struct exp : public unary_op_info { - float operator()(float x) const override { return std::exp(x); } + unary_params params; + + explicit exp(const unary_params& params) : params(params) {} + float operator()(float x) const override { + return std::exp2(params.exp.input_multiplier * x); + } float tolerance(float y_ref, ynn_type type) const override { return tol_mixed(y_ref, 2 * epsilon(type), 6 * epsilon(type)); @@ -328,6 +347,7 @@ struct exp : public unary_op_info { }; struct log1p : public unary_op_info { + explicit log1p(const unary_params& = {}) {} float operator()(float x) const override { return std::log1p(x); } float tolerance(float y_ref, ynn_type type) const override { @@ -344,6 +364,7 @@ struct log1p : public unary_op_info { }; struct expm1 : public unary_op_info { + explicit expm1(const unary_params& = {}) {} float operator()(float x) const override { return std::expm1(x); } float tolerance(float y_ref, ynn_type type) const override { @@ -357,6 +378,7 @@ struct expm1 : public unary_op_info { }; struct erf : public unary_op_info { + explicit erf(const unary_params& = {}) {} float operator()(float x) const override { return std::erf(x); } float tolerance(float y_ref, ynn_type type) const override { @@ -369,6 +391,7 @@ struct erf : public unary_op_info { }; struct cube_root : public unary_op_info { + explicit cube_root(const unary_params& = {}) {} float operator()(float x) const override { return std::cbrt(x); } float tolerance(float y_ref, ynn_type type) const override { @@ -377,6 +400,7 @@ struct cube_root : public unary_op_info { }; struct sign : public unary_op_info { + explicit sign(const unary_params& = {}) {} float operator()(float x) const override { return x < 0 ? -1.0f : (x > 0 ? 1.0f : 0.0f); } @@ -386,6 +410,7 @@ struct sign : public unary_op_info { }; struct trig : public unary_op_info { + explicit trig(const unary_params& = {}) {} float tolerance(float y_ref, ynn_type type) const override { switch (type) { case ynn_type_fp32: @@ -404,14 +429,17 @@ struct trig : public unary_op_info { }; struct sine : public trig { + explicit sine(const unary_params& = {}) {} float operator()(float x) const override { return std::sin(x); } }; struct cosine : public trig { + explicit cosine(const unary_params& = {}) {} float operator()(float x) const override { return std::cos(x); } }; struct hardswish : public unary_op_info { + explicit hardswish(const unary_params& = {}) {} float operator()(float x) const override { return (x / 6.0) * std::max(std::min(x + 3.0, 6.0), 0.0); } @@ -435,7 +463,8 @@ struct hardswish : public unary_op_info { interval domain(ynn_type) const override { return {-4.0f, 4.0f}; } }; -const unary_op_info* get_unary_op_info(ynn_unary_operator op); +std::unique_ptr get_unary_op_info( + ynn_unary_operator op, const unary_params& params = {}); // Check that op(a) == x, within tolerances described by `op`. template diff --git a/ynnpack/kernels/unary/test.cc b/ynnpack/kernels/unary/test.cc index b018a820739..b835a90c8e9 100644 --- a/ynnpack/kernels/unary/test.cc +++ b/ynnpack/kernels/unary/test.cc @@ -65,22 +65,25 @@ std::string to_string(const Shape& shape) { struct KernelInfo { uint64_t arch_flags = 0; unary_kernel_fn kernel; + unary_params params; // Constructor for a reference kernel. - KernelInfo(ynn_unary_operator op, ynn_type type) { + KernelInfo(ynn_unary_operator op, ynn_type type, unary_params p) { kernel = get_unary_reference_kernel(op, type); + params = p; assert(kernel); } // Constructor for a reference convert op. - KernelInfo(ynn_type a_type, ynn_type x_type) { + KernelInfo(ynn_type a_type, ynn_type x_type, unary_params p) { kernel = get_convert_reference_kernel(a_type, x_type); + params = p; assert(kernel); } // Constructor for a kernel function. - KernelInfo(uint64_t arch_flags, unary_kernel_fn kernel) - : arch_flags(arch_flags), kernel(kernel) { + KernelInfo(uint64_t arch_flags, unary_kernel_fn kernel, unary_params p) + : arch_flags(arch_flags), kernel(kernel), params(p) { assert(kernel); } }; @@ -105,7 +108,7 @@ void TestImpl(A, X, const KernelInfo& kernel_info, const OpInfo& op_info, x = x.crop_padding({0, 0}, {0, shape.padding_x}); kernel(shape.m, shape.n, a.stride(0) * sizeof(A), a.base(), - x.stride(0) * sizeof(X), x.base()); + x.stride(0) * sizeof(X), x.base(), &kernel_info.params); check_results(op_info, a, x); } @@ -137,10 +140,10 @@ TEST_P(Reference, op) { ynn_type type = std::get<0>(GetParam()); ynn_unary_operator op = std::get<1>(GetParam()); const Shape& shape = std::get<2>(GetParam()); - const unary_op_info& op_info = *get_unary_op_info(op); - KernelInfo kernel_info(op, type); + auto op_info = get_unary_op_info(op, get_unary_params(op)); + KernelInfo kernel_info(op, type, get_unary_params(op)); SwitchType(type, [&](auto type) { - TestImpl(type, type, kernel_info, op_info, shape); + TestImpl(type, type, kernel_info, *op_info, shape); }); } @@ -150,7 +153,7 @@ class ReferenceConvert TEST_P(ReferenceConvert, op) { ynn_type a = std::get<0>(GetParam()); ynn_type x = std::get<1>(GetParam()); - KernelInfo kernel_info(a, x); + KernelInfo kernel_info(a, x, get_unary_params(ynn_unary_convert)); const Shape& shape = std::get<2>(GetParam()); SwitchType(x, [&](auto x) { SwitchType(a, @@ -209,6 +212,23 @@ const Shape reference_shapes[] = { {256, 4, 0, padding}, }; +TEST(Exp, CustomParams) { + unary_params params; + params.exp.input_multiplier = 0.7f; + const Shape shape = {1, 32, 0, 0}; + const ynn_type type = ynn_type_fp32; + const ynn_unary_operator op = ynn_unary_exp; + exp op_info(params); + + unary_kernel_fn kernel = get_unary_kernel(op, type, type); + if (kernel == nullptr) { + GTEST_SKIP() << "No exp kernel found"; + } + + KernelInfo kernel_info(get_supported_arch_flags(), kernel, params); + TestImpl(float(), float(), kernel_info, op_info, shape); +} + INSTANTIATE_TEST_SUITE_P(RealOps, Reference, Combine(Values(ynn_type_fp32), ValuesIn(all_real_ops), ValuesIn(reference_shapes)), @@ -242,14 +262,17 @@ const std::vector all_shapes = []() { return shapes; }(); -#define YNN_ELEMENTWISE_KERNEL(arch_flags, kernel, op, type_a, type_x) \ - class kernel##_test : public testing::TestWithParam {}; \ - TEST_P(kernel##_test, no_broadcast) { \ - KernelInfo kernel_info(arch_flags, kernel); \ - TestImpl(type_a{}, type_x{}, kernel_info, op{}, GetParam()); \ - } \ - INSTANTIATE_TEST_SUITE_P(test, kernel##_test, ValuesIn(all_shapes), \ - [](const auto& i) { return to_string(i.param); }); +#define YNN_ELEMENTWISE_KERNEL(arch_flags, kernel, op, type_a, type_x) \ + class kernel##_test : public testing::TestWithParam {}; \ + TEST_P(kernel##_test, no_broadcast) { \ + ynn::KernelInfo kernel_info(arch_flags, kernel, \ + ynn::get_unary_params(ynn_unary_##op)); \ + ynn::TestImpl(type_a{}, type_x{}, kernel_info, \ + ynn::op(kernel_info.params), GetParam()); \ + } \ + INSTANTIATE_TEST_SUITE_P( \ + test, kernel##_test, ValuesIn(ynn::all_shapes), \ + [](const auto& i) { return ynn::to_string(i.param); }); #include "ynnpack/kernels/unary/kernels.inc" #undef YNN_ELEMENTWISE_KERNEL diff --git a/ynnpack/kernels/unary/unary.cc b/ynnpack/kernels/unary/unary.cc index e479cff161d..393e210b9b0 100644 --- a/ynnpack/kernels/unary/unary.cc +++ b/ynnpack/kernels/unary/unary.cc @@ -29,11 +29,11 @@ namespace { // intend to give the compiler a reasonable chance at optimizing them. template void unary_impl(size_t m, size_t n, size_t stride_x, const void* vx, - size_t stride_y, void* vy) { + size_t stride_y, void* vy, const unary_params* params) { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); - Operator op; + Operator op(*params); for (size_t i = 0; i < m; ++i) { for (size_t j = 0; j < n; ++j) { y[j] = static_cast(op(x[j])); @@ -45,6 +45,7 @@ void unary_impl(size_t m, size_t n, size_t stride_x, const void* vx, template struct convert_op { + explicit convert_op(const unary_params& = {}) {} TOut operator()(TIn x) const { if constexpr (std::is_integral::value) { if constexpr (std::is_integral::value) { @@ -61,6 +62,7 @@ struct convert_op { #if XNN_HAVE_FLOAT16 template <> struct convert_op { + explicit convert_op(const unary_params& = {}) {} _Float16 operator()(bfloat16 x) const { return static_cast<_Float16>(static_cast(x)); } @@ -93,28 +95,34 @@ unary_kernel_fn get_convert_kernel(ynn_type output) { } struct abs_op { + explicit abs_op(const unary_params& = {}) {} float operator()(float x) const { return std::abs(x); } int32_t operator()(int32_t x) const { return std::abs(x); } }; struct negate_op { + explicit negate_op(const unary_params& = {}) {} float operator()(float x) const { return -x; } int32_t operator()(int32_t x) const { return -x; } }; struct round_op { + explicit round_op(const unary_params& = {}) {} float operator()(float x) const { return std::nearbyint(x); } }; struct ceil_op { + explicit ceil_op(const unary_params& = {}) {} float operator()(float x) const { return std::ceil(x); } }; struct floor_op { + explicit floor_op(const unary_params& = {}) {} float operator()(float x) const { return std::floor(x); } }; struct square_op { + explicit square_op(const unary_params& = {}) {} float operator()(float x) const { return x * x; } int32_t operator()(int32_t x) const { return static_cast(static_cast(x) * @@ -123,59 +131,76 @@ struct square_op { }; struct square_root_op { + explicit square_root_op(const unary_params& = {}) {} float operator()(float x) const { return std::sqrt(x); } }; struct cube_root_op { + explicit cube_root_op(const unary_params& = {}) {} float operator()(float x) const { return std::cbrt(x); } }; struct tanh_op { + explicit tanh_op(const unary_params& = {}) {} float operator()(float x) const { return std::tanh(x); } }; struct reciprocal_square_root_op { + explicit reciprocal_square_root_op(const unary_params& = {}) {} float operator()(float x) const { return 1 / std::sqrt(x); } }; struct log_op { + explicit log_op(const unary_params& = {}) {} float operator()(float x) const { return std::log(x); } }; struct log1p_op { + explicit log1p_op(const unary_params& = {}) {} float operator()(float x) const { return std::log1p(x); } }; struct exp_op { - float operator()(float x) const { return std::exp(x); } + explicit exp_op(const unary_params& params) : params(params) {} + float operator()(float x) const { + return std::exp2(params.exp.input_multiplier * x); + } + unary_params params; }; struct expm1_op { + explicit expm1_op(const unary_params& = {}) {} float operator()(float x) const { return std::expm1(x); } }; struct erf_op { + explicit erf_op(const unary_params& = {}) {} float operator()(float x) const { return std::erf(x); } }; struct sign_op { + explicit sign_op(const unary_params& = {}) {} float operator()(float x) const { return x < 0 ? -1 : x > 0 ? 1 : 0; } int32_t operator()(int32_t x) const { return x < 0 ? -1 : x > 0 ? 1 : 0; } }; struct sine_op { + explicit sine_op(const unary_params& = {}) {} float operator()(float x) const { return std::sin(x); } }; struct cosine_op { + explicit cosine_op(const unary_params& = {}) {} float operator()(float x) const { return std::cos(x); } }; struct sigmoid_op { + explicit sigmoid_op(const unary_params& = {}) {} float operator()(float x) const { return 1.0f / (1.0f + std::exp(-x)); } }; struct hardswish_op { + explicit hardswish_op(const unary_params& = {}) {} float operator()(float x) const { return (x * (1.0f / 6.0f)) * std::max(std::min(x + 3.0f, 6.0f), 0.0f); } @@ -293,4 +318,16 @@ unary_kernel_fn get_unary_kernel(ynn_unary_operator op, ynn_type a_type, } } +unary_params get_unary_params(ynn_unary_operator op) { + switch (op) { + case ynn_unary_exp: + return unary_params{ + .exp = exp_params{static_cast(std::log2(std::exp(1.0)))}}; + default: + return unary_params{}; + } + + return unary_params{}; +} + } // namespace ynn diff --git a/ynnpack/kernels/unary/unary.h b/ynnpack/kernels/unary/unary.h index d33011fceed..a5ca3556a14 100644 --- a/ynnpack/kernels/unary/unary.h +++ b/ynnpack/kernels/unary/unary.h @@ -15,12 +15,21 @@ namespace ynn { +struct exp_params { + float input_multiplier; +}; + +union unary_params { + exp_params exp; +}; + typedef void (*unary_kernel_fn)(size_t width, size_t height, size_t stride_a, - const void* a, size_t stride_x, void* x); + const void* a, size_t stride_x, void* x, + const unary_params* params); #define YNN_ELEMENTWISE_KERNEL(arch, name, op, type_a, type_c) \ void name(size_t m, size_t n, size_t stride_a_m, const void* a, \ - size_t stride_x_m, void* x); + size_t stride_x_m, void* x, const unary_params* params); #include "ynnpack/kernels/unary/kernels.inc" #undef YNN_ELEMENTWISE_KERNEL @@ -36,6 +45,8 @@ unary_kernel_fn get_unary_kernel( ynn_unary_operator op, ynn_type a_type, ynn_type x_type, uint64_t supported_arch_flags = get_supported_arch_flags()); +unary_params get_unary_params(ynn_unary_operator op); + } // namespace ynn #endif // XNNPACK_YNNPACK_KERNELS_UNARY_H_ diff --git a/ynnpack/subgraph/elementwise.cc b/ynnpack/subgraph/elementwise.cc index cbb4678f3a2..cd5acf68ec3 100644 --- a/ynnpack/subgraph/elementwise.cc +++ b/ynnpack/subgraph/elementwise.cc @@ -35,29 +35,29 @@ namespace ynn { namespace { // Call a unary kernel. -auto make_unary_elementwise_impl(unary_kernel_fn kernel) { - return - [kernel](slinky::raw_buffer a, slinky::raw_buffer x) -> slinky::index_t { - slinky::dim a_dims[2], x_dims[2]; +auto make_unary_elementwise_impl(unary_kernel_fn kernel, unary_params params) { + return [kernel, params](slinky::raw_buffer a, + slinky::raw_buffer x) -> slinky::index_t { + slinky::dim a_dims[2], x_dims[2]; - fuse_and_slice_leading_dims<2>(&x_dims[0], x, &a_dims[0], a); + fuse_and_slice_leading_dims<2>(&x_dims[0], x, &a_dims[0], a); - // We don't support broadcasting of `a` here in the innermost - // dimension (and it would waste computation). - assert(is_continguous(a_dims[0], a.elem_size)); + // We don't support broadcasting of `a` here in the innermost + // dimension (and it would waste computation). + assert(is_continguous(a_dims[0], a.elem_size)); - const slinky::dim& x_n = x_dims[0]; - const slinky::dim& a_m = a_dims[1]; - const slinky::dim& x_m = x_dims[1]; + const slinky::dim& x_n = x_dims[0]; + const slinky::dim& a_m = a_dims[1]; + const slinky::dim& x_m = x_dims[1]; - slinky::for_each_element( - [&](void* x, const void* a) { - kernel(x_m.extent(), x_n.extent(), a_m.stride(), a, x_m.stride(), - x); - }, - x, a); - return 0; - }; + slinky::for_each_element( + [&](void* x, const void* a) { + kernel(x_m.extent(), x_n.extent(), a_m.stride(), a, x_m.stride(), x, + ¶ms); + }, + x, a); + return 0; + }; } // Call a lut kernel. @@ -100,7 +100,7 @@ auto make_binary_elementwise_impl(binary_kernel_fn kernel) { slinky::for_each_element( [&](void* x, const void* a, const void* b) { kernel(x_m.extent(), x_n.extent(), a_m.stride(), a_n.stride(), a, - b_m.stride(), b_n.stride(), b, x_m.stride(), x); + b_m.stride(), b_n.stride(), b, x_m.stride(), x, nullptr); }, x, a, b); return 0; @@ -140,7 +140,7 @@ auto make_ternary_elementwise_impl(ternary_kernel_fn kernel) { [&](void* x, const void* a, const void* b, const void* c) { kernel(x_m.extent(), x_n.extent(), a_m.stride(), a_n.stride(), a, b_m.stride(), b_n.stride(), b, c_m.stride(), c_n.stride(), - c, x_m.stride(), x); + c, x_m.stride(), x, nullptr); }, x, a, b, c); return 0; @@ -162,7 +162,7 @@ std::pair GetScalarQuantization( } ynn_status create_unary(const ynn_node& node, ynn_runtime& runtime, - unary_kernel_fn kernel) { + unary_kernel_fn kernel, unary_params params) { assert(node.inputs.size() == 1); assert(node.outputs.size() == 1); @@ -193,9 +193,9 @@ ynn_status create_unary(const ynn_node& node, ynn_runtime& runtime, attrs.name = to_string(std::get(node.op).op); attrs.allow_in_place = compute_allow_in_place(node, *runtime.subgraph); - slinky::func func = slinky::func::make(make_unary_elementwise_impl(kernel), - {{a.buffer, std::move(bounds)}}, - {{x.buffer, dims}}, std::move(attrs)); + slinky::func func = slinky::func::make( + make_unary_elementwise_impl(kernel, params), + {{a.buffer, std::move(bounds)}}, {{x.buffer, dims}}, std::move(attrs)); auto sched = runtime.make_schedule(dims, x.buffer, node.outputs[0]); func.user_data() = sched.get(); @@ -332,14 +332,14 @@ void infer_shape(ynn_node& node, ynn_subgraph& subgraph) { void define_unary(ynn_subgraph& subgraph, ynn_node& node, uint32_t input_a_id, uint32_t output_id, ynn_unary_operator op, - unary_kernel_fn kernel) { + unary_kernel_fn kernel, unary_params params) { // Make the node. node.inputs = {input_a_id}; node.outputs = {output_id}; node.op = ynn_node::unary_elementwise{op}; infer_shape(node, subgraph); - node.create = [kernel](const ynn_node& node, ynn_runtime& runtime) { - return create_unary(node, runtime, kernel); + node.create = [kernel, params](const ynn_node& node, ynn_runtime& runtime) { + return create_unary(node, runtime, kernel, params); }; } @@ -457,9 +457,12 @@ ynn_status ynn_define_unary(ynn_subgraph_t subgraph, ynn_unary_operator op, return ynn_status_unsupported_parameter; } + unary_params params = get_unary_params(op); + // Make the node. ynn_node node; - ynn::define_unary(*subgraph, node, input_a_id, *output_id, op, kernel); + ynn::define_unary(*subgraph, node, input_a_id, *output_id, op, kernel, + params); subgraph->add_node(std::move(node)); return ynn_status_success; } @@ -650,7 +653,7 @@ ynn_status ynn_define_convert(ynn_subgraph_t subgraph, uint32_t input_id, node.op = ynn_node::unary_elementwise{ynn_unary_convert}; infer_shape(node, *subgraph); node.create = [kernel](const ynn_node& node, ynn_runtime& runtime) { - return create_unary(node, runtime, kernel); + return create_unary(node, runtime, kernel, unary_params{}); }; subgraph->add_node(std::move(node)); return ynn_status_success; diff --git a/ynnpack/subgraph/test/unary.cc b/ynnpack/subgraph/test/unary.cc index 0cf21aac3c3..96ee10b8c3b 100644 --- a/ynnpack/subgraph/test/unary.cc +++ b/ynnpack/subgraph/test/unary.cc @@ -73,7 +73,6 @@ void TestOp(A, X, const unary_op_info& op_info, ynn_unary_operator op) { ASSERT_EQ(runtime.GetExternalTensorShape(1), shape); runtime.SetupExternalTensor(output.data(), 1).InvokeRuntime(); - check_results(op_info, a, output, a_quantization, output_quantization); } } @@ -88,8 +87,8 @@ class RealOps template void TestOp(T type, ynn_unary_operator op) { - const unary_op_info& op_info = *get_unary_op_info(op); - TestOp(type, type, op_info, op); + auto op_info = get_unary_op_info(op, get_unary_params(op)); + TestOp(type, type, *op_info, op); } TEST_P(IntegerOps, op) {