diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 92e82e6e7de..f70a6d7faee 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -689,11 +689,11 @@ def register_fake( ) lib.define( - "quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor" + "quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden) -> Tensor" ) lib.define( - "quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( @@ -3060,11 +3060,20 @@ def quantized_w8a32_gru_meta( weights_hidden: torch.Tensor, w_h_scale: float, bias_inputs: torch.Tensor, - b_i_scale: float, + b_scale: float, bias_hidden: torch.Tensor, - b_h_scale: float, ) -> torch.Tensor: - return hidden.new_empty((2, *hidden.shape), dtype=torch.float32) + seq_len = inputs.shape[1] + assert seq_len == 1 + # inputs comes in shape [batch, seq_len, input_size] + # hidden comes in shape [batch, seq_len, hidden_size] + # weights_inputs comes in shape [3 * hidden_size, input_size] + # weights_hidden comes in shape [3 * hidden_size, hidden_size] + # output comes in empty with shape [2, batch, seq_len, hidden_size] + # The first dimension stacks the output and the new hidden state + return hidden.new_empty( + (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32 + ) @register_fake("cadence::slice_scatter_") diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 2853136081b..434d0712d49 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -438,26 +438,36 @@ def get_args_and_kwargs_mixed_w8a32_conv( torch.ops.aten.permute.default, (other_inputs[0], [0, 2, 1]), # NCL -> NLC ) - assert "val" in other_inputs[0].meta, "Missing val metadata on input node" - original_val = other_inputs[0].meta["val"] - assert original_val.fake_mode is not None, "fake_mode is None on input node" - with original_val.fake_mode: - transposed_inputs.meta["val"] = torch.ops.aten.permute.default( - original_val, [0, 2, 1] - ) + # Propagate val metadata for transposed_inputs + if "val" in other_inputs[0].meta: + original_val = other_inputs[0].meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_val = torch.ops.aten.permute.default(original_val, [0, 2, 1]) + transposed_inputs.meta["val"] = transposed_val + else: + transposed_inputs.meta["val"] = torch.ops.aten.permute.default( + original_val, [0, 2, 1] + ) copy_node_metadata(transposed_inputs, other_inputs[0]) transposed_weights = graph_module.graph.call_function( torch.ops.aten.permute.default, (weights_inputs[0], [2, 0, 1]), # NCL -> LNC ) - assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node" - original_val = weights_inputs[0].meta["val"] - assert original_val.fake_mode is not None, "fake_mode is None on weight node" - with original_val.fake_mode: - transposed_weights.meta["val"] = torch.ops.aten.permute.default( - original_val, [2, 0, 1] - ) + # Propagate val metadata for transposed_weights + if "val" in weights_inputs[0].meta: + original_val = weights_inputs[0].meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_val = torch.ops.aten.permute.default(original_val, [2, 0, 1]) + transposed_weights.meta["val"] = transposed_val + else: + transposed_weights.meta["val"] = torch.ops.aten.permute.default( + original_val, [2, 0, 1] + ) copy_node_metadata(transposed_weights, weights_inputs[0]) args = ( @@ -511,12 +521,10 @@ def get_args_and_kwargs_mixed_w8a32_gru( ) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: # Stride, padding, dilation, groups not supported yet - assert len(dequants_weights) == 2 assert len(dequants_biases) == 2 w_i_scale = dequants_weights[0].args[1] w_h_scale = dequants_weights[1].args[1] - b_i_scale = dequants_biases[0].args[1] - b_h_scale = dequants_biases[1].args[1] + b_scale = dequants_biases[0].args[1] args = ( other_inputs[0], @@ -526,9 +534,8 @@ def get_args_and_kwargs_mixed_w8a32_gru( weights_inputs[1], w_h_scale, bias_inputs[0], - b_i_scale, + b_scale, bias_inputs[1], - b_h_scale, ) kwargs = {} diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 4fd672a4a6c..2ce50871fc0 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -718,7 +718,7 @@ def get_anchors( ) cnn_weights = conv_layer.args[1] - if hasattr(cnn_weights.meta, "tensor_meta"): + if "tensor_meta" in cnn_weights.meta: cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape # Bail if the channels are not multiple of 4 (SIMD) if cnn_weights_shape[0] % 4 != 0: @@ -744,6 +744,18 @@ def get_anchors( conv_layer, ) + inputs = conv_layer.args[0] + if "tensor_meta" in inputs.meta: + inputs_shape = inputs.meta["tensor_meta"].shape + # Bail if length != kernel size - Not yet supported + if inputs_shape[-1] != cnn_weights_shape[2]: + return ( + PartitionAnchors( + empty=True, + ), + conv_layer, + ) + return ( PartitionAnchors( inputs=[], @@ -777,14 +789,16 @@ def get_anchors( ) # Bail if input or states are not multiple of 4 (SIMD) - if gru_layer.args[0].meta["tensor_meta"].shape[-1] % 4 != 0: + tensor_meta_0 = gru_layer.args[0].meta.get("tensor_meta", None) + if tensor_meta_0 is None or tensor_meta_0.shape[-1] % 4 != 0: return ( PartitionAnchors( empty=True, ), gru_layer, ) - if gru_layer.args[1].meta["tensor_meta"].shape[-1] % 4 != 0: + tensor_meta_1 = gru_layer.args[1].meta.get("tensor_meta", None) + if tensor_meta_1 is None or tensor_meta_1.shape[-1] % 4 != 0: return ( PartitionAnchors( empty=True, @@ -799,13 +813,26 @@ def __init__(self, args, meta): wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta) + # Using SharedQuantizationSpec so that bias_hh has the same observer as bias_ih + # Both biases get the same quantization scale to match the cpp operator + bias_ih_node = wrapper.args[2] + bias_ih_edge = (bias_ih_node, gru_layer) + shared_bias_qspec = SharedQuantizationSpec(edge_or_node=bias_ih_edge) + return ( PartitionAnchors( inputs=[], # pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`. weights=[(wrapper, 0), (wrapper, 1)], # pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`. - biases=[(wrapper, 2), (wrapper, 3)], + biases=[ + (wrapper, 2), # bias_ih gets normal qspec + ( + wrapper, + 3, + shared_bias_qspec, + ), # bias_hh shares observer with bias_ih + ], output=[], others=[(gru_layer, 0), (gru_layer, 1)], ), diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 9399efe632a..8de5f9d7317 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -90,6 +90,15 @@ observer_or_fake_quant_ctr=MinMaxObserver, ) +wgt_qspec_sym8s_127 = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=MinMaxObserver, +) + bias_qspec: Optional[QuantizationSpec] = None qconfig_A8W8 = QuantizationConfig( @@ -113,11 +122,11 @@ None, ) -qconfig_A32W8sym = QuantizationConfig( +qconfig_A32W8sym_127 = QuantizationConfig( input_activation=None, output_activation=None, - weight=wgt_qspec_sym8s, - bias=wgt_qspec_sym8s, + weight=wgt_qspec_sym8s_127, + bias=wgt_qspec_sym8s_127, ) @@ -350,13 +359,13 @@ class CadenceW8A32MixedQuantizer(CadenceQuantizer): def __init__(self) -> None: quantizers = [] quantizers.append( - CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym) + CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym_127) ) quantizers.append( - CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym) + CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym_127) ) quantizers.append( - CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym) + CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym_127) ) super().__init__(quantizers) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 8404fe25268..447fb30402c 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1257,9 +1257,8 @@ def quantized_w8a32_gru( weights_hidden: torch.Tensor, w_h_scale: float, bias_inputs: torch.Tensor, - b_i_scale: float, + b_scale: float, bias_hidden: torch.Tensor, - b_h_scale: float, ) -> torch.Tensor: assert weights_inputs.dtype == torch.int8 assert weights_hidden.dtype == torch.int8 @@ -1288,10 +1287,8 @@ def quantized_w8a32_gru( dequant_weights_inputs = weights_inputs.float() * w_i_scale dequant_weights_hidden = weights_hidden.float() * w_h_scale - # C++ implementation averages the two bias scales - avg_bias_scale = (b_i_scale + b_h_scale) / 2 - dequant_bias_inputs = bias_inputs.float() * avg_bias_scale - dequant_bias_hidden = bias_hidden.float() * avg_bias_scale + dequant_bias_inputs = bias_inputs.float() * b_scale + dequant_bias_hidden = bias_hidden.float() * b_scale gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs) gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden) @@ -1310,8 +1307,13 @@ def quantized_w8a32_gru( assert new_hidden.shape == original_hidden_shape - new_hidden = new_hidden.view(original_hidden_shape) - return torch.stack([new_hidden, new_hidden], dim=0) + batch_size = inputs.shape[0] + input_dim = inputs.shape[1] + hidden_dim = hidden.shape[-1] + + new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim) + + return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0) @impl_tracked(m, "quantized_conv2d_nhwc.per_tensor") diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 63077d373a7..29d74258ed4 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2901,9 +2901,8 @@ def test_softmax_f32_f32(self) -> None: torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 (3*4 x 4) 0.1, # w_h_scale torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 - 0.1, # b_i_scale + 0.1, # b_scale torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 - 0.1, # b_h_scale ), ( "invalid_batch_size_2", @@ -2918,9 +2917,8 @@ def test_softmax_f32_f32(self) -> None: torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 0.1, # w_h_scale torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 - 0.1, # b_i_scale + 0.1, # b_scale torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 - 0.1, # b_h_scale ), ( "non_zero_biases", @@ -2933,11 +2931,10 @@ def test_softmax_f32_f32(self) -> None: torch.tensor( [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8 ), # bias_inputs: 12 - 0.1, # b_i_scale + 0.1, # b_scale torch.tensor( [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8 ), # bias_hidden: 12 - 0.1, # b_h_scale ), ( "negative_weights", @@ -2954,9 +2951,8 @@ def test_softmax_f32_f32(self) -> None: ), # weights_hidden: 12x4 (alternating pattern) 0.1, # w_h_scale torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 - 0.1, # b_i_scale + 0.1, # b_scale torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 - 0.1, # b_h_scale ), ( "hidden_dim_8", @@ -2969,9 +2965,8 @@ def test_softmax_f32_f32(self) -> None: torch.ones((24, 8), dtype=torch.int8), # weights_hidden: 24x8 (3*8 x 8) 0.1, # w_h_scale torch.zeros(24, dtype=torch.int8), # bias_inputs: 24 - 0.1, # b_i_scale + 0.1, # b_scale torch.zeros(24, dtype=torch.int8), # bias_hidden: 24 - 0.1, # b_h_scale ), ] ) @@ -2985,9 +2980,8 @@ def test_quantized_w8a32_gru( weights_hidden: torch.Tensor, w_h_scale: float, bias_inputs: torch.Tensor, - b_i_scale: float, + b_scale: float, bias_hidden: torch.Tensor, - b_h_scale: float, ) -> None: if name == "invalid_batch_size_2": @@ -3000,9 +2994,8 @@ def test_quantized_w8a32_gru( weights_hidden, w_h_scale, bias_inputs, - b_i_scale, + b_scale, bias_hidden, - b_h_scale, ) self.assertIn( "Leading dimension 0 of hidden state must be 1", str(context.exception) @@ -3017,9 +3010,8 @@ def test_quantized_w8a32_gru( weights_hidden, w_h_scale, bias_inputs, - b_i_scale, + b_scale, bias_hidden, - b_h_scale, ) # Verify output properties @@ -3028,10 +3020,11 @@ def test_quantized_w8a32_gru( torch.float32, f"Output dtype should be float32 in {name}", ) + expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]) self.assertEqual( output.shape, - (2, *hidden.shape), - f"Output shape should match {(2, *hidden.shape)} in {name}", + expected_shape, + f"Output shape should match {expected_shape} in {name}", ) assert isinstance(output, torch.Tensor) @@ -3064,7 +3057,6 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None: bias_inputs, 0.1, bias_hidden, - 0.1, ) self.assertIn(