Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import math
import operator
from copy import copy
from typing import cast, Dict, Optional, Set, Tuple, Type

Expand Down Expand Up @@ -34,22 +35,67 @@ class InsertRescalePass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = set()

def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None:
"""Ensure uint8 tensors only appear at IO boundaries.

TOSA has no true uint8 tensor type; unsigned semantics are carried via
RESCALE input/output flags. If uint8 appears for other nodes, it means
unsigned data leaked past IO.

"""
for node in graph_module.graph.nodes:
meta_val = node.meta.get("val")
if not isinstance(meta_val, torch.Tensor):
continue
if meta_val.dtype != torch.uint8:
continue
if node.op in ("placeholder", "output"):
continue
if node.op == "call_function" and node.target == operator.getitem:
if all(user.op == "output" for user in node.users):
continue
if (
node.op == "call_function"
and node.target
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
):
# dim_order is a view-like transform; allow it to preserve uint8 at IO.
continue
if (
node.op == "call_function"
and node.target == exir_ops.backend.tosa.RESCALE.default
):
continue
raise ValueError(
f"Found internal uint8 tensor at node {node.name} "
f"({node.target}). Uint8 is only allowed at IO boundaries."
)

def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
dq_args = QuantArgs.from_operator(node.target, node.args)
q_args = QuantArgs.from_operator(user.target, user.args)
new_scale = dq_args.scale / q_args.scale
input_unsigned = dq_args.dtype == torch.uint8
output_unsigned = q_args.dtype == torch.uint8
# TOSA has no true uint8 tensors; unsigned semantics are handled via
# the RESCALE flags, so uint8 does not propagate as a tensor dtype.
output_dtype = torch.int8 if output_unsigned else q_args.dtype

with graph_module.graph.inserting_before(node):
rescale_node = create_node(
graph_module.graph,
exir_ops.backend.tosa.RESCALE.default,
(
node.all_input_nodes[0],
q_args.dtype,
output_dtype,
[new_scale],
dq_args.zp,
q_args.zp,
),
kwargs={
"input_unsigned": input_unsigned,
"output_unsigned": output_unsigned,
},
)
rescale_node.meta = copy(user.meta)
user.replace_all_uses_with(rescale_node)
Expand All @@ -74,6 +120,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
graph_module.recompile()
return PassResult(graph_module, modified)

def ensures(self, graph_module: GraphModule) -> None:
self._ensure_uint8_io_only(graph_module)


