Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,11 @@ def register_fake(
)

lib.define(
"quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor"
"quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden) -> Tensor"
)

lib.define(
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)"
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
Expand Down Expand Up @@ -3060,11 +3060,20 @@ def quantized_w8a32_gru_meta(
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
b_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> torch.Tensor:
return hidden.new_empty((2, *hidden.shape), dtype=torch.float32)
seq_len = inputs.shape[1]
assert seq_len == 1
# inputs comes in shape [batch, seq_len, input_size]
# hidden comes in shape [batch, seq_len, hidden_size]
# weights_inputs comes in shape [3 * hidden_size, input_size]
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
# output comes in empty with shape [2, batch, seq_len, hidden_size]
# The first dimension stacks the output and the new hidden state
return hidden.new_empty(
(2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32
)
Comment on lines +3066 to +3076


@register_fake("cadence::slice_scatter_")
Expand Down
45 changes: 26 additions & 19 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,26 +438,36 @@ def get_args_and_kwargs_mixed_w8a32_conv(
torch.ops.aten.permute.default,
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
)
assert "val" in other_inputs[0].meta, "Missing val metadata on input node"
original_val = other_inputs[0].meta["val"]
assert original_val.fake_mode is not None, "fake_mode is None on input node"
with original_val.fake_mode:
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
original_val, [0, 2, 1]
)
# Propagate val metadata for transposed_inputs
if "val" in other_inputs[0].meta:
original_val = other_inputs[0].meta["val"]
fake_mode = original_val.fake_mode
if fake_mode is not None:
with fake_mode:
transposed_val = torch.ops.aten.permute.default(original_val, [0, 2, 1])
transposed_inputs.meta["val"] = transposed_val
else:
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
original_val, [0, 2, 1]
)
copy_node_metadata(transposed_inputs, other_inputs[0])

transposed_weights = graph_module.graph.call_function(
torch.ops.aten.permute.default,
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
)
assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node"
original_val = weights_inputs[0].meta["val"]
assert original_val.fake_mode is not None, "fake_mode is None on weight node"
with original_val.fake_mode:
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
original_val, [2, 0, 1]
)
# Propagate val metadata for transposed_weights
if "val" in weights_inputs[0].meta:
original_val = weights_inputs[0].meta["val"]
fake_mode = original_val.fake_mode
if fake_mode is not None:
with fake_mode:
transposed_val = torch.ops.aten.permute.default(original_val, [2, 0, 1])
transposed_weights.meta["val"] = transposed_val
else:
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
original_val, [2, 0, 1]
)
copy_node_metadata(transposed_weights, weights_inputs[0])

