From 0f0db5f040116f0c89ba5385586433c225454514 Mon Sep 17 00:00:00 2001 From: Adam Hillier <7688302+AdamHillier@users.noreply.github.com> Date: Mon, 21 Jun 2021 14:24:32 +0100 Subject: [PATCH] Extend quantiser support so as to accelerate more binary models. Add the ability to convert `tf.where`-style binary quantisers, and add support for boolean input to `LceQuantize` and `LceDequantize`. --- .../core/bitpacking/bitpack_aarch64.h | 9 +++ larq_compute_engine/mlir/ir/lce_ops.td | 8 +-- larq_compute_engine/mlir/tests/optimize.mlir | 44 ++++++++++++ .../mlir/tests/prepare-tf.mlir | 67 ++++++++++++++++++- .../transforms/optimize_patterns_common.td | 24 +++++++ .../transforms/prepare_patterns_common.td | 51 +++++++++++++- .../mlir/transforms/prepare_tf.cc | 8 +++ larq_compute_engine/tests/end2end_test.py | 36 +++++++--- larq_compute_engine/tflite/kernels/BUILD | 1 + .../tflite/kernels/quantization.cc | 37 ++++++++-- .../tflite/tests/quantization_test.cc | 2 + 11 files changed, 264 insertions(+), 23 deletions(-) diff --git a/larq_compute_engine/core/bitpacking/bitpack_aarch64.h b/larq_compute_engine/core/bitpacking/bitpack_aarch64.h index b716a776c..f76dd8559 100644 --- a/larq_compute_engine/core/bitpacking/bitpack_aarch64.h +++ b/larq_compute_engine/core/bitpacking/bitpack_aarch64.h @@ -9,12 +9,20 @@ #include "larq_compute_engine/core/types.h" #include "ruy/profiler/instrumentation.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace compute_engine { namespace core { namespace bitpacking { +template +inline void bitpack_aarch64_4x32(const T* input, std::size_t num_blocks, + TBitpacked* output, const T zero_point) { + TFLITE_ASSERT_FALSE; +} + // Bitpack an array of `4 * 32 * num_blocks` floats. +template <> inline void bitpack_aarch64_4x32(const float* input, std::size_t num_blocks, TBitpacked* output, const float zero_point /*ignored*/) { @@ -227,6 +235,7 @@ inline void bitpack_aarch64_4x32(const float* input, std::size_t num_blocks, } // Bitpack an array of `4 * 32 * num_blocks` int8 bytes. +template <> inline void bitpack_aarch64_4x32(const std::int8_t* input, std::size_t num_blocks, TBitpacked* output, const std::int8_t zero_byte) { diff --git a/larq_compute_engine/mlir/ir/lce_ops.td b/larq_compute_engine/mlir/ir/lce_ops.td index 9c790333d..c763ae065 100644 --- a/larq_compute_engine/mlir/ir/lce_ops.td +++ b/larq_compute_engine/mlir/ir/lce_ops.td @@ -70,11 +70,11 @@ def LQ_QuantizeOp : LQ_Op<"Quantize", [NoSideEffect]> { let summary = "Binary quantize operator"; let description = [{ -Converts floating point or integer tensors to binarized bitpacked tensors. +Converts floating point, integer, or boolean tensors to binarized bitpacked tensors. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$x + TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16, I1]>:$x ); let results = (outs @@ -90,7 +90,7 @@ def LQ_DequantizeOp : LQ_Op<"Dequantize", [NoSideEffect]> { let summary = "Binary dequantize operator"; let description = [{ -Converts binarized bitpacked tensors to floating point or integer tensors. +Converts binarized bitpacked tensors to floating point, integer, or boolean tensors. }]; let arguments = (ins @@ -98,7 +98,7 @@ Converts binarized bitpacked tensors to floating point or integer tensors. ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16]>:$y + TensorOf<[BF16, F16, F32, F64, I32, I64, QI8, QI16, I1]>:$y ); let hasFolder = 1; diff --git a/larq_compute_engine/mlir/tests/optimize.mlir b/larq_compute_engine/mlir/tests/optimize.mlir index eb467227e..354d6954c 100644 --- a/larq_compute_engine/mlir/tests/optimize.mlir +++ b/larq_compute_engine/mlir/tests/optimize.mlir @@ -1,6 +1,50 @@ // RUN: lce-tf-opt %s -tfl-optimize-lce=target=arm -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM // RUN: lce-tf-opt %s -tfl-optimize-lce=target=xcore -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-XCORE +// CHECK-LABEL: @optimize_quantize_greater_equal_zero +func @optimize_quantize_greater_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> { + %cst = constant dense<0.0> : tensor + %0 = "tfl.greater_equal"(%arg0, %cst) : (tensor<48x48x64xf32>, tensor) -> tensor<48x48x64xi1> + %1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32> + return %1 : tensor<48x48x2xi32> + + // CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32> + // CHECK-NEXT: return %0 +} + +// CHECK-LABEL: @optimize_quantize_greater_equal_non_zero +func @optimize_quantize_greater_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> { + %0 = "tfl.greater_equal"(%arg0, %arg1) : (tensor<48x48x64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1> + %1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32> + return %1 : tensor<48x48x2xi32> + + // CHECK-NEXT: %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<48x48x64xf32> + // CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32> + // CHECK-NEXT: return %1 +} + +// CHECK-LABEL: @optimize_quantize_less_equal_zero +func @optimize_quantize_less_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> { + %cst = constant dense<0.0> : tensor<64xf32> + %0 = "tfl.less_equal"(%cst, %arg0) : (tensor<64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1> + %1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32> + return %1 : tensor<48x48x2xi32> + + // CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32> + // CHECK-NEXT: return %0 +} + +// CHECK-LABEL: @optimize_quantize_less_equal_non_zero +func @optimize_quantize_less_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> { + %0 = "tfl.less_equal"(%arg0, %arg1) : (tensor<48x48x64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1> + %1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32> + return %1 : tensor<48x48x2xi32> + + // CHECK-NEXT: %0 = tfl.sub %arg1, %arg0 {fused_activation_function = "NONE"} : tensor<48x48x64xf32> + // CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<48x48x64xf32>) -> tensor<48x48x2xi32> + // CHECK-NEXT: return %1 +} + // CHECK-LABEL: @fuse_add_into_bconv2d func @fuse_add_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> { %cst = constant dense<1.5> : tensor<16xf32> diff --git a/larq_compute_engine/mlir/tests/prepare-tf.mlir b/larq_compute_engine/mlir/tests/prepare-tf.mlir index 68bf05cd3..aaeb4d5c4 100644 --- a/larq_compute_engine/mlir/tests/prepare-tf.mlir +++ b/larq_compute_engine/mlir/tests/prepare-tf.mlir @@ -1,8 +1,71 @@ // RUN: lce-tf-opt %s -tfl-prepare-lce=target=arm -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-ARM // RUN: lce-tf-opt %s -tfl-prepare-lce=target=xcore -verify-diagnostics | FileCheck %s --check-prefixes CHECK,CHECK-XCORE -// CHECK-LABEL: @fuse_bsign -func @fuse_bsign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { +// CHECK-LABEL: @fuse_bsign_tf_where +func @fuse_bsign_tf_where(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { + %cst_l = constant dense<1.0> : tensor<8x16xf32> + %cst_r = constant dense<-1.0> : tensor<8x16xf32> + %0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> + + // CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32> + // CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32> + // CHECK-NEXT: return %1 +} + +// CHECK-LABEL: @fuse_bsign_tf_where_inverted +func @fuse_bsign_tf_where_inverted(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { + %cst_l = constant dense<-1.0> : tensor<8x16xf32> + %cst_r = constant dense<1.0> : tensor<8x16xf32> + %0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> + + // CHECK-NEXT: %0 = "tf.LogicalNot"(%arg0) : (tensor<8x16xi1>) -> tensor<8x16xi1> + // CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<8x16xi1>) -> tensor<8x1xi32> + // CHECK-NEXT: %2 = "lq.Dequantize"(%1) : (tensor<8x1xi32>) -> tensor<8x16xf32> + // CHECK-NEXT: return %2 +} + +// CHECK-LABEL: @fuse_bsign_tf_where_broadcast_cond +func @fuse_bsign_tf_where_broadcast_cond(%arg0: tensor<8x1xi1>) -> tensor<8x16xf32> { + %cst_l = constant dense<1.0> : tensor<8x16xf32> + %cst_r = constant dense<-1.0> : tensor<8x16xf32> + %0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x1xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> + + // CHECK-NEXT: %cst = constant dense<[8, 16]> : tensor<2xi64> + // CHECK-NEXT: %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<8x1xi1>, tensor<2xi64>) -> tensor<8x16xi1> + // CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<8x16xi1>) -> tensor<8x1xi32> + // CHECK-NEXT: %2 = "lq.Dequantize"(%1) : (tensor<8x1xi32>) -> tensor<8x16xf32> + // CHECK-NEXT: return %2 +} + +// CHECK-LABEL: @fuse_bsign_tf_where_broadcast_lhs_rhs +func @fuse_bsign_tf_where_broadcast_lhs_rhs(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { + %cst_l = constant dense<1.0> : tensor + %cst_r = constant dense<-1.0> : tensor<8x1xf32> + %0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor, tensor<8x1xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> + + // CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32> + // CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32> + // CHECK-NEXT: return %1 +} + +// CHECK-LABEL: @fuse_bsign_tf_where_select_v1_op +func @fuse_bsign_tf_where_select_v1_op(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { + %cst_l = constant dense<1.0> : tensor<8x16xf32> + %cst_r = constant dense<-1.0> : tensor<8x16xf32> + %0 = "tf.Select"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> + + // CHECK-NEXT: %0 = "lq.Quantize"(%arg0) : (tensor<8x16xi1>) -> tensor<8x1xi32> + // CHECK-NEXT: %1 = "lq.Dequantize"(%0) : (tensor<8x1xi32>) -> tensor<8x16xf32> + // CHECK-NEXT: return %1 +} + +// CHECK-LABEL: @fuse_bsign_legacy_tf_sign +func @fuse_bsign_legacy_tf_sign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sign"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> %cst = constant dense<0.1> : tensor %2 = "tf.AddV2"(%0, %cst) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> diff --git a/larq_compute_engine/mlir/transforms/optimize_patterns_common.td b/larq_compute_engine/mlir/transforms/optimize_patterns_common.td index 8234f422d..e48b4a735 100644 --- a/larq_compute_engine/mlir/transforms/optimize_patterns_common.td +++ b/larq_compute_engine/mlir/transforms/optimize_patterns_common.td @@ -11,6 +11,30 @@ def HasOneUse : Constraint>; class ConstantValue : AttrConstraint>; +def : Pat<(LQ_QuantizeOp + (TFL_GreaterEqualOp:$ge_op + $input, + (ConstantOp ConstantValue<"0.0f">))), + (LQ_QuantizeOp $input), + [(HasOneUse $ge_op)], + (addBenefit 150)>; + +def : Pat<(LQ_QuantizeOp + (TFL_GreaterEqualOp:$ge_op + $input, + $threshold)), + (LQ_QuantizeOp + (TFL_SubOp $input, $threshold, TFL_AF_None)), + [(HasOneUse $ge_op)], + (addBenefit 100)>; + +def : Pat<(LQ_QuantizeOp + (TFL_LessEqualOp:$ge_op $lhs, $rhs)), + (LQ_QuantizeOp + (TFL_GreaterEqualOp $rhs, $lhs)), + [(HasOneUse $ge_op)], + (addBenefit 100)>; + // TODO: Check shapes before fusing multiclass FuseAddOrSubWithBConv2D { def : Pat<(binaryOp diff --git a/larq_compute_engine/mlir/transforms/prepare_patterns_common.td b/larq_compute_engine/mlir/transforms/prepare_patterns_common.td index bb39042d8..36b6d4249 100644 --- a/larq_compute_engine/mlir/transforms/prepare_patterns_common.td +++ b/larq_compute_engine/mlir/transforms/prepare_patterns_common.td @@ -4,9 +4,56 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "larq_compute_engine/mlir/ir/lce_ops.td" include "larq_compute_engine/mlir/transforms/op_removal_patterns.td" +class ConstantValue : AttrConstraint>; -// This relies on implementation details of larq.math.sign. We should make -// this more general in the future +def CreateTFBroadcastToOp : NativeCodeCall< + "$_builder.create(" + "$0.getLoc()," + "RankedTensorType::get(" + "$0.getType().cast().getShape()," + "getElementTypeOrSelf($1.getType()))," + "$1," + "$2)">; + +def CreateTFShapeOp : NativeCodeCall< + "$_builder.create($0.getLoc(), $1, $2)">; + +// Base quantiser patterns that match the `tf.where` implementation of `ste_sign`. +multiclass QuantDequantPatterns { + def : Pat<(SelectOp:$select_op + $cond, + (ConstantOp ConstantValue<"1.0f">), + (ConstantOp ConstantValue<"-1.0f">)), + (LQ_DequantizeOp + (LQ_QuantizeOp + (CreateTFBroadcastToOp + $select_op, + $cond, + (CreateTFShapeOp + $select_op, + $select_op, + /*use 32bit*/ConstBoolAttrFalse)))), + [], (addBenefit 100)>; + def : Pat<(SelectOp:$select_op + $cond, + (ConstantOp ConstantValue<"-1.0f">), + (ConstantOp ConstantValue<"1.0f">)), + (LQ_DequantizeOp + (LQ_QuantizeOp + (CreateTFBroadcastToOp + $select_op, + (TF_LogicalNotOp $cond), + (CreateTFShapeOp + $select_op, + $select_op, + /*use 32bit*/ConstBoolAttrFalse)))), + [], (addBenefit 100)>; +} +foreach SelectOp = [TF_SelectOp, TF_SelectV2Op] in + defm : QuantDequantPatterns; + +// A fallback for the old version of `ste_sign` that uses a specific `tf.sign` +// based implementation of `larq.math.sign`. def : Pat<(TF_SignOp (TF_AddV2Op (TF_SignOp $arg), $c)), (LQ_DequantizeOp (LQ_QuantizeOp $arg)), [], (addBenefit 100)>; def : Pat<(TF_SignOp (TF_AddV2Op $c, (TF_SignOp $arg))), diff --git a/larq_compute_engine/mlir/transforms/prepare_tf.cc b/larq_compute_engine/mlir/transforms/prepare_tf.cc index 83683512b..a07e6ebf1 100644 --- a/larq_compute_engine/mlir/transforms/prepare_tf.cc +++ b/larq_compute_engine/mlir/transforms/prepare_tf.cc @@ -36,6 +36,14 @@ struct PrepareLCE : public PassWrapper { clEnumValN(LCETarget::XCORE, "xcore", "XCORE target"))}; }; +bool IsConstantValue(Attribute values, float expected_value) { + if (!values.isa()) return false; + + for (auto value : values.cast().getValues()) { + if (value != expected_value) return false; + } + return true; +} DenseElementsAttr GetConstantVector(Attribute filter, float val) { auto filter_type = filter.getType().cast(); auto filter_shape = filter_type.getShape(); diff --git a/larq_compute_engine/tests/end2end_test.py b/larq_compute_engine/tests/end2end_test.py index 1c55eea01..3ad190a0d 100644 --- a/larq_compute_engine/tests/end2end_test.py +++ b/larq_compute_engine/tests/end2end_test.py @@ -1,3 +1,4 @@ +import functools import os import sys import tempfile @@ -23,7 +24,7 @@ def convert_keras_model_as_saved_model(model, **kwargs): return convert_saved_model(saved_model_dir, **kwargs) -def toy_model(**kwargs): +def toy_model(binary_quantizer="ste_sign", **kwargs): def block(padding, pad_values, activation): def dummy(x): shortcut = x @@ -32,8 +33,8 @@ def dummy(x): kernel_size=3, padding=padding, pad_values=pad_values, - input_quantizer="ste_sign", - kernel_quantizer="ste_sign", + input_quantizer=binary_quantizer, + kernel_quantizer=binary_quantizer, use_bias=False, activation=activation, )(x) @@ -59,7 +60,7 @@ def dummy(x): return tf.keras.Model(inputs=img_input, outputs=out) -def toy_model_sequential(**kwargs): +def toy_model_sequential(binary_quantizer="ste_sign", **kwargs): return tf.keras.models.Sequential( [ tf.keras.layers.Input((224, 224, 3)), @@ -70,8 +71,8 @@ def toy_model_sequential(**kwargs): lq.layers.QuantConv2D( 32, (3, 3), - input_quantizer="ste_sign", - kernel_quantizer="ste_sign", + input_quantizer=binary_quantizer, + kernel_quantizer=binary_quantizer, padding="same", pad_values=1.0, use_bias=False, @@ -85,8 +86,8 @@ def toy_model_sequential(**kwargs): lq.layers.QuantConv2D( 32, (3, 3), - input_quantizer="ste_sign", - kernel_quantizer="ste_sign", + input_quantizer=binary_quantizer, + kernel_quantizer=binary_quantizer, strides=(2, 2), padding="same", pad_values=1.0, @@ -104,8 +105,8 @@ def toy_model_sequential(**kwargs): lq.layers.QuantConv2D( 32, (3, 3), - input_quantizer="ste_sign", - kernel_quantizer="ste_sign", + input_quantizer=binary_quantizer, + kernel_quantizer=binary_quantizer, padding="same", pad_values=1.0, use_bias=False, @@ -165,12 +166,25 @@ def dataset(): ) +def tf_where_binary_quantizer(x): + return tf.where(x >= 0, tf.ones_like(x), -tf.ones_like(x)) + + @pytest.mark.parametrize( "conversion_function", [convert_keras_model, convert_keras_model_as_saved_model] ) @pytest.mark.parametrize( "model_cls", - [toy_model, toy_model_sequential, toy_model_int8, lqz.sota.QuickNetSmall], + [ + toy_model, + functools.partial(toy_model, binary_quantizer=tf_where_binary_quantizer), + toy_model_sequential, + functools.partial( + toy_model_sequential, binary_quantizer=tf_where_binary_quantizer + ), + toy_model_int8, + lqz.sota.QuickNetSmall, + ], ) def test_simple_model(dataset, conversion_function, model_cls): model = model_cls(weights="imagenet") diff --git a/larq_compute_engine/tflite/kernels/BUILD b/larq_compute_engine/tflite/kernels/BUILD index 8217c8a27..cfefeee54 100644 --- a/larq_compute_engine/tflite/kernels/BUILD +++ b/larq_compute_engine/tflite/kernels/BUILD @@ -41,6 +41,7 @@ cc_library( "//larq_compute_engine/core/indirect_bgemm:kernels", "@flatbuffers", "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:type_to_tflitetype", "@org_tensorflow//tensorflow/lite/kernels/internal:kernel_utils", "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", "@ruy//ruy/profiler:instrumentation", diff --git a/larq_compute_engine/tflite/kernels/quantization.cc b/larq_compute_engine/tflite/kernels/quantization.cc index e804b9af4..68b826d40 100644 --- a/larq_compute_engine/tflite/kernels/quantization.cc +++ b/larq_compute_engine/tflite/kernels/quantization.cc @@ -1,9 +1,12 @@ +#include + #include "larq_compute_engine/core/bitpacking/utils.h" #include "ruy/profiler/instrumentation.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" using namespace tflite; @@ -20,8 +23,9 @@ TfLiteStatus QuantizePrepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - TF_LITE_ENSURE(context, - input->type == kTfLiteFloat32 || input->type == kTfLiteInt8); + TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 || + input->type == kTfLiteInt8 || + input->type == kTfLiteBool); TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt32); int num_dims = NumDimensions(input); @@ -44,8 +48,9 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32); - TF_LITE_ENSURE(context, - output->type == kTfLiteFloat32 || output->type == kTfLiteInt8); + TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 || + output->type == kTfLiteInt8 || + output->type == kTfLiteBool); int num_dims = NumDimensions(input); @@ -80,6 +85,27 @@ TfLiteStatus QuantizeEval(TfLiteContext* context, TfLiteNode* node) { } else if (input->type == kTfLiteInt8) { bitpack_tensor(GetTensorShape(input), GetTensorData(input), input->params.zero_point, GetTensorData(output)); + } else if (input->type == kTfLiteBool) { + // The strategy here is to interpret the input data as an unsigned integer + // (of the same width as the bool type for the target). We then call + // bitpacking, with a 'zero point' of 1. This means that the value with all + // zero bits will be bitpacked as bit 1, and all other values will be + // bitpacked as bit 0. Assuming that `false` is represented by a value with + // all zero bits, this gives the correct result of bitpacking `false` as bit + // 1 and `true` as bit 0. + + static_assert(std::is_same<::tflite::TfLiteTypeToType::Type, + bool>::value, + ""); + using BOOL_UINT = std::conditional< + sizeof(bool) == 1, std::uint8_t, + std::conditional::type>::type>::type; + static_assert(sizeof(bool) == sizeof(BOOL_UINT), ""); + + bitpack_tensor(GetTensorShape(input), GetTensorData(input), + BOOL_UINT(1), GetTensorData(output)); } else { return kTfLiteError; } @@ -110,6 +136,9 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) { unpack_matrix(GetTensorData(input), num_rows, num_cols, GetTensorData(output), zero_bit_result, one_bit_result); + } else if (output->type == kTfLiteBool) { + unpack_matrix(GetTensorData(input), num_rows, num_cols, + GetTensorData(output), true, false); } else { return kTfLiteError; } diff --git a/larq_compute_engine/tflite/tests/quantization_test.cc b/larq_compute_engine/tflite/tests/quantization_test.cc index 292994814..8f396b862 100644 --- a/larq_compute_engine/tflite/tests/quantization_test.cc +++ b/larq_compute_engine/tflite/tests/quantization_test.cc @@ -116,6 +116,8 @@ TEST_P(QuantizationOpTest, Float) { TestQuantization(GetParam()); } TEST_P(QuantizationOpTest, Int8) { TestQuantization(GetParam()); } +TEST_P(QuantizationOpTest, Bool) { TestQuantization(GetParam()); } + INSTANTIATE_TEST_SUITE_P(AllCombinations, QuantizationOpTest, ::testing::Values(std::array{1, 1, 1, 1}, std::array{1, 4, 4, 1},