class InsertRescaleInt32Pass(ArmPass):
"""Numerous TOSA ops require inputs and outputs to be 32-bit integers in
Expand Down
36 changes: 36 additions & 0 deletions backends/arm/common/arm_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class DebugMode(Enum):
compiler_flags: list[str] = field(default_factory=list)
path_for_intermediates: str | None = None
tosa_debug_mode: DebugMode | None = None
preserve_io_quantization: bool = False

_TOSA_SPEC_KEY = "tosa_spec"
_COMPILE_FLAGS_KEY = "compile_flags"
Expand All @@ -44,6 +45,7 @@ class DebugMode(Enum):
_DEBUG_MODE_KEY = "dump_debug_info"
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
_TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config"
_PRESERVE_IO_QUANT_KEY = "preserve_io_quantization"

def _set_compile_specs(
self,
Expand All @@ -53,6 +55,7 @@ def _set_compile_specs(
tosa_debug_mode: DebugMode | None = None,
output_order_workaround: bool = False,
pipeline_config: ArmPassPipelineConfig | None = None,
preserve_io_quantization: bool = False,
):
"""Set all values of dataclass directly."""
self.tosa_spec = tosa_spec
Expand All @@ -61,6 +64,8 @@ def _set_compile_specs(
self.tosa_debug_mode = tosa_debug_mode
self._pipeline_config = pipeline_config
self.output_order_workaround = output_order_workaround
self.preserve_io_quantization = preserve_io_quantization
self._warn_if_redundant_preserve_io_quantization()
if output_order_workaround:
warnings.warn(
"ArmCompileSpec(output_order_workaround=True) is deprecated and will be "
Expand All @@ -78,6 +83,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
output_order_workaround: bool = False
pipeline_config: ArmPassPipelineConfig | None = None
preserve_io_quantization: bool = False
unknown_specs: dict[str, str] = {}
for spec in compile_specs:
key = spec.key
Expand Down Expand Up @@ -128,6 +134,8 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
"More than one transform pipeline entry in compile spec."
)
pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val))
elif key == ArmCompileSpec._PRESERVE_IO_QUANT_KEY:
preserve_io_quantization = str(val).lower() in ("1", "true", "yes")
else:
unknown_specs[key] = val

Expand All @@ -151,6 +159,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
tosa_debug_mode=tosa_debug_mode,
output_order_workaround=output_order_workaround,
pipeline_config=pipeline_config,
preserve_io_quantization=preserve_io_quantization,
)
cls._from_list_hook(compile_spec, unknown_specs)
compile_spec._validate()
Expand Down Expand Up @@ -227,8 +236,35 @@ def _to_list(self):
self._pipeline_config.serialize(),
)
)
compile_spec.append(
CompileSpec(
ArmCompileSpec._PRESERVE_IO_QUANT_KEY,
str(bool(self.preserve_io_quantization)).encode(),
)
)
return compile_spec

def _set_preserve_io_quantization(self, enabled: bool) -> "ArmCompileSpec":
"""Preserve Q/DQ nodes at IO boundaries when lowering."""
self.preserve_io_quantization = enabled
self._warn_if_redundant_preserve_io_quantization()
return self

def _warn_if_redundant_preserve_io_quantization(self) -> None:
"""Warn when preserve_io_quantization has no effect for INT-only
specs.
"""
if (
self.preserve_io_quantization
and self.tosa_spec.support_integer()
and not self.tosa_spec.support_float()
):
warnings.warn(
"preserve_io_quantization=True is redundant for INT-only TOSA "
"specifications because boundary Q/DQ are already de-tagged.",
stacklevel=3,
)

def _get_pass_pipeline_config(self) -> ArmPassPipelineConfig:
"""Returns configuration that controls how the Arm pass pipeline should
behave.
Expand Down
17 changes: 14 additions & 3 deletions backends/arm/operators/op_tosa_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def _build_rescale(
rounding_mode: ts.RoundingMode,
per_channel: bool = False,
is_scale32: bool = True,
input_unsigned: bool = False,
output_unsigned: bool = False,
):
"""Insert a TOSA RESCALE operator configured for the quantized path.

Expand Down Expand Up @@ -198,8 +200,8 @@ def _build_rescale(
scale32=is_scale32,
rounding_mode=rounding_mode,
per_channel=per_channel,
input_unsigned=False,
output_unsigned=False,
input_unsigned=input_unsigned,
output_unsigned=output_unsigned,
)

tosa_fb.addOperator(
Expand Down Expand Up @@ -228,6 +230,14 @@ def define_node(
scales = cast(list[float], node.args[2])
input_zp = cast(int, node.args[3])
output_zp = cast(int, node.args[4])
if "input_unsigned" in node.kwargs:
input_unsigned = cast(bool, node.kwargs.get("input_unsigned", False))
else:
input_unsigned = cast(bool, node.args[5]) if len(node.args) > 5 else False
if "output_unsigned" in node.kwargs:
output_unsigned = cast(bool, node.kwargs.get("output_unsigned", False))
else:
output_unsigned = cast(bool, node.args[6]) if len(node.args) > 6 else False

if (
input_dtype
Expand All @@ -244,7 +254,6 @@ def define_node(
raise ValueError(
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
)

_build_rescale(
tosa_graph,
scale=scales,
Expand All @@ -255,4 +264,6 @@ def define_node(
output_zp=[output_zp],
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
per_channel=len(scales) > 1,
input_unsigned=input_unsigned,
output_unsigned=output_unsigned,
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
EthosUQuantizer,
get_symmetric_a16w8_quantization_config,
get_symmetric_quantization_config,
get_uint8_io_quantization_config,
TOSAQuantizer,
VgfQuantizer,
)
Expand Down
48 changes: 48 additions & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"VgfQuantizer",
"get_symmetric_a16w8_quantization_config",
"get_symmetric_quantization_config",
"get_uint8_io_quantization_config",
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -234,6 +235,53 @@ def get_symmetric_quantization_config(
return quantization_config


@functools.lru_cache
def get_uint8_io_quantization_config(
is_qat: bool = False,
is_dynamic: bool = False,
eps: float = 2**-16,
) -> QuantizationConfig:
"""Create a uint8 IO quantization config for TOSA backends.