args = (
Expand Down Expand Up @@ -511,12 +521,10 @@ def get_args_and_kwargs_mixed_w8a32_gru(
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
# Stride, padding, dilation, groups not supported yet

assert len(dequants_weights) == 2
assert len(dequants_biases) == 2
w_i_scale = dequants_weights[0].args[1]
w_h_scale = dequants_weights[1].args[1]
b_i_scale = dequants_biases[0].args[1]
b_h_scale = dequants_biases[1].args[1]
b_scale = dequants_biases[0].args[1]

Comment on lines 524 to 528
args = (
other_inputs[0],
Expand All @@ -526,9 +534,8 @@ def get_args_and_kwargs_mixed_w8a32_gru(
weights_inputs[1],
w_h_scale,
bias_inputs[0],
b_i_scale,
b_scale,
bias_inputs[1],
b_h_scale,
)
kwargs = {}

Expand Down
35 changes: 31 additions & 4 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def get_anchors(
)

cnn_weights = conv_layer.args[1]
if hasattr(cnn_weights.meta, "tensor_meta"):
if "tensor_meta" in cnn_weights.meta:
cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape
# Bail if the channels are not multiple of 4 (SIMD)
if cnn_weights_shape[0] % 4 != 0:
Expand All @@ -744,6 +744,18 @@ def get_anchors(
conv_layer,
)

inputs = conv_layer.args[0]
if "tensor_meta" in inputs.meta:
inputs_shape = inputs.meta["tensor_meta"].shape
# Bail if length != kernel size - Not yet supported
if inputs_shape[-1] != cnn_weights_shape[2]:
return (
PartitionAnchors(
empty=True,
),
conv_layer,
)

return (
PartitionAnchors(
inputs=[],
Expand Down Expand Up @@ -777,14 +789,16 @@ def get_anchors(
)

# Bail if input or states are not multiple of 4 (SIMD)
if gru_layer.args[0].meta["tensor_meta"].shape[-1] % 4 != 0:
tensor_meta_0 = gru_layer.args[0].meta.get("tensor_meta", None)
if tensor_meta_0 is None or tensor_meta_0.shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
),
gru_layer,
)
if gru_layer.args[1].meta["tensor_meta"].shape[-1] % 4 != 0:
tensor_meta_1 = gru_layer.args[1].meta.get("tensor_meta", None)
if tensor_meta_1 is None or tensor_meta_1.shape[-1] % 4 != 0:
return (
PartitionAnchors(
empty=True,
Expand All @@ -799,13 +813,26 @@ def __init__(self, args, meta):

wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta)

# Using SharedQuantizationSpec so that bias_hh has the same observer as bias_ih
# Both biases get the same quantization scale to match the cpp operator
bias_ih_node = wrapper.args[2]
bias_ih_edge = (bias_ih_node, gru_layer)
shared_bias_qspec = SharedQuantizationSpec(edge_or_node=bias_ih_edge)

return (
PartitionAnchors(
inputs=[],
# pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`.
weights=[(wrapper, 0), (wrapper, 1)],
# pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`.
biases=[(wrapper, 2), (wrapper, 3)],
biases=[
(wrapper, 2), # bias_ih gets normal qspec
(
wrapper,
3,
shared_bias_qspec,
), # bias_hh shares observer with bias_ih
],
output=[],
others=[(gru_layer, 0), (gru_layer, 1)],
),
Expand Down
21 changes: 15 additions & 6 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@
observer_or_fake_quant_ctr=MinMaxObserver,
)

wgt_qspec_sym8s_127 = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
)

bias_qspec: Optional[QuantizationSpec] = None

qconfig_A8W8 = QuantizationConfig(
Expand All @@ -113,11 +122,11 @@
None,
)

qconfig_A32W8sym = QuantizationConfig(
qconfig_A32W8sym_127 = QuantizationConfig(
input_activation=None,
output_activation=None,
weight=wgt_qspec_sym8s,
bias=wgt_qspec_sym8s,
weight=wgt_qspec_sym8s_127,
bias=wgt_qspec_sym8s_127,
)


Expand Down Expand Up @@ -350,13 +359,13 @@ class CadenceW8A32MixedQuantizer(CadenceQuantizer):
def __init__(self) -> None:
quantizers = []
quantizers.append(
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym)
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym_127)
)
quantizers.append(
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym)
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym_127)
)
quantizers.append(
CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym)
CadenceAtenQuantizer(MixedW8A32GruPattern(), qconfig_A32W8sym_127)
)
super().__init__(quantizers)

Expand Down
18 changes: 10 additions & 8 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -1257,9 +1257,8 @@
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
b_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> torch.Tensor:
assert weights_inputs.dtype == torch.int8
assert weights_hidden.dtype == torch.int8
Expand Down Expand Up @@ -1288,10 +1287,8 @@
dequant_weights_inputs = weights_inputs.float() * w_i_scale
dequant_weights_hidden = weights_hidden.float() * w_h_scale

# C++ implementation averages the two bias scales
avg_bias_scale = (b_i_scale + b_h_scale) / 2
dequant_bias_inputs = bias_inputs.float() * avg_bias_scale
dequant_bias_hidden = bias_hidden.float() * avg_bias_scale
dequant_bias_inputs = bias_inputs.float() * b_scale
dequant_bias_hidden = bias_hidden.float() * b_scale

gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs)
gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden)
Expand All @@ -1310,8 +1307,13 @@

assert new_hidden.shape == original_hidden_shape

new_hidden = new_hidden.view(original_hidden_shape)
return torch.stack([new_hidden, new_hidden], dim=0)
batch_size = inputs.shape[0]
input_dim = inputs.shape[1]
hidden_dim = hidden.shape[-1]

new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim)

return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0)
Comment on lines +1310 to +1316


@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
Expand Down
30 changes: 11 additions & 19 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2901,9 +2901,8 @@ def test_softmax_f32_f32(self) -> None:
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 (3*4 x 4)
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
0.1, # b_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"invalid_batch_size_2",
Expand All @@ -2918,9 +2917,8 @@ def test_softmax_f32_f32(self) -> None:
torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
0.1, # b_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"non_zero_biases",
Expand All @@ -2933,11 +2931,10 @@ def test_softmax_f32_f32(self) -> None:
torch.tensor(
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8
), # bias_inputs: 12
0.1, # b_i_scale
0.1, # b_scale
torch.tensor(
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8
), # bias_hidden: 12
0.1, # b_h_scale
),
(
"negative_weights",
Expand All @@ -2954,9 +2951,8 @@ def test_softmax_f32_f32(self) -> None:
), # weights_hidden: 12x4 (alternating pattern)
0.1, # w_h_scale
torch.zeros(12, dtype=torch.int8), # bias_inputs: 12
0.1, # b_i_scale
0.1, # b_scale
torch.zeros(12, dtype=torch.int8), # bias_hidden: 12
0.1, # b_h_scale
),
(
"hidden_dim_8",
Expand All @@ -2969,9 +2965,8 @@ def test_softmax_f32_f32(self) -> None:
torch.ones((24, 8), dtype=torch.int8), # weights_hidden: 24x8 (3*8 x 8)
0.1, # w_h_scale
torch.zeros(24, dtype=torch.int8), # bias_inputs: 24
0.1, # b_i_scale
0.1, # b_scale
torch.zeros(24, dtype=torch.int8), # bias_hidden: 24
0.1, # b_h_scale
),
]
)
Expand All @@ -2985,9 +2980,8 @@ def test_quantized_w8a32_gru(
weights_hidden: torch.Tensor,
w_h_scale: float,
bias_inputs: torch.Tensor,
b_i_scale: float,
b_scale: float,
bias_hidden: torch.Tensor,
b_h_scale: float,
) -> None:

if name == "invalid_batch_size_2":
Expand All @@ -3000,9 +2994,8 @@ def test_quantized_w8a32_gru(
weights_hidden,
w_h_scale,
bias_inputs,
b_i_scale,
b_scale,
bias_hidden,
b_h_scale,
)
self.assertIn(
"Leading dimension 0 of hidden state must be 1", str(context.exception)
Expand All @@ -3017,9 +3010,8 @@ def test_quantized_w8a32_gru(
weights_hidden,
w_h_scale,
bias_inputs,
b_i_scale,
b_scale,
bias_hidden,
b_h_scale,
)

# Verify output properties
Expand All @@ -3028,10 +3020,11 @@ def test_quantized_w8a32_gru(
torch.float32,
f"Output dtype should be float32 in {name}",
)
expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1])
self.assertEqual(
output.shape,
(2, *hidden.shape),
f"Output shape should match {(2, *hidden.shape)} in {name}",
expected_shape,
f"Output shape should match {expected_shape} in {name}",
Comment on lines +3023 to +3027
)
assert isinstance(output, torch.Tensor)

Expand Down Expand Up @@ -3064,7 +3057,6 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
bias_inputs,
0.1,
bias_hidden,
0.1,
)

self.assertIn(
Expand Down
Loading