This config is intended for model inputs/outputs only. Internal tensors
should remain int8 for TOSA INT lowering.

"""
extra_args: Dict[str, Any] = {"eps": eps}
if is_qat:
if is_dynamic:
act_observer_or_fake_quant_ctr = FakeQuantize
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
averaging_constant=1
)
extra_args["observer"] = dynamic_quant_observer
else:
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]

act_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
quant_min=torch.iinfo(torch.uint8).min,
quant_max=torch.iinfo(torch.uint8).max,
qscheme=torch.per_tensor_affine,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args,
),
)

return TOSAQuantizationConfig(
act_quantization_spec,
act_quantization_spec,
None,
None,
)


def get_symmetric_a8w4_quantization_config(
is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False
):
Expand Down
72 changes: 44 additions & 28 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,45 +532,61 @@ def _get_input_nodes_with_float_output(self, node: Node) -> list[Node]:
def _get_user_nodes_with_float_input(self, node: Node) -> list[Node]:
return [n for n in node.users.keys() if has_float_output(node)]

def _skip_shared_qspec_from_io(self, node: Node, qspec: QuantizationSpec) -> bool:
return node.op in ("placeholder", "output") and qspec.dtype == torch.uint8

def _maybe_enqueue_shared_node(
self, neighbor: Node, shared_nodes: set[Node], bfs_queue: list[Node]
) -> None:
if neighbor.target in self.targets and neighbor not in shared_nodes:
if not self._is_annotated(neighbor):
bfs_queue.append(neighbor)

def _append_output_qspec(self, node: Node, adjacent_qspecs: list[Any]) -> None:
if not self._is_annotated(node):
return
output_qspec = node.meta.get( # type: ignore[union-attr]
Q_ANNOTATION_KEY
).output_qspec
if output_qspec is None:
return
if self._skip_shared_qspec_from_io(node, output_qspec):
return
adjacent_qspecs.append(output_qspec)

def _append_input_qspec(
self, user_node: Node, input_node: Node, adjacent_qspecs: list[Any]
) -> None:
if not self._is_annotated(user_node):
return
qspec_map = user_node.meta.get(Q_ANNOTATION_KEY)
if qspec_map is None:
return
if input_node not in qspec_map.input_qspec_map:
return
input_qspec = qspec_map.input_qspec_map[input_node]
if input_qspec is None:
return
if self._skip_shared_qspec_from_io(user_node, input_qspec):
return
adjacent_qspecs.append(input_qspec)

def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]:
shared_nodes = set()
bfs_queue = [root_node]
adjacent_qspecs = []
adjacent_qspecs: list[Any] = []

while bfs_queue:
node = bfs_queue.pop(0)
shared_nodes.add(node)

for input_node in node.all_input_nodes:
if input_node.target in self.targets and input_node not in shared_nodes:
if not self._is_annotated(input_node):
bfs_queue.append(input_node)
if self._is_annotated(input_node):
output_qspec = input_node.meta.get( # type: ignore[union-attr]
Q_ANNOTATION_KEY
).output_qspec
if output_qspec is not None:
adjacent_qspecs.append(output_qspec)
self._maybe_enqueue_shared_node(input_node, shared_nodes, bfs_queue)
self._append_output_qspec(input_node, adjacent_qspecs)

for output_node in node.users.keys():
if (
output_node.target in self.targets
and output_node not in shared_nodes
):
if not self._is_annotated(output_node):
bfs_queue.append(output_node)
if (
self._is_annotated(output_node)
and node
in output_node.meta.get( # type: ignore[union-attr]
Q_ANNOTATION_KEY
).input_qspec_map
):
input_qspec = output_node.meta.get( # type: ignore[union-attr]
Q_ANNOTATION_KEY
).input_qspec_map[node]
if input_qspec is not None:
adjacent_qspecs.append(input_qspec)
self._maybe_enqueue_shared_node(output_node, shared_nodes, bfs_queue)
self._append_input_qspec(output_node, node, adjacent_qspecs)

return shared_nodes, adjacent_qspecs

Expand Down
Loading
Loading