diff --git a/backends/nordic/.gitignore b/backends/nordic/.gitignore new file mode 100644 index 00000000000..840a5bb672c --- /dev/null +++ b/backends/nordic/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +*.pyc +*.pyo +*.so +*.egg-info/ +build/ diff --git a/backends/nordic/CMakeLists.txt b/backends/nordic/CMakeLists.txt new file mode 100644 index 00000000000..ee0aa04650d --- /dev/null +++ b/backends/nordic/CMakeLists.txt @@ -0,0 +1,53 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +cmake_minimum_required(VERSION 3.19) +project(nordic_backend) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_include_directories + ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10 +) +add_compile_definitions(C10_USING_CUSTOM_GENERATED_MACROS) + +# Nordic AXON NPU delegate build +if(EXECUTORCH_BUILD_NORDIC_AXON) + + add_compile_options("-Wall" "-Werror") + + set(_axon_backend_sources + backends/nordic/runtime/AxonBackend.cpp + backends/nordic/runtime/axon_op_extensions.c + ) + list(TRANSFORM _axon_backend_sources PREPEND "${EXECUTORCH_ROOT}/") + + add_library(executorch_delegate_axon STATIC ${_axon_backend_sources}) + target_link_libraries(executorch_delegate_axon PUBLIC executorch_core) + + # The Nordic sdk-edge-ai headers must be on the include path. + # Users should set SDK_EDGE_AI_INCLUDE_PATH to their sdk-edge-ai/include + # directory, or ensure the AXON driver headers are available via the + # Zephyr module system. + if(DEFINED SDK_EDGE_AI_INCLUDE_PATH) + target_include_directories( + executorch_delegate_axon PRIVATE ${SDK_EDGE_AI_INCLUDE_PATH} + ) + endif() + + target_include_directories( + executorch_delegate_axon PRIVATE ${_common_include_directories} + ) + + install(TARGETS executorch_delegate_axon EXPORT ExecuTorchTargets) + +endif() diff --git a/backends/nordic/README.md b/backends/nordic/README.md new file mode 100644 index 00000000000..af7a77235ff --- /dev/null +++ b/backends/nordic/README.md @@ -0,0 +1,258 @@ +# Nordic AXON NPU Backend for ExecuTorch + +ExecuTorch backend for Nordic Semiconductor's **AXON NPU** on the +nRF54LM20B (ARM Cortex-M33 + hardware neural network accelerator). + +Compiles PyTorch models to AXON command buffers via TOSA, then executes +them on the NPU at inference time, offloading compute-intensive layers +from the CPU. + +## Architecture + +``` +PyTorch Model + │ + ▼ +torch.export ─── ExecuTorch Edge Lowering + │ + ▼ +AxonPartitioner ─── identifies ops for AXON delegation + │ + ▼ +TOSABackend._preprocess() ─── shared ARM TOSA lowering + │ + ▼ +tosa_reader ─── parse TOSA flatbuffer + │ + ▼ +axon_compiler ─── convert TOSA layers to AXON layer descriptors + │ + ▼ +axon_binary ─── pack intermediate binary (cffi structs) + │ + ▼ +Nordic compiler lib ─── produce AXON command buffers (.h headers) + │ + ▼ +.pte file + generated headers ─── deploy to nRF54LM20DK +``` + +## Supported Hardware + +| Device | NPU | Status | +|--------|-----|--------| +| nRF54LM20DK (nRF54LM20B) | AXON (~300 MACs, 3-8 GOPS) | Supported | + +## Supported Operations + +### AXON-accelerated (hardware) + +| Operation | Max dimensions | Notes | +|-----------|---------------|-------| +| Fully Connected | 2048 in/out | INT8 weights + bias | +| Conv2D | 16x16 filter, stride ≤ 31 | INT8, with padding | +| Depthwise Conv2D | 16x16 filter | INT8 | +| Average Pool 2D | 32x32 filter | | +| Max Pool 2D | 32x32 filter | | +| Add | element-wise | INT8 | +| Multiply | element-wise | INT8 | +| ReLU / ReLU6 | fused with preceding layer | Zero overhead | +| Leaky ReLU | fused with preceding layer | | + +### Op extensions (AXON + CPU hybrid) + +| Operation | Preceding layer output | CPU callback | +|-----------|----------------------|--------------| +| Sigmoid | INT16 q3.12 | `axon_op_extension_sigmoid` | +| Tanh | INT16 q3.12 | `axon_op_extension_tanh` | +| Softmax | INT32 q11.12 | Nordic's reference implementation | + +## Prerequisites + +- **nRF54LM20DK**: Nordic's development kit with the AXON NPU. +- **nRF Connect SDK (NCS)**: Nordic's Zephyr-based SDK. Install via + [nRF Connect for Desktop](https://www.nordicsemi.com/Products/Development-tools/nRF-Connect-for-Desktop) + or manually: + ```bash + # Install NCS toolchain + nrfutil sdk-manager install --ncs-version v3.3.0-preview3 + + # Initialize west workspace + west init -m https://github.com/nrfconnect/sdk-nrf --mr v3.3.0-preview3 ~/ncs-workspace + cd ~/ncs-workspace && west update + + # Generate toolchain environment script + nrfutil sdk-manager toolchain env --as-script sh --ncs-version v3.3.0-preview3 > nrf-connect-sdk-env.sh + ``` +- **Nordic sdk-edge-ai**: Contains the AXON compiler library (proprietary, + not redistributed). Available to nRF54LM20DK owners via Nordic's + [Edge AI documentation](https://docs.nordicsemi.com/bundle/addon-edge-ai_latest/page/index.html). + ```bash + git clone ~/sdk-edge-ai + export SDK_EDGE_AI_PATH=~/sdk-edge-ai + ``` +- **ExecuTorch**: This repository (with the ARM TOSA backend). +- **Python packages** (for model export, separate from NCS Python): + ```bash + pip install cffi numpy tosa-tools + ``` + +Verify the setup: +```bash +bash backends/nordic/scripts/setup.sh +``` + +## Quick Start + +### One-command flow + +```bash +# Source NCS toolchain +source ~/ncs-workspace/nrf-connect-sdk-env.sh + +# Export model, build firmware, flash — all in one +./backends/nordic/scripts/run.sh +``` + +### Step-by-step flow + +#### Step 1: Export a model + +```bash +# Use a Python environment with ExecuTorch installed (NOT the NCS Python) +python examples/nordic/hello_axon/export_model.py +``` + +This trains a small sin(x) model, quantizes it to INT8, compiles the +FC layers to AXON command buffers, and produces: +- `build/hello_axon.pte` — ExecuTorch program +- `src/model_pte.h` — embedded model as C array +- `src/generated/axon_subgraph_*.h` — AXON command buffers +- `src/generated/axon_subgraphs_table.h` — delegate lookup table + +### Step 2: Build firmware + +```bash +# Source the NCS toolchain environment +source ~/ncs-workspace/nrf-connect-sdk-env.sh + +# Build for nRF54LM20DK +west build -b nrf54lm20dk/nrf54lm20b/cpuapp examples/nordic/hello_axon \ + --no-sysbuild -- \ + -DZEPHYR_EXTRA_MODULES="$(pwd);$SDK_EDGE_AI_PATH" +``` + +### Step 3: Flash and verify + +```bash +west flash + +# Open serial console (115200 baud) +# Linux: screen /dev/ttyACM0 115200 +# macOS: screen /dev/cu.usbmodem* 115200 +``` + +Expected output: +``` +Hello AXON - ExecuTorch + Nordic AXON NPU +AXON NPU: enabled +Loading model (2084 bytes)... +AxonBackend::init (delegate 0, processed=36 bytes) + AXON model 'hello_axon_...' bound (out: 1x1x1 byte_width=1) +Running inference (x=1.57, expected sin~1.0)... +Inference: 20871 cycles (163 us @ 128 MHz) + output[0] = 0.997794 +Done. +``` + +### Using the AXON backend in your own model + +```python +from executorch.backends.nordic.axon import ( + AxonQuantizer, + AxonCompileSpec, + AxonPartitioner, +) +from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e +from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig + +# 1. Quantize +quantizer = AxonQuantizer() # Symmetric INT8, per-channel weights +exported = torch.export.export(model.eval(), example_input, strict=False) +prepared = prepare_pt2e(exported.module(), quantizer) +prepared(*calibration_data) # Calibrate +quantized = convert_pt2e(prepared) +re_exported = torch.export.export(quantized, example_input, strict=False) + +# 2. Partition to AXON +compile_spec = AxonCompileSpec( + model_name="my_model", + axon_generated_dir="build/generated", +) +partitioner = AxonPartitioner(compile_spec) +edge = to_edge_transform_and_lower( + re_exported, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), +) + +# 3. Save .pte +edge.to_executorch().save("build/my_model.pte") +``` + +## Directory Structure + +``` +backends/nordic/ +├── axon/ # Core backend +│ ├── backend.py # AxonBackend (BackendDetails) +│ ├── compile_spec.py # AxonCompileSpec +│ ├── partitioner.py # AxonPartitioner (extends TOSAPartitioner) +│ └── codegen.py # Marker format, naming, header generation +├── axon_compiler.py # TOSA → AXON layer conversion +├── axon_binary.py # Binary builder for Nordic compiler +├── tosa_reader.py # TOSA flatbuffer parser +├── operator_support/ # AXON hardware constraint checks +├── runtime/ # C++ delegate for on-device execution +│ ├── AxonBackend.cpp # BackendInterface implementation +│ ├── AxonBackend.h # Profiling API +│ └── axon_op_extensions.c # Sigmoid/tanh CPU callbacks +├── test/ # Unit tests (23 tests, no SDK required) +├── CMakeLists.txt # Build configuration +└── README.md # This file +``` + +## How It Works + +The backend follows the same **composition pattern** as the Ethos-U backend: + +1. **Partitioner** identifies INT8-quantized operations that AXON supports +2. **TOSABackend** (shared with Ethos-U) lowers the subgraph to TOSA IR +3. **AXON compiler** converts TOSA → AXON layer descriptors → intermediate binary +4. **Nordic compiler lib** (external) produces command buffers as C headers +5. **C++ delegate** on-device parses the marker, looks up the compiled model, + and calls `nrf_axon_nn_model_infer_sync()` for hardware execution + +Multi-subgraph models are supported: each delegated subgraph gets a unique +name (content-hash based), and a generated lookup table maps names to compiled +models at runtime. + +## Running Tests + +```bash +# TOSA lowering tests (no SDK required) +pytest backends/nordic/test/test_tosa_lowering.py -v + +# Full compilation tests (requires SDK_EDGE_AI_PATH) +pytest backends/nordic/test/test_axon_compile.py -v +``` + +## Tutorials and Examples + +See [ioteai/axon-ai](https://github.com/ioteai/axon-ai) for Jupyter +notebooks, Docker setup, and detailed guides. + +## License + +Copyright (c) 2026 iote.ai. BSD 3-Clause License — see the root +[LICENSE](../../LICENSE) file. diff --git a/backends/nordic/TARGETS b/backends/nordic/TARGETS new file mode 100644 index 00000000000..aa8a9eb0a36 --- /dev/null +++ b/backends/nordic/TARGETS @@ -0,0 +1,51 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# @noautodeps +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +runtime.python_library( + name = "axon", + srcs = glob(["axon/*.py"]), + deps = [ + "//executorch/backends/arm:tosa", + "//executorch/exir:lib", + ], + visibility = ["//executorch/..."], +) + +runtime.python_library( + name = "operator_support", + srcs = glob(["operator_support/*.py"]), + deps = [ + ":axon", + "//executorch/exir:lib", + ], + visibility = ["//executorch/..."], +) + +runtime.python_library( + name = "axon_compiler", + srcs = [ + "axon_compiler.py", + "axon_binary.py", + "tosa_reader.py", + ], + deps = [ + ":axon", + ], + visibility = ["//executorch/..."], +) + +runtime.python_test( + name = "test_tosa_lowering", + srcs = ["test/test_tosa_lowering.py"], + deps = [ + ":axon", + ":axon_compiler", + ":operator_support", + ], +) diff --git a/backends/nordic/__init__.py b/backends/nordic/__init__.py new file mode 100644 index 00000000000..ac080621ffa --- /dev/null +++ b/backends/nordic/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Public entry points for the Nordic backend. + +Public API is defined by explicit module exports (``.axon``). +Selected symbols are re-exported here for convenience. +""" + +from __future__ import annotations + +import importlib +from typing import Any + +# Public for tooling (manifest generation and API validation). +LAZY_IMPORTS = { + "AxonBackend": ("executorch.backends.nordic.axon", "AxonBackend"), + "AxonCompileSpec": ("executorch.backends.nordic.axon", "AxonCompileSpec"), + "AxonPartitioner": ("executorch.backends.nordic.axon", "AxonPartitioner"), + "AxonQuantizer": ("executorch.backends.nordic.axon", "AxonQuantizer"), +} + + +def __getattr__(name: str) -> Any: + if name in LAZY_IMPORTS: + module_name, attr = LAZY_IMPORTS[name] + module = importlib.import_module(module_name) + value = getattr(module, attr) + globals()[name] = value + return value + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +def __dir__() -> list[str]: + return sorted(list(globals()) + list(LAZY_IMPORTS)) diff --git a/backends/nordic/axon/__init__.py b/backends/nordic/axon/__init__.py new file mode 100644 index 00000000000..ad0145e192a --- /dev/null +++ b/backends/nordic/axon/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""ExecuTorch backend for Nordic AXON NPU.""" + +__version__ = "0.1.0" + +from .backend import AxonBackend +from .compile_spec import AxonCompileSpec +from .partitioner import AxonPartitioner +from .quantizer import AxonQuantizer + +__all__ = ["AxonBackend", "AxonCompileSpec", "AxonPartitioner", "AxonQuantizer"] diff --git a/backends/nordic/axon/backend.py b/backends/nordic/axon/backend.py new file mode 100644 index 00000000000..03808e98310 --- /dev/null +++ b/backends/nordic/axon/backend.py @@ -0,0 +1,345 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON NPU backend for ExecuTorch. + +Compiles delegated subgraphs for execution on Nordic's AXON NPU. +Composes with TOSABackend for TOSA lowering, then compiles TOSA to +AXON command buffers via our converter + Nordic's compiler lib. + +Pipeline:: + + ExportedProgram -> TOSABackend._preprocess() -> TOSA flatbuffer + -> tosa_reader -> axon_compiler -> axon_binary -> Nordic compiler lib + -> compiled AXON model (.h with command buffers) +""" + +from __future__ import annotations + +import logging +import os +import tempfile +from typing import final, List + +from executorch.backends.arm.tosa.backend import TOSABackend +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult +from executorch.exir.backend.compile_spec_schema import CompileSpec + +from .codegen import ( + derive_subgraph_name, + make_marker, + regenerate_table, + rewrite_header_symbols, + rewrite_op_extension_symbols, + write_subgraph_header, +) +from executorch.backends.nordic.axon_binary import AxonBinaryBuilder +from executorch.backends.nordic.axon_types import ActivationQuantInfo +from executorch.backends.nordic.axon_compiler import tosa_to_axon_layers +from executorch.backends.nordic.tosa_reader import parse_tosa_flatbuffer + +logger = logging.getLogger(__name__) + + +def _extract_activation_quant_info(edge_program) -> list[ActivationQuantInfo]: + """Scan the edge program FX graph for sigmoid/tanh/softmax ops and + extract their input/output quantization parameters. + + AXON op extensions (sigmoid=101, tanh=102, softmax=100) require the + preceding layer to output INT16 q3.12 (sigmoid/tanh) or INT32 q11.12 + (softmax). To recompute the rescale, we need the input/output scales + of the activation, which are NOT directly available in the TOSA graph + (TOSA uses TABLE ops with the scales baked in). So we extract them + here from the still-quantize/dequantize-annotated FX graph, before + TOSA lowering destroys that information. + + Returns: + List of ActivationQuantInfo records, in the order activations + appear in the graph. + """ + info: list[ActivationQuantInfo] = [] + nodes = list(edge_program.graph_module.graph.nodes) + + sigmoid_targets = ("aten.sigmoid", "aten_sigmoid") + tanh_targets = ("aten.tanh", "aten_tanh") + softmax_targets = ("aten.softmax", "aten._softmax", "aten_softmax", "aten__softmax") + amax_targets = ("aten.amax", "aten_amax") + + def matches(node, names): + s = str(node.target) + return any(n in s for n in names) + + def get_quant_args(qnode): + if qnode is None or len(qnode.args) < 3: + return None + scale = qnode.args[1] + zp = qnode.args[2] + return (float(scale), int(zp)) + + def find_input_quant(act_node): + if not act_node.args: + return None + cur = act_node.args[0] + for _ in range(4): + if cur is None or not hasattr(cur, "op"): + return None + if "dequantize" in str(cur.target): + cur = cur.args[0] if cur.args else None + continue + if "quantize" in str(cur.target): + return get_quant_args(cur) + return None + return None + + def find_output_quant(act_node): + for user in act_node.users: + if "quantize" in str(user.target) and "dequantize" not in str(user.target): + return get_quant_args(user) + return None + + def find_softmax_output_quant(amax_node): + seen = set() + frontier = [amax_node] + steps = 0 + while frontier and steps < 30: + steps += 1 + nxt = [] + for n in frontier: + for u in n.users: + if id(u) in seen: + continue + seen.add(id(u)) + s = str(u.target) + if "aten" in s and "mul" in s and "tensor" in s.lower(): + return find_output_quant(u) + nxt.append(u) + frontier = nxt + return None + + for node in nodes: + if node.op != "call_function": + continue + if matches(node, sigmoid_targets): + op_type = "sigmoid" + elif matches(node, tanh_targets): + op_type = "tanh" + elif matches(node, softmax_targets): + op_type = "softmax" + elif matches(node, amax_targets): + in_q = find_input_quant(node) + out_q = find_softmax_output_quant(node) + if in_q is None or out_q is None: + continue + info.append(ActivationQuantInfo( + op_type="softmax", + input_scale=in_q[0], input_zp=in_q[1], + output_scale=out_q[0], output_zp=out_q[1], + )) + logger.debug(f" Activation softmax (from amax): in_scale={in_q[0]:.6f} " + f"in_zp={in_q[1]} out_scale={out_q[0]:.6f} out_zp={out_q[1]}") + continue + else: + continue + + in_q = find_input_quant(node) + out_q = find_output_quant(node) + if in_q is None or out_q is None: + logger.warning(f" Could not extract quant info for {op_type} ({node.name})") + continue + + info.append(ActivationQuantInfo( + op_type=op_type, + input_scale=in_q[0], + input_zp=in_q[1], + output_scale=out_q[0], + output_zp=out_q[1], + )) + logger.debug(f" Activation {op_type}: in_scale={in_q[0]:.6f} in_zp={in_q[1]} " + f"out_scale={out_q[0]:.6f} out_zp={out_q[1]}") + + return info + + +@final +class AxonBackend(BackendDetails): + """ExecuTorch backend for Nordic AXON NPU. + + Follows the same composition pattern as EthosUBackend: + reuses TOSABackend for TOSA lowering, then compiles + TOSA -> AXON command buffers. + """ + + @staticmethod + def preprocess( + edge_program, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """Compile a delegated subgraph for AXON NPU execution. + + Args: + edge_program: ExportedProgram from ExecuTorch edge lowering. + compile_specs: List of CompileSpec. We look for: + - tosa_spec: TOSA version/profile (default: TOSA-1.0+INT) + - sdk_edge_ai_path: Path to Nordic sdk-edge-ai repo + - model_name: Name for the compiled model + - axon_generated_dir: Where to write generated headers + + Returns: + PreprocessResult with AXON marker as processed_bytes. + """ + # Parse compile specs + sdk_path = os.environ.get("SDK_EDGE_AI_PATH", "") + model_name = "axon_model" + tosa_spec_str = "TOSA-1.0+INT" + generated_dir_override: str | None = None + + for spec in compile_specs: + if spec.key == "sdk_edge_ai_path": + sdk_path = spec.value.decode() if isinstance(spec.value, bytes) else spec.value + elif spec.key == "model_name": + model_name = spec.value.decode() if isinstance(spec.value, bytes) else spec.value + elif spec.key == "tosa_spec": + tosa_spec_str = spec.value.decode() if isinstance(spec.value, bytes) else spec.value + elif spec.key == "axon_generated_dir": + generated_dir_override = ( + spec.value.decode() if isinstance(spec.value, bytes) else spec.value + ) + + logger.info("AxonBackend.preprocess: model=%s, tosa_spec=%s", + model_name, tosa_spec_str) + + try: + return AxonBackend._do_preprocess( + edge_program, model_name, tosa_spec_str, + sdk_path, generated_dir_override, + ) + except Exception: + logger.exception("AxonBackend.preprocess failed for model=%s", model_name) + # Return marker-only so the .pte is still valid (firmware will + # report "subgraph not found" instead of crashing at load time). + return PreprocessResult( + processed_bytes=make_marker( + derive_subgraph_name(model_name, b"error") + ) + ) + + @staticmethod + def _do_preprocess( + edge_program, + model_name: str, + tosa_spec_str: str, + sdk_path: str, + generated_dir_override: str | None, + ) -> PreprocessResult: + """Internal preprocess implementation. Separated so the public + ``preprocess()`` can catch exceptions and return a safe fallback.""" + # Extract activation quantization info BEFORE TOSA lowering. + activation_info = _extract_activation_quant_info(edge_program) + if activation_info: + logger.info("Found %d activation op(s) for AXON op extensions", + len(activation_info)) + + # 1. Reuse TOSA lowering (shared with Ethos-U, VGF, etc.) + tosa_spec = TosaSpecification.create_from_string(tosa_spec_str) + tosa_compile_spec = TosaCompileSpec(tosa_spec) + tosa_result = TOSABackend._preprocess(edge_program, tosa_compile_spec) + tosa_flatbuffer = tosa_result.processed_bytes + logger.info("TOSA lowering produced %d bytes", len(tosa_flatbuffer)) + + # Save TOSA flatbuffer for debugging (unique per model_name) + tosa_debug_path = os.path.join( + tempfile.gettempdir(), + f"axon_tosa_debug_{model_name}.tosa", + ) + with open(tosa_debug_path, "wb") as f: + f.write(tosa_flatbuffer) + + # 2. Parse TOSA -> AXON layers + graph = parse_tosa_flatbuffer(tosa_flatbuffer) + layers = tosa_to_axon_layers(graph, activation_info=activation_info) + logger.info("Converted to %d AXON layers", len(layers)) + + # 3. Pack intermediate binary, derive stable subgraph name + builder = AxonBinaryBuilder() + intermediate_binary = builder.build(layers, model_name=model_name) + logger.info("Intermediate binary: %d bytes", len(intermediate_binary)) + + subgraph_name = derive_subgraph_name(model_name, intermediate_binary) + logger.info("Subgraph unique name: %s", subgraph_name) + + # 4. Call Nordic compiler lib + import platform + system = platform.system() + lib_names = { + "Linux": "libnrf-axon-nn-compiler-lib-amd64.so", + "Darwin": "libnrf-axon-nn-compiler-lib-arm64.dylib", + "Windows": "nrf-axon-nn-compiler-lib-amd64.dll", + } + lib_name = lib_names.get(system, lib_names["Linux"]) + compiler_lib_path = os.path.join( + sdk_path, "tools", "axon", "compiler", "bin", system, lib_name + ) + + if not os.path.exists(compiler_lib_path): + logger.warning("AXON compiler lib not found: %s", compiler_lib_path) + logger.warning( + "Returning marker only — firmware will not find a model. " + "Set SDK_EDGE_AI_PATH to your Nordic sdk-edge-ai directory." + ) + return PreprocessResult(processed_bytes=make_marker(subgraph_name)) + + # Re-pack with the unique name baked into the binary + intermediate_binary = builder.build(layers, model_name=subgraph_name) + + # Write intermediate binary to temp file, compile, read result + compile_dir = tempfile.mkdtemp(prefix="axon_compile_") + bin_path = os.path.join(compile_dir, f"{subgraph_name}.bin") + with open(bin_path, "wb") as f: + f.write(intermediate_binary) + + output_prefix = os.path.join(compile_dir, f"nrf_axon_model_{subgraph_name}") + + from executorch.backends.nordic.axon_compiler import _call_compiler_lib + result = _call_compiler_lib(compiler_lib_path, bin_path, output_prefix) + + if result != 0: + logger.error("AXON compilation failed with code %d", result) + return PreprocessResult(processed_bytes=make_marker(subgraph_name)) + + # Read the compiled header file + header_path = f"{output_prefix}_.h" + if not os.path.exists(header_path): + logger.error("Compiled header not found: %s", header_path) + return PreprocessResult(processed_bytes=make_marker(subgraph_name)) + + with open(header_path, "rb") as f: + compiled_header = f.read().decode("utf-8", errors="replace") + logger.info("AXON compilation successful: %d bytes header", + len(compiled_header)) + + # Rewrite op extension symbols to generic names + compiled_header = rewrite_op_extension_symbols(compiled_header) + + # Rewrite header symbols for unique naming + compiled_header = rewrite_header_symbols( + compiled_header, subgraph_name, subgraph_name, + ) + + # 5. Write per-subgraph .h and regenerate master table + if generated_dir_override: + from pathlib import Path + generated_dir = Path(generated_dir_override) + write_subgraph_header(generated_dir, subgraph_name, compiled_header) + regenerate_table(generated_dir) + else: + logger.info( + "No axon_generated_dir specified — skipping header generation. " + "Set axon_generated_dir in AxonCompileSpec to write firmware headers." + ) + + # 6. Return the marker as processed_bytes + return PreprocessResult(processed_bytes=make_marker(subgraph_name)) diff --git a/backends/nordic/axon/codegen.py b/backends/nordic/axon/codegen.py new file mode 100644 index 00000000000..283fc488b91 --- /dev/null +++ b/backends/nordic/axon/codegen.py @@ -0,0 +1,265 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Code generation for AXON delegate integration. + +Handles subgraph naming, marker format, header symbol rewriting, and +generated table management. These are the generic parts needed by any +firmware project using the AXON ExecuTorch delegate. + +Every call to ``AxonBackend.preprocess`` produces one AXON-compiled +subgraph. Each subgraph gets a **stable unique name** derived from a +hash of the intermediate binary, so the per-subgraph C symbols never +collide and rebuilding the same model gives the same names. + +The generated directory layout:: + + / + axon_subgraph_.h -- one per delegated subgraph + axon_subgraphs_table.h -- regenerated on every preprocess() call + .gitignore -- so this dir doesn't get committed +""" +from __future__ import annotations + +import hashlib +import logging +import re +from pathlib import Path + +logger = logging.getLogger(__name__) + + +# ── Marker format ───────────────────────────────────────────────── +# +# The .pte's processed_bytes for an AXON delegate handle: +# +# offset size field +# ------ ---- ----- +# 0 4 magic "AXNG" (Axon NN, Generated) +# 4 4 version little-endian uint32, currently 1 +# 8 4 name_len little-endian uint32, length of name in bytes +# 12 N name ASCII subgraph name, no NUL terminator +# 12+N pad pad to 4-byte alignment +# +# The total size is small (<256 bytes) and constant per subgraph. + +_AXON_MARKER_MAGIC = b"AXNG" +_AXON_MARKER_VERSION = 1 + + +def make_marker(subgraph_name: str) -> bytes: + """Build the binary marker that goes into processed_bytes.""" + name_bytes = subgraph_name.encode("ascii") + if len(name_bytes) > 255: + raise ValueError( + f"subgraph name too long ({len(name_bytes)} > 255 bytes): " + f"{subgraph_name!r}" + ) + payload = _AXON_MARKER_MAGIC + payload += _AXON_MARKER_VERSION.to_bytes(4, "little") + payload += len(name_bytes).to_bytes(4, "little") + payload += name_bytes + # Pad to 4-byte alignment so consumers can over-read safely. + if len(payload) % 4: + payload += b"\x00" * (4 - len(payload) % 4) + return payload + + +# ── Subgraph naming ─────────────────────────────────────────────── + +# 12 hex chars = 48 bits of SHA-256 hash. At 48 bits, the probability +# of collision among N subgraphs is ~N^2 / 2^49 (birthday bound). +# For 1000 subgraphs: ~1.8e-9. Safe for any practical firmware build. +_NAME_HASH_LEN = 12 + + +def derive_subgraph_name(model_name_prefix: str, intermediate_binary: bytes) -> str: + """Stable unique subgraph name from the model name + binary content. + + The result is a valid C identifier and starts with the prefix so it's + grep-friendly in the firmware build output. + """ + digest = hashlib.sha256(intermediate_binary).hexdigest()[:_NAME_HASH_LEN] + safe_prefix = re.sub(r"[^a-zA-Z0-9_]", "_", model_name_prefix) + return f"{safe_prefix}_{digest}" + + +# ── Header rewriting ────────────────────────────────────────────── + +_RENAME_TOKEN_RE = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\b") + + +def rewrite_header_symbols(header_text: str, old_name: str, new_name: str) -> str: + """Rename all per-model C symbols in a Nordic-generated .h. + + The Nordic compiler embeds the model name into many C identifiers: + + - ``axon_model_const_`` + - ``cmd_buffer_`` + - ``model_`` + - ``axon_model__packed_output_buf`` + - ``NRF_AXON_MODEL__PACKED_OUTPUT_SIZE`` (uppercase) + - ``NRF_AXON_MODEL__MAX_IL_BUFFER_USED`` + - ``NRF_AXON_MODEL__MAX_PSUM_BUFFER_USED`` + - ``.model_name = ""`` (a string literal) + """ + if old_name == new_name: + return header_text + + old_upper = old_name.upper() + new_upper = new_name.upper() + + def replace_token(match: re.Match) -> str: + tok = match.group(1) + for old_marker, new_marker in ( + (f"axon_model_const_{old_name}", f"axon_model_const_{new_name}"), + (f"axon_model_{old_name}_packed_output_buf", + f"axon_model_{new_name}_packed_output_buf"), + (f"NRF_AXON_MODEL_{old_upper}_PACKED_OUTPUT_SIZE", + f"NRF_AXON_MODEL_{new_upper}_PACKED_OUTPUT_SIZE"), + (f"NRF_AXON_MODEL_{old_upper}_MAX_IL_BUFFER_USED", + f"NRF_AXON_MODEL_{new_upper}_MAX_IL_BUFFER_USED"), + (f"NRF_AXON_MODEL_{old_upper}_MAX_PSUM_BUFFER_USED", + f"NRF_AXON_MODEL_{new_upper}_MAX_PSUM_BUFFER_USED"), + (f"cmd_buffer_{old_name}", f"cmd_buffer_{new_name}"), + (f"model_{old_name}", f"model_{new_name}"), + ): + if tok == old_marker: + return new_marker + return tok + + rewritten = _RENAME_TOKEN_RE.sub(replace_token, header_text) + + # Also rewrite the .model_name string literal. + rewritten = re.sub( + r'\.model_name\s*=\s*"' + re.escape(old_name) + r'"', + f'.model_name = "{new_name}"', + rewritten, + ) + return rewritten + + +def rewrite_op_extension_symbols(header_text: str) -> str: + """Rewrite Nordic op extension symbols to generic names. + + Nordic's compiler generates ``nrf_axon_nn_op_extension_sigmoid`` etc. + in the compiled headers. We rewrite these to ``axon_op_extension_*`` + so the firmware provides a consistent, non-vendor-prefixed interface. + """ + replacements = { + "nrf_axon_nn_op_extension_sigmoid": "axon_op_extension_sigmoid", + "nrf_axon_nn_op_extension_tanh": "axon_op_extension_tanh", + "nrf_axon_nn_op_extension_softmax": "axon_op_extension_softmax", + } + for old, new in replacements.items(): + header_text = header_text.replace(old, new) + return header_text + + +# ── Generated directory layout ──────────────────────────────────── + +_TABLE_FILENAME = "axon_subgraphs_table.h" +_SUBGRAPH_PREFIX = "axon_subgraph_" + + +def write_subgraph_header( + generated_dir: Path, + subgraph_name: str, + header_text: str, +) -> Path: + """Write a single subgraph .h into the generated directory. + + Returns the path written. Creates ``generated_dir`` if missing. + """ + generated_dir.mkdir(parents=True, exist_ok=True) + out_path = generated_dir / f"{_SUBGRAPH_PREFIX}{subgraph_name}.h" + out_path.write_text(header_text) + logger.info(f"AXON subgraph header -> {out_path}") + return out_path + + +def regenerate_table(generated_dir: Path) -> Path: + """Regenerate ``axon_subgraphs_table.h`` from the current contents. + + Scans ``generated_dir`` for ``axon_subgraph_*.h`` files and emits a + deterministic master table. Idempotent: re-running with the same + set of subgraph headers produces the same output bytes. + """ + generated_dir.mkdir(parents=True, exist_ok=True) + + subgraph_paths = sorted( + p for p in generated_dir.iterdir() + if p.is_file() and p.name.startswith(_SUBGRAPH_PREFIX) and p.name.endswith(".h") + ) + names = [p.name[len(_SUBGRAPH_PREFIX) : -2] for p in subgraph_paths] + + lines: list[str] = [ + "/* Auto-generated AXON subgraphs table — do NOT edit by hand.", + " * Regenerated by the Nordic AXON backend on every AxonBackend.preprocess() call.", + " * Owns the lookup from delegate marker name -> compiled model struct.", + " */", + "#pragma once", + "#include ", + "", + '#include "axon/nrf_axon_platform.h"', + '#include "drivers/axon/nrf_axon_nn_infer.h"', + "", + "/* Each subgraph header allocates its own packed output buffer */", + "#define NRF_AXON_MODEL_ALLOCATE_PACKED_OUTPUT_BUFFER 1", + "", + ] + for path in subgraph_paths: + lines.append(f'#include "{path.name}"') + lines.append("") + lines.append("typedef struct {") + lines.append(" const char *name;") + lines.append(" const nrf_axon_nn_compiled_model_s *model;") + lines.append("} axon_subgraph_entry_t;") + lines.append("") + lines.append(f"#define AXON_SUBGRAPHS_COUNT {len(names)}") + lines.append("") + if names: + lines.append("static const axon_subgraph_entry_t axon_subgraphs[] = {") + for name in names: + lines.append(f' {{"{name}", &model_{name}}},') + lines.append("};") + else: + lines.append("/* No subgraphs registered yet. */") + lines.append("static const axon_subgraph_entry_t axon_subgraphs[1] = {{0}};") + lines.append("") + + table_path = generated_dir / _TABLE_FILENAME + table_path.write_text("\n".join(lines)) + logger.info( + f"AXON subgraphs table -> {table_path} ({len(names)} subgraph(s))" + ) + + # Drop a tiny .gitignore so the directory is self-cleaning if it + # ever gets committed by accident. + gitignore = generated_dir / ".gitignore" + if not gitignore.exists(): + gitignore.write_text("# Auto-generated by AXON backend; do not commit.\n*\n!.gitignore\n") + return table_path + + +def clean_generated_dir(generated_dir: Path) -> int: + """Remove every ``axon_subgraph_*.h`` and the master table from a dir. + + Returns the number of files removed. Useful when re-exporting a + different model and you want to drop the previous model's + subgraphs from the firmware build. + """ + if not generated_dir.exists(): + return 0 + removed = 0 + for path in generated_dir.iterdir(): + if path.is_file() and ( + path.name.startswith(_SUBGRAPH_PREFIX) or path.name == _TABLE_FILENAME + ): + path.unlink() + removed += 1 + if removed: + logger.info(f"Removed {removed} stale AXON file(s) from {generated_dir}") + return removed diff --git a/backends/nordic/axon/compile_spec.py b/backends/nordic/axon/compile_spec.py new file mode 100644 index 00000000000..9868cc456e3 --- /dev/null +++ b/backends/nordic/axon/compile_spec.py @@ -0,0 +1,66 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON NPU compile specification for ExecuTorch.""" + +from __future__ import annotations + +from executorch.exir.backend.compile_spec_schema import CompileSpec + +# AXON hardware constraints (from nrf_axon_nn_compiler_types.h and Nordic docs) +AXON_MAX_FC_INPUT = 2048 +AXON_MAX_FC_OUTPUT = 2048 +AXON_MAX_CONV2D_FILTER = 16 # Max filter height/width for Conv2D +AXON_MAX_CONV_STRIDE = 31 +AXON_MAX_POOL_FILTER = 32 # Max filter height/width for pooling +AXON_MAX_TENSOR_DIM = 1024 # Max height/width/channels +AXON_MAX_INPUTS_PER_NODE = 2 + + +class AxonCompileSpec: + """Configuration for compiling models targeting the AXON NPU. + + Args: + sdk_edge_ai_path: Path to Nordic sdk-edge-ai repo. Can also be + set via the ``SDK_EDGE_AI_PATH`` environment variable. Required + for compilation to AXON command buffers; not needed for TOSA + lowering only. + model_name: Human-readable prefix for delegated subgraphs. + The actual C symbols in the generated headers append a + content-derived hash suffix so multiple subgraphs in the + same firmware build never collide. + tosa_spec: TOSA version string (default: "TOSA-1.0+INT"). + axon_generated_dir: Where ``preprocess()`` writes the per-subgraph + ``axon_subgraph_*.h`` files and the master + ``axon_subgraphs_table.h``. Required when writing generated + headers for firmware integration. + """ + + def __init__( + self, + sdk_edge_ai_path: str | None = None, + model_name: str = "axon_model", + tosa_spec: str = "TOSA-1.0+INT", + axon_generated_dir: str | None = None, + ): + self.sdk_edge_ai_path = sdk_edge_ai_path + self.model_name = model_name + self.tosa_spec = tosa_spec + self.axon_generated_dir = axon_generated_dir + + def to_compile_specs(self) -> list[CompileSpec]: + """Convert to ExecuTorch CompileSpec list.""" + specs = [ + CompileSpec("tosa_spec", self.tosa_spec.encode()), + CompileSpec("output_format", b"tosa"), + CompileSpec("model_name", self.model_name.encode()), + ] + if self.sdk_edge_ai_path: + specs.append(CompileSpec("sdk_edge_ai_path", self.sdk_edge_ai_path.encode())) + if self.axon_generated_dir: + specs.append( + CompileSpec("axon_generated_dir", self.axon_generated_dir.encode()) + ) + return specs diff --git a/backends/nordic/axon/partitioner.py b/backends/nordic/axon/partitioner.py new file mode 100644 index 00000000000..60885921130 --- /dev/null +++ b/backends/nordic/axon/partitioner.py @@ -0,0 +1,66 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON NPU partitioner for ExecuTorch. + +Delegates supported operations to the AXON NPU backend. +Reuses TOSAPartitioner for TOSA-compatible operation checking, +with additional AXON hardware constraint checks that reject nodes +exceeding tensor size limits, input count, or filter dimensions. +""" + +from __future__ import annotations + +from typing import final + +from executorch.backends.arm.tosa.partitioner import TOSAPartitioner +from executorch.exir.backend.partitioner import DelegationSpec + +from .backend import AxonBackend +from .compile_spec import AxonCompileSpec + + +@final +class AxonPartitioner(TOSAPartitioner): + """Partitioner that delegates supported operations to AXON NPU. + + Inherits from TOSAPartitioner to reuse TOSA-based operator support + checks. Adds AXON-specific hardware constraint checks that reject + nodes exceeding: + + - Max tensor dimensions (1024 height/width/channels) + - Max input count per node (2) + - Max FC input/output (2048) + - Max Conv2D filter size (16x16) and stride (31) + + Nodes that pass TOSA checks but fail AXON constraints fall back + to CPU execution via ExecuTorch's portable kernels. + """ + + def __init__( + self, + compile_spec: AxonCompileSpec, + additional_checks=None, + ): + self.compile_spec = compile_spec + self.delegation_spec = DelegationSpec( + AxonBackend.__name__, + compile_spec.to_compile_specs(), + ) + + # AXON hardware constraint checks can be added via additional_checks. + # By default, TOSA-level checks handle partitioning. For stricter + # enforcement of AXON-specific limits (tensor dim 1024, FC 2048, + # Conv filter 16, max 2 inputs), pass get_axon_constraint_checks(): + # + # from executorch.backends.nordic.operator_support.axon_constraints import ( + # get_axon_constraint_checks, + # ) + # partitioner = AxonPartitioner(spec, additional_checks=get_axon_constraint_checks()) + self.additional_checks = additional_checks or [] + + # Use TOSA INT profile for operator support checking + from executorch.backends.arm.tosa.specification import TosaSpecification + self.tosa_spec = TosaSpecification.create_from_string(compile_spec.tosa_spec) diff --git a/backends/nordic/axon/quantizer.py b/backends/nordic/axon/quantizer.py new file mode 100644 index 00000000000..dd49e845f45 --- /dev/null +++ b/backends/nordic/axon/quantizer.py @@ -0,0 +1,65 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON NPU quantizer for ExecuTorch. + +Provides a quantizer configured for the AXON NPU's INT8 requirements. +Wraps the ARM backend's quantizer infrastructure with AXON-specific +defaults: + +- Symmetric INT8 quantization (AXON requirement) +- Per-channel weights for FC/Conv (better accuracy) +- TOSA-1.0+INT profile (matches the AXON compilation pipeline) + +Usage:: + + from executorch.backends.nordic.axon import AxonQuantizer + + quantizer = AxonQuantizer() + prepared = prepare_pt2e(exported.module(), quantizer) + prepared(*calibration_data) + quantized = convert_pt2e(prepared) +""" +from __future__ import annotations + +from executorch.backends.arm.quantizer import ( + TOSAQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.tosa.specification import TosaSpecification + + +class AxonQuantizer(TOSAQuantizer): + """Quantizer configured for the AXON NPU. + + Defaults to symmetric INT8 quantization with per-channel weights, + targeting the TOSA-1.0+INT profile that the AXON compilation + pipeline requires. + + Args: + per_channel: Use per-channel quantization for weights (default True). + Per-channel gives better accuracy; per-tensor gives smaller + command buffers (single shared shift instruction). + quantize_io: Also quantize the model's input and output tensors + (default False). When False, the model accepts fp32 input + and produces fp32 output, with q/dq ops at the AXON + delegation boundaries. + """ + + def __init__( + self, + per_channel: bool = True, + quantize_io: bool = False, + ): + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + super().__init__(tosa_spec) + + config = get_symmetric_quantization_config( + is_per_channel=per_channel, + ) + self.set_global(config) + + if quantize_io: + self.set_io(config) diff --git a/backends/nordic/axon_binary.py b/backends/nordic/axon_binary.py new file mode 100644 index 00000000000..c8c27edd48f --- /dev/null +++ b/backends/nordic/axon_binary.py @@ -0,0 +1,348 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON intermediate binary file builder. + +Creates the binary file that Nordic's AXON compiler lib reads as input. +Uses cffi to create C structs from nrf_axon_nn_compiler_types.h, +guaranteeing exact binary compatibility. + +Binary format:: + + [header struct: nrf_axon_nn_model_desc_hdr_s] (offsets to sections) + [model name string] + [meta info: nrf_axon_nn_model_meta_info_s] + [layer descs: nrf_axon_nn_model_layer_desc_s[]] + [constants: weights, biases, multipliers, shifts] + [compilation options: nrf_axon_nn_model_compilation_options_s] + [title string: "AXON_INTERMEDIATE_REPRESENTATION_FILE"] + [version: uint32] +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +from cffi import FFI + +from .axon_types import AxonLayer + +logger = logging.getLogger(__name__) + +# Binary format constants +BINARY_TITLE = "AXON_INTERMEDIATE_REPRESENTATION_FILE" +VERSION_MAJOR = 0 +VERSION_MINOR = 17 +VERSION_PATCH = 0 +MODEL_BIN_VERSION = (VERSION_MAJOR << 16) + (VERSION_MINOR << 8) + VERSION_PATCH + + +def _create_ffi() -> FFI: + """Create cffi FFI with AXON compiler structs defined manually. + + We define the structs explicitly rather than parsing the header file, + because cffi can't handle all the preprocessor macros and enum patterns + in nrf_axon_nn_compiler_types.h. + """ + ffi = FFI() + ffi.cdef(""" + typedef struct { + uint16_t height; + uint16_t width; + uint16_t channel_cnt; + int32_t byte_width; /* nrf_axon_nn_byte_width_e: 1=INT8, 2=INT16, 4=INT32 */ + } nrf_axon_nn_compiler_model_layer_dimensions_s; + + typedef struct { + uint32_t begin[3]; + uint32_t end[3]; + uint32_t strides[3]; + } nrf_axon_nn_compiler_strided_slice_parameters_s; + + typedef struct { + uint8_t input_id_cnt; + int16_t input_ids[4]; + int32_t nn_operation; /* nrf_axon_nn_op_e */ + nrf_axon_nn_compiler_model_layer_dimensions_s input_dimensions[4]; + nrf_axon_nn_compiler_model_layer_dimensions_s filter_dimensions; + nrf_axon_nn_compiler_model_layer_dimensions_s output_dimensions; + uint8_t concatenate_axis; + uint8_t stride_x; + uint8_t stride_y; + uint8_t dilation_x; + uint8_t dilation_y; + int8_t input_zero_point; + int8_t output_zero_point; + uint64_t bias_prime; /* offset into consts */ + uint64_t output_multipliers; /* offset into consts */ + uint64_t scale_shifts; /* offset into consts */ + uint16_t scale_shift_cnt; + int32_t activation_function; /* nrf_axon_nn_activation_function_e */ + uint8_t pad_left; + uint8_t pad_right; + uint8_t pad_top; + uint8_t pad_bottom; + uint64_t filter; /* offset into consts */ + uint32_t cpu_op_additional_attributes_count; + uint64_t cpu_op_additional_attributes; /* offset into consts */ + } nrf_axon_nn_model_layer_desc_s; + + typedef struct { + uint32_t offset; + uint32_t length; + } nrf_axon_nn_model_bin_item_s; + + typedef struct { + uint32_t mult; + uint8_t round; + int8_t zero_point; + } nrf_axon_nn_model_quant_paramters_s; + + typedef struct { + nrf_axon_nn_model_bin_item_s model_name; + nrf_axon_nn_model_bin_item_s model_labels; + uint32_t model_layer_cnt; + nrf_axon_nn_model_quant_paramters_s input_quant; + nrf_axon_nn_model_quant_paramters_s output_dequant; + } nrf_axon_nn_model_meta_info_s; + + typedef struct { + nrf_axon_nn_model_bin_item_s title; + nrf_axon_nn_model_bin_item_s version; + nrf_axon_nn_model_bin_item_s meta; + nrf_axon_nn_model_bin_item_s layers; + nrf_axon_nn_model_bin_item_s consts; + nrf_axon_nn_model_bin_item_s compilation_option; + } nrf_axon_nn_model_desc_hdr_s; + + typedef struct { + uint32_t interlayer_buffer_size; + uint32_t psum_buffer_size; + uint32_t header_file_test_vector_cnt; + int32_t convolution_2d_setting; + int32_t log_level; + int32_t psum_buffer_placement; + } nrf_axon_nn_model_compilation_options_s; + """) + return ffi + + +class AxonBinaryBuilder: + """Builds the AXON intermediate binary file using cffi structs.""" + + def __init__(self, compiler_types_hdr_path: str | None = None): + self.ffi = _create_ffi() + self._data = bytearray() + self._header_size = self.ffi.sizeof("nrf_axon_nn_model_desc_hdr_s") + + def build( + self, + layers: list[AxonLayer], + model_name: str = "model", + interlayer_buffer_size: int = 125000, + psum_buffer_size: int = 4096, + input_quant_mult: int = 1, + input_quant_round: int = 0, + input_quant_zp: int = 0, + output_quant_mult: int = 1, + output_quant_round: int = 0, + output_quant_zp: int = 0, + ) -> bytes: + """Build the complete intermediate binary. + + Returns: + bytes: The binary file contents ready for the compiler lib. + """ + self._data = bytearray() + header = self.ffi.new("nrf_axon_nn_model_desc_hdr_s *") + + # 1. Model name string + model_name_bytes = model_name.lower().encode("utf-8") + b"\x00" + model_name_item = self._append_data(model_name_bytes) + + # 2. Meta info + meta = self.ffi.new("nrf_axon_nn_model_meta_info_s *") + meta.model_name.offset = model_name_item[0] + meta.model_name.length = model_name_item[1] + meta.model_labels.offset = 0 + meta.model_labels.length = 0 + meta.model_layer_cnt = len(layers) + meta.input_quant.mult = input_quant_mult + meta.input_quant.round = input_quant_round + meta.input_quant.zero_point = input_quant_zp + meta.output_dequant.mult = output_quant_mult + meta.output_dequant.round = output_quant_round + meta.output_dequant.zero_point = output_quant_zp + meta_item = self._append_struct(meta, "nrf_axon_nn_model_meta_info_s") + + # 3. Layer descriptors + constants + consts_data = bytearray() + layer_structs_data = bytearray() + + for i, layer in enumerate(layers): + layer_struct = self.ffi.new("nrf_axon_nn_model_layer_desc_s *") + + # Input IDs + layer_struct.input_id_cnt = len(layer.input_ids) + for j in range(min(len(layer.input_ids), 4)): + layer_struct.input_ids[j] = layer.input_ids[j] + + # Operation + layer_struct.nn_operation = layer.operation + + # Input dimensions + for j in range(min(len(layer.input_dimensions), 4)): + d = layer.input_dimensions[j] + layer_struct.input_dimensions[j].height = d.height + layer_struct.input_dimensions[j].width = d.width + layer_struct.input_dimensions[j].channel_cnt = d.channel_cnt + layer_struct.input_dimensions[j].byte_width = d.byte_width + + # Filter dimensions + layer_struct.filter_dimensions.height = layer.filter_dimensions.height + layer_struct.filter_dimensions.width = layer.filter_dimensions.width + layer_struct.filter_dimensions.channel_cnt = layer.filter_dimensions.channel_cnt + layer_struct.filter_dimensions.byte_width = layer.filter_dimensions.byte_width + + # Output dimensions + layer_struct.output_dimensions.height = layer.output_dimensions.height + layer_struct.output_dimensions.width = layer.output_dimensions.width + layer_struct.output_dimensions.channel_cnt = layer.output_dimensions.channel_cnt + layer_struct.output_dimensions.byte_width = layer.output_dimensions.byte_width + + # Stride, dilation + layer_struct.stride_x = layer.stride_x + layer_struct.stride_y = layer.stride_y + layer_struct.dilation_x = layer.dilation_x + layer_struct.dilation_y = layer.dilation_y + + # Zero points + layer_struct.input_zero_point = layer.input_zero_point + layer_struct.output_zero_point = layer.output_zero_point + logger.debug(f" Binary layer {i}: in_zp={layer.input_zero_point} out_zp={layer.output_zero_point}") + + # Activation + layer_struct.activation_function = layer.activation + + # Padding + layer_struct.pad_left = layer.pad_left + layer_struct.pad_right = layer.pad_right + layer_struct.pad_top = layer.pad_top + layer_struct.pad_bottom = layer.pad_bottom + + # Constants — store as offsets into the consts section. + # Unused offsets must be 0xFFFFFFFFFFFFFFFF (sentinel), + # not 0 (which points to filter data and corrupts compilation). + if layer.filter_data: + layer_struct.filter = len(consts_data) + consts_data.extend(layer.filter_data) + self._pad_to_4(consts_data) + else: + layer_struct.filter = 0xFFFFFFFFFFFFFFFF + + if layer.bias_data: + layer_struct.bias_prime = len(consts_data) + consts_data.extend(layer.bias_data) + self._pad_to_4(consts_data) + else: + layer_struct.bias_prime = 0xFFFFFFFFFFFFFFFF + + from .axon_types import AxonOp + + # Determine number of output channels + n_out_ch = layer.output_dimensions.channel_cnt + if n_out_ch <= 1: + n_out_ch = max(n_out_ch, layer.output_dimensions.width) + n_out_ch = max(n_out_ch, layer.filter_dimensions.height) + + needs_per_ch_mult = layer.operation in ( + AxonOp.FULLY_CONNECTED, AxonOp.CONV2D, + AxonOp.DEPTHWISE_CONV2D, AxonOp.POINTWISE_CONV2D, + ) + + if layer.multiplier_data: + layer_struct.output_multipliers = len(consts_data) + mult_arr = np.frombuffer(layer.multiplier_data, dtype=np.int32) + if needs_per_ch_mult and len(mult_arr) < n_out_ch: + mult_arr = np.tile(mult_arr, (n_out_ch + len(mult_arr) - 1) // len(mult_arr))[:n_out_ch] + consts_data.extend(mult_arr.tobytes()) + self._pad_to_4(consts_data) + + if layer.shift_data: + layer_struct.scale_shifts = len(consts_data) + consts_data.extend(layer.shift_data) + self._pad_to_4(consts_data) + + layer_struct.scale_shift_cnt = layer.scale_shift_cnt + + layer_struct.cpu_op_additional_attributes_count = 0 + layer_struct.cpu_op_additional_attributes = 0xFFFFFFFFFFFFFFFF + + # Serialize layer struct + layer_bin = bytes(self.ffi.buffer(layer_struct)) + layer_structs_data.extend(layer_bin) + + # Append layers section + layers_item = self._append_data(bytes(layer_structs_data)) + + # Append constants section + consts_item = self._append_data(bytes(consts_data)) + + # 4. Compilation options + options = self.ffi.new("nrf_axon_nn_model_compilation_options_s *") + options.interlayer_buffer_size = interlayer_buffer_size + options.psum_buffer_size = psum_buffer_size + options.header_file_test_vector_cnt = 0 + options.convolution_2d_setting = 0 + options.log_level = 0 + options.psum_buffer_placement = 0 # INTERLAYER_BUFFER (Nordic's default) + options_item = self._append_struct(options, "nrf_axon_nn_model_compilation_options_s") + + # 5. Title string + title_bytes = BINARY_TITLE.encode("utf-8") + b"\x00" + title_item = self._append_data(title_bytes) + + # 6. Version (uint32) + version_bytes = np.array([MODEL_BIN_VERSION], dtype=np.uint32).tobytes() + version_item = self._append_data(version_bytes) + + # Fill header with offsets + header.title.offset = title_item[0] + header.title.length = title_item[1] + header.version.offset = version_item[0] + header.version.length = version_item[1] + header.meta.offset = meta_item[0] + header.meta.length = meta_item[1] + header.layers.offset = layers_item[0] + header.layers.length = layers_item[1] + header.consts.offset = consts_item[0] + header.consts.length = consts_item[1] + header.compilation_option.offset = options_item[0] + header.compilation_option.length = options_item[1] + + # Assemble: header + data + header_bin = bytes(self.ffi.buffer(header)) + return header_bin + bytes(self._data) + + def _append_data(self, data: bytes) -> tuple[int, int]: + """Append data to the binary, return (offset, length).""" + offset = self._header_size + len(self._data) + length = len(data) + self._data.extend(data) + self._pad_to_4(self._data) + return (offset, length) + + def _append_struct(self, struct_ptr, struct_type: str) -> tuple[int, int]: + """Serialize a cffi struct and append to binary.""" + struct_bytes = bytes(self.ffi.buffer(struct_ptr)) + return self._append_data(struct_bytes) + + @staticmethod + def _pad_to_4(buf: bytearray): + """Pad bytearray to 4-byte alignment.""" + while len(buf) % 4 != 0: + buf.append(0) diff --git a/backends/nordic/axon_compiler.py b/backends/nordic/axon_compiler.py new file mode 100644 index 00000000000..43fee9be260 --- /dev/null +++ b/backends/nordic/axon_compiler.py @@ -0,0 +1,888 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""TOSA → AXON compiler bridge. + +Reads a TOSA flatbuffer, converts operations to AXON layer descriptors, +packs them into the AXON intermediate binary format, and calls Nordic's +compiler library to produce command buffers. + +Supported AXON operators: + - FULLY_CONNECTED (CONV2D 1x1 on 1x1 spatial) + - CONV2D (standard 2D convolution) + - DEPTHWISE_CONV2D (depthwise separable convolution) + - POINTWISE_CONV2D (1x1 conv on spatial input) + - ADD2 (element-wise add with broadcast) + - MULTIPLY (element-wise multiply with broadcast) + - AVERAGE_POOLING (average pool 2D) + - MAX_POOLING (max pool 2D) + - MEAN (global average pooling / reduce) + +Architecture: + TOSA flatbuffer → parse (tosa_reader.py) + → fuse ops + RESCALE into AXON layers + → pack into AXON intermediate binary (nrf_axon_nn_model_desc_hdr_s format) + → call nrf_axon_compile_model() via ctypes + → produces C header with command buffers +""" + +from __future__ import annotations + +import ctypes +import logging +import os +import struct +import tempfile +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np + +from .axon.compile_spec import ( + AXON_MAX_CONV2D_FILTER, + AXON_MAX_CONV_STRIDE, + AXON_MAX_FC_INPUT, + AXON_MAX_FC_OUTPUT, + AXON_MAX_POOL_FILTER, + AXON_MAX_TENSOR_DIM, +) +from .axon_types import ( + ActivationQuantInfo, + AxonActivation, + AxonByteWidth, + AxonDimensions, + AxonLayer, + AxonOp, +) +from .tosa_reader import TosaGraph, TosaOperator, TosaTensor, parse_tosa_flatbuffer +from .axon_converters import ( + _convert_concat, + _convert_conv2d, + _convert_depthwise_conv2d, + _convert_elementwise, + _convert_pad, + _convert_persistent_var, + _convert_pool2d, + _convert_reduce_sum, + _convert_slice, + _create_rescale_layer, + _extract_rescale_params, + _handle_softmax_pattern, + _handle_table_op, + _optimized_scaling_shift, + _resolve_input_id, + _validate_axon_layer, +) + +logger = logging.getLogger(__name__) + + +def tosa_to_axon_layers( + graph: TosaGraph, + activation_info: list["ActivationQuantInfo"] | None = None, +) -> list[AxonLayer]: + """Convert a TOSA graph to AXON layer descriptors. + + Fuses CONV2D + RESCALE pairs into single AXON layers with + quantization parameters. Skips RESHAPE/CONST/CONST_SHAPE ops. + + The TOSA graph for a quantized Linear layer looks like: + CONV2D(input, weights[O,1,1,I], bias[O], input_zp, weight_zp) → INT16 + RESCALE(result, multiplier, shift, input_zp, output_zp) → INT8 + + This fuses into one AXON FULLY_CONNECTED layer with: + - weights from CONV2D input[1] + - bias from CONV2D input[2] + - zero points from CONV2D input[3,4] + - multiplier/shift from RESCALE input[1,2] + - output zero point from RESCALE input[4] + + Args: + activation_info: Quantization info for sigmoid/tanh/softmax ops, + extracted from the PyTorch FX graph before TOSA lowering. Used + to convert TOSA TABLE ops into AXON op extensions. The list + order matches the order of TABLE ops in the TOSA graph. + """ + ops = graph.get_non_const_operators() + layers = [] + axon_layer_idx = 0 + # Index into activation_info — incremented as we consume TABLE ops + activation_info = activation_info or [] + activation_idx = 0 + # Map TOSA output tensor name → AXON layer index + tensor_to_layer = {} + # Track spatial shapes before flatten (for CHW→HWC weight permutation) + # Maps tensor name → (H, W, C) shape before the flatten RESHAPE + tensor_pre_flatten_shape = {} + # Track output zero points through the graph (for ops like MAX_POOL that + # don't have explicit zp tensors but need to inherit from their input) + tensor_to_zp = {} # tensor_name → int zero_point + # Track skipped rescale info for ADD (scale ratios from standalone RESCALEs + # that were too large to implement as identity layers) + tensor_rescale_info = {} # tensor_name → (scale_ratio, input_zp) + + if logger.isEnabledFor(logging.DEBUG): + for idx, op in enumerate(ops): + ins = [f"{t.shape}" for t in op.input_tensors if not t.has_data] + outs = [f"{t.shape}" for t in op.output_tensors] + logger.debug(f" TOSA[{idx}] {op.op_name}: in={ins} -> out={outs}") + + i = 0 + while i < len(ops): + op = ops[i] + + if op.op_name in ("RESHAPE", "TRANSPOSE"): + # Transparent — just pass through tensor mapping + if op.input_tensors and op.output_tensors: + in_name = op.input_tensors[0].name + out_name = op.output_tensors[0].name + if in_name in tensor_to_layer: + tensor_to_layer[out_name] = tensor_to_layer[in_name] + + # Track flatten: RESHAPE from [N,H,W,C] to [N,1,1,H*W*C] or [N,H*W*C] + # The pre-flatten spatial shape is needed to permute FC weights + # from HWC (TOSA) to CHW (AXON) order. + if op.op_name == "RESHAPE": + in_shape = op.input_tensors[0].shape + out_shape = op.output_tensors[0].shape + in_numel = 1 + for s in in_shape: + in_numel *= s + out_numel = 1 + for s in out_shape: + out_numel *= s + # Detect flatten: same element count, input has spatial dims, output is flat + if (in_numel == out_numel and len(in_shape) >= 3 + and in_shape[-1] > 1 and in_shape[-2] > 1): + # Input is [N, H, W, C] — record spatial shape + if len(in_shape) == 4: + tensor_pre_flatten_shape[out_name] = (in_shape[1], in_shape[2], in_shape[3]) + elif len(in_shape) == 3: + tensor_pre_flatten_shape[out_name] = (in_shape[0], in_shape[1], in_shape[2]) + logger.debug(f" Flatten detected: {in_shape} → {out_shape}, " + f"pre-flatten HWC={tensor_pre_flatten_shape.get(out_name)}") + + # Propagate pre-flatten shape through pass-through ops + if in_name in tensor_pre_flatten_shape: + tensor_pre_flatten_shape[out_name] = tensor_pre_flatten_shape[in_name] + + # Propagate zero points through pass-through ops + if in_name in tensor_to_zp: + tensor_to_zp[out_name] = tensor_to_zp[in_name] + + i += 1 + continue + + if op.op_name == "CLAMP": + # Fuse CLAMP into the preceding AXON layer's activation field. + # CLAMP(min=0, max=127) = ReLU, CLAMP(min=-128, max=X) where X < 127 = custom clamp. + if op.input_tensors and op.output_tensors: + in_name = op.input_tensors[0].name + out_name = op.output_tensors[0].name + attrs = op.attributes + min_int = attrs.get("min_int", -128) + max_int = attrs.get("max_int", 127) + + if in_name in tensor_to_layer: + prev_layer_idx = tensor_to_layer[in_name] + if min_int >= 0 and max_int > 0: + # ReLU or ReLU6: clamp to [0, X]. + # Nordic: "ReLU6 is mapped to ReLU because quantization + # causes saturation at 6" — the clip at 6 is already + # handled by INT8 quantization range clipping. + layers[prev_layer_idx].activation = AxonActivation.RELU + if max_int < 127: + logger.debug(f" Fused ReLU6/CLAMP(0,{max_int}) as ReLU into layer {prev_layer_idx}") + else: + logger.debug(f" Fused ReLU into layer {prev_layer_idx}") + elif min_int == -128 and max_int == 127: + # No-op clamp (full INT8 range) + pass + else: + # Negative min with non-full range — could be LeakyReLU territory + logger.warning(f" CLAMP({min_int},{max_int}) not fused — " + f"AXON only supports ReLU/LeakyReLU activation fusion") + tensor_to_layer[out_name] = tensor_to_layer[in_name] + i += 1 + continue + + if op.op_name == "RESCALE": + # Standalone RESCALE: chain into the preceding AXON layer by + # combining the rescale parameters. This avoids creating extra + # identity conv layers that add quantization error. + # Nordic's TFLite pipeline doesn't produce standalone RESCALEs. + if op.input_tensors and op.output_tensors: + in_name = op.input_tensors[0].name + out_name = op.output_tensors[0].name + if in_name in tensor_to_layer: + prev_idx = tensor_to_layer[in_name] + prev_layer = layers[prev_idx] + + # Get the standalone RESCALE's parameters + new_zp, new_mult_data, new_shift_data, new_cnt = _extract_rescale_params(op) + + can_chain = False + if prev_layer.multiplier_data and new_mult_data: + # Check if chaining is safe: combined scale must be < 1.0 + # to keep multiplier within AXON's INT32 range. + # Standalone RESCALEs with large scale (>1) are for + # requantizing between different domains (e.g., for ADD) + # and produce multipliers near INT32 max when chained. + prev_mult = np.frombuffer(prev_layer.multiplier_data, dtype=np.int32) + prev_shift = np.frombuffer(prev_layer.shift_data, dtype=np.int8) + new_mult = np.frombuffer(new_mult_data, dtype=np.int32) + new_shift = np.frombuffer(new_shift_data, dtype=np.int8) + + # Check combined scale magnitude + ps0 = float(prev_mult[0]) / (2.0 ** int(prev_shift[0])) + ns0 = float(new_mult[0]) / (2.0 ** int(new_shift[0])) + combined_scale = ps0 * ns0 + + if combined_scale < 1.0: + # Safe to chain — combined scale fits in AXON range + combined_mult = np.zeros_like(prev_mult) + combined_shift = np.zeros_like(prev_shift) + for ch in range(len(prev_mult)): + ps = float(prev_mult[ch]) / (2.0 ** int(prev_shift[ch])) + ns_idx = min(ch, len(new_mult) - 1) + ns = float(new_mult[ns_idx]) / (2.0 ** int(new_shift[ns_idx])) + cs = ps * ns + m, s = _optimized_scaling_shift(cs, new_zp) + combined_mult[ch] = m + combined_shift[ch] = s + + prev_layer.multiplier_data = combined_mult.tobytes() + prev_layer.shift_data = combined_shift.tobytes() + prev_layer.output_zero_point = new_zp + can_chain = True + logger.debug(f" Chained RESCALE into layer {prev_idx}: " + f"new out_zp={new_zp}") + else: + logger.debug(f" Cannot chain RESCALE: combined_scale={combined_scale:.2f} > 1.0, " + f"creating identity conv instead") + + if not can_chain: + # Combined scale > 1.0 means the RESCALE is amplifying. + # Identity conv layers with large multipliers destroy + # precision (int8 clips all values). Skip and keep the + # original tensor — the ADD will use the rescale info + # to compute proper per-input multipliers and bias. + # + # Save the skipped rescale's scale ratio for ADD: + # The standalone RESCALE converts from prev domain to ADD domain. + # scale_ratio = prev_scale / add_scale (the ns0 value) + rescale_in_zp = 0 + if len(op.input_tensors) > 3 and op.input_tensors[3].data is not None: + rescale_in_zp = int(op.input_tensors[3].data.flat[0]) + tensor_rescale_info[out_name] = (ns0, rescale_in_zp) + logger.debug(f" Skipping amplifying RESCALE (scale={ns0:.0f}, " + f"in_zp={rescale_in_zp}), keeping zp={prev_layer.output_zero_point}") + tensor_to_layer[out_name] = prev_idx + tensor_to_zp[out_name] = prev_layer.output_zero_point + i += 1 + continue + + tensor_to_layer[out_name] = prev_idx + # Track the output zero point for downstream ops (e.g., MAX_POOL) + tensor_to_zp[out_name] = new_zp + else: + # Input is a graph-level input — pass through + pass + i += 1 + continue + + # Check if next op is RESCALE (many ops fuse with it) + rescale_op = None + if i + 1 < len(ops) and ops[i + 1].op_name == "RESCALE": + rescale_op = ops[i + 1] + + layer = None + + if op.op_name == "CONV2D": + # Check if this FC's input was spatially flattened (needs weight permutation) + input_name = op.input_tensors[0].name + pfs = tensor_pre_flatten_shape.get(input_name) + layer = _convert_conv2d(op, rescale_op, tensor_to_layer, axon_layer_idx, + pre_flatten_shape=pfs) + + elif op.op_name == "DEPTHWISE_CONV2D": + # TOSA uses DEPTHWISE_CONV2D for Conv2d with 1 input channel. + # But AXON should use regular CONV2D for this (Nordic's convention). + # True depthwise: c_in>1, depth_mult=1 (groups=channels). + # Fake depthwise: c_in=1, depth_mult>1 (really a standard conv). + weight_shape = op.input_tensors[1].shape # [KH, KW, C_in, M] + c_in = weight_shape[2] if len(weight_shape) > 2 else 1 + depth_mult = weight_shape[3] if len(weight_shape) > 3 else 1 + if c_in == 1 and depth_mult > 1: + # Standard conv disguised as depthwise — use CONV2D + layer = _convert_depthwise_conv2d(op, rescale_op, tensor_to_layer, axon_layer_idx, + as_conv2d=True) + else: + layer = _convert_depthwise_conv2d(op, rescale_op, tensor_to_layer, axon_layer_idx) + + elif op.op_name == "ADD": + layer = _convert_elementwise(op, rescale_op, tensor_to_layer, axon_layer_idx, AxonOp.ADD2, graph, tensor_to_zp, tensor_rescale_info) + + elif op.op_name == "MUL": + layer = _convert_elementwise(op, rescale_op, tensor_to_layer, axon_layer_idx, AxonOp.MULTIPLY, graph, tensor_to_zp, tensor_rescale_info) + + elif op.op_name == "AVG_POOL2D": + layer = _convert_pool2d(op, rescale_op, tensor_to_layer, axon_layer_idx, AxonOp.AVERAGE_POOLING) + + elif op.op_name == "MAX_POOL2D": + # MAX_POOL has no zp tensors in TOSA — propagate from preceding layer + pool_input_zp = tensor_to_zp.get(op.input_tensors[0].name, 0) + layer = _convert_pool2d(op, rescale_op, tensor_to_layer, axon_layer_idx, AxonOp.MAX_POOLING, + input_zp_from_graph=pool_input_zp) + + elif op.op_name == "REDUCE_SUM": + layer = _convert_reduce_sum(op, rescale_op, tensor_to_layer, axon_layer_idx) + + elif op.op_name == "CONCAT": + # CONCAT doesn't fuse with RESCALE — no rescale_op consumed + layer = _convert_concat(op, tensor_to_layer, axon_layer_idx) + if layer is not None: + layers.append(layer) + out_name = op.output_tensors[0].name + tensor_to_layer[out_name] = axon_layer_idx + axon_layer_idx += 1 + i += 1 + continue + + # Softmax decomposition: TOSA decomposes quantized softmax into + # REDUCE_MAX → SUB → TABLE(exp) → REDUCE_SUM → TABLE(reciprocal) → MUL. + # We detect this at the REDUCE_MAX op and replace the entire chain + # with a single AXON SOFTMAX op extension (operation 100). + if (op.op_name == "REDUCE_MAX" + and activation_idx < len(activation_info) + and activation_info[activation_idx].op_type == "softmax"): + handled = _handle_softmax_pattern( + op, ops, i, layers, tensor_to_layer, tensor_to_zp, + activation_info, activation_idx, axon_layer_idx, + ) + if handled is not None: + axon_layer_idx, activation_idx, advance = handled + i += advance + continue + + # TABLE op: sigmoid/tanh activation lookup table. + # ExecuTorch's quantized sigmoid/tanh becomes a TOSA TABLE. + # We convert this into an AXON op extension (operation=101 sigmoid, + # 102 tanh) and modify the preceding layer to output INT16 q3.12. + # + # NOTE: TABLE ops also appear in softmax decompositions (one for exp, + # one for reciprocal). The softmax decomposition has a characteristic + # surrounding pattern (REDUCE_MAX/SUB before, REDUCE_SUM/MUL after); + # for now we treat any TABLE that has matching activation_info as a + # standalone activation. Softmax handling is added separately. + if op.op_name == "TABLE": + handled = _handle_table_op( + op, ops, i, layers, tensor_to_layer, tensor_to_zp, + activation_info, activation_idx, axon_layer_idx, + ) + if handled is not None: + axon_layer_idx, activation_idx, advance = handled + i += advance + continue + + # Ops that don't fuse with RESCALE — handle and continue + _no_rescale_layer = None + if op.op_name == "SLICE": + _no_rescale_layer = _convert_slice(op, tensor_to_layer, axon_layer_idx) + elif op.op_name == "PAD": + _no_rescale_layer = _convert_pad(op, tensor_to_layer, axon_layer_idx) + elif op.op_name in ("VARIABLE", "VARIABLE_READ", "VARIABLE_WRITE"): + _no_rescale_layer = _convert_persistent_var(op, tensor_to_layer, axon_layer_idx) + + if _no_rescale_layer is not None: + layers.append(_no_rescale_layer) + if op.output_tensors: + tensor_to_layer[op.output_tensors[0].name] = axon_layer_idx + axon_layer_idx += 1 + i += 1 + continue + + if layer is not None: + # Detect ReLU from output_zp == -128 (fused into quantization) + if layer.output_zero_point == -128 and layer.activation == AxonActivation.DISABLED: + layer.activation = AxonActivation.RELU + logger.debug(f" Detected ReLU from output_zp=-128 on layer {axon_layer_idx}") + + # Validate against AXON hardware constraints + constraint_warnings = _validate_axon_layer(layer, axon_layer_idx) + for w in constraint_warnings: + logger.warning(f"AXON CONSTRAINT: {w} — may fail on hardware or fall back to CPU") + + if logger.isEnabledFor(logging.DEBUG): + shift_vals = np.frombuffer(layer.shift_data, dtype=np.int8).tolist() if layer.shift_data else [] + mult_vals = np.frombuffer(layer.multiplier_data, dtype=np.int32).tolist() if layer.multiplier_data else [] + act_names = {0: "", 1: " ReLU", 2: " Softmax", 3: " LeakyReLU"} + logger.debug(f" Layer {axon_layer_idx} ({op.op_name}): shift={shift_vals}, mult={mult_vals}, " + f"cnt={layer.scale_shift_cnt}, in_zp={layer.input_zero_point}, " + f"out_zp={layer.output_zero_point}{act_names.get(layer.activation, '')}") + layers.append(layer) + if rescale_op: + out_name = rescale_op.output_tensors[0].name + i += 2 + else: + out_name = op.output_tensors[0].name + i += 1 + tensor_to_layer[out_name] = axon_layer_idx + # Track output zero point for downstream ops (MAX_POOL, etc.) + tensor_to_zp[out_name] = layer.output_zero_point + axon_layer_idx += 1 + continue + + logger.warning(f"Skipping unsupported TOSA op: {op.op_name}") + i += 1 + + # Nordic keeps output_zp on ALL layers (verified from binary). + # No clearing needed. + + return layers + + + +def pack_intermediate_binary( + layers: list[AxonLayer], + model_name: str = "model", + interlayer_buffer_size: int = 125000, + psum_buffer_size: int = 4096, +) -> bytes: + """Pack AXON layers into the intermediate binary format. + + The binary format (nrf_axon_nn_model_desc_hdr_s) is: + Header: 6 × bin_item_s (offset, length pairs) + Title string: "AXON_INTERMEDIATE_REPRESENTATION_FILE" + Version string + Meta info: nrf_axon_nn_model_meta_info_s + Layer descriptors: nrf_axon_nn_model_layer_desc_s[] + Constants: weights, biases, multipliers, shifts (concatenated) + Compilation options: nrf_axon_nn_model_compilation_options_s + + Each pointer field in layer_desc_s uses the offset union member, + pointing into the constants section. + """ + # Build constants section, tracking offsets + consts = bytearray() + layer_const_offsets = [] # Per layer: (filter_off, bias_off, mult_off, shift_off) + + for layer in layers: + offsets = {} + if layer.filter_data: + offsets["filter"] = len(consts) + consts.extend(layer.filter_data) + # Align to 4 bytes + while len(consts) % 4 != 0: + consts.append(0) + + if layer.bias_data: + offsets["bias"] = len(consts) + consts.extend(layer.bias_data) + while len(consts) % 4 != 0: + consts.append(0) + + if layer.multiplier_data: + offsets["multiplier"] = len(consts) + consts.extend(layer.multiplier_data) + while len(consts) % 4 != 0: + consts.append(0) + + if layer.shift_data: + offsets["shift"] = len(consts) + consts.extend(layer.shift_data) + while len(consts) % 4 != 0: + consts.append(0) + + layer_const_offsets.append(offsets) + + # Build layer descriptors section + # sizeof(nrf_axon_nn_model_layer_desc_s) — we need to match the C struct exactly + # This is the tricky part — struct layout depends on alignment and platform. + # For now, build a simplified version and validate against Nordic's executor. + layers_bin = bytearray() + for i, layer in enumerate(layers): + offsets = layer_const_offsets[i] + layer_bin = _pack_layer_desc(layer, offsets) + layers_bin.extend(layer_bin) + + # Build meta info + model_name_bytes = model_name.encode("utf-8") + b"\x00" + meta_bin = _pack_meta_info(len(layers), model_name_bytes) + + # Build compilation options + options_bin = _pack_compilation_options(interlayer_buffer_size, psum_buffer_size) + + # Build version string + version_str = b"1.0.0\x00" + + # Now build the header and assemble the file + # Header is 6 × bin_item_s = 6 × 8 = 48 bytes + header_size = 48 + title_str = b"AXON_INTERMEDIATE_REPRESENTATION_FILE\x00" + + # Calculate offsets for each section + title_offset = header_size + version_offset = title_offset + len(title_str) + meta_offset = version_offset + len(version_str) + # Align meta to 4 bytes + while (meta_offset) % 4 != 0: + meta_offset += 1 + model_name_offset = meta_offset + len(meta_bin) + layers_offset = model_name_offset + len(model_name_bytes) + # Align layers to 4 bytes + while (layers_offset) % 4 != 0: + layers_offset += 1 + consts_offset = layers_offset + len(layers_bin) + # Align consts to 4 bytes + while (consts_offset) % 4 != 0: + consts_offset += 1 + options_offset = consts_offset + len(consts) + while (options_offset) % 4 != 0: + options_offset += 1 + + # Pack header (6 × nrf_axon_nn_model_bin_item_s) + header = struct.pack(" bytes: + """Pack a single nrf_axon_nn_model_layer_desc_s. + + WARNING: This struct packing must exactly match the C struct layout. + The struct is complex with unions and alignment. This is a best-effort + implementation — will need validation against Nordic's executor output. + """ + buf = bytearray() + + # input_id_cnt (uint8) + buf.append(len(layer.input_ids)) + + # padding for alignment (3 bytes to align input_ids to 2-byte boundary) + buf.append(0) + + # input_ids[4] (int16 × 4) + for j in range(4): + if j < len(layer.input_ids): + buf.extend(struct.pack(" bytes: + """Pack nrf_axon_nn_model_meta_info_s.""" + buf = bytearray() + # model_name: bin_item_s (offset relative to meta section start, length) + # The name follows immediately after the meta struct + meta_struct_size = 8 + 8 + 4 + 6 + 6 # approximate + buf.extend(struct.pack(" bytes: + """Pack nrf_axon_nn_model_compilation_options_s. + + psum_buffer_placement: 0=shared, 1=DEDICATED_MEM (required for Conv/Pool + to avoid PSUM/FILTER overlap). + """ + buf = struct.pack(" dict: + """Full compilation pipeline: TOSA → AXON command buffers. + + Args: + tosa_flatbuffer: Serialized TOSA graph bytes. + sdk_edge_ai_path: Path to sdk-edge-ai repo. + model_name: Name for the compiled model. + output_dir: Directory for compiler output. Uses tempdir if None. + + Returns: + Dict with compilation results and paths to output files. + """ + # 1. Parse TOSA + graph = parse_tosa_flatbuffer(tosa_flatbuffer) + logger.info(f"Parsed TOSA graph: {len(graph.tensors)} tensors, {len(graph.operators)} operators") + + # 2. Convert to AXON layers + layers = tosa_to_axon_layers(graph) + logger.info(f"Converted to {len(layers)} AXON layers") + + op_names = { + 0: "FC", 1: "CONV2D", 2: "DW_CONV2D", 3: "PW_CONV2D", + 4: "AVG_POOL", 5: "MAX_POOL", 6: "ADD", 7: "CH_PAD", + 8: "PERSIST_VAR", 9: "CONCAT", 10: "SLICE", 11: "MUL", 12: "MEAN", + } + for i, layer in enumerate(layers): + in_desc = f"in={layer.input_dimensions[0].channel_cnt}" if layer.input_dimensions else "in=?" + logger.info(f" Layer {i}: {op_names.get(layer.operation, '?')} " + f"{in_desc} " + f"out={layer.output_dimensions.channel_cnt} " + f"weights={len(layer.filter_data)}B") + + # 3. Pack intermediate binary + binary = pack_intermediate_binary(layers, model_name) + logger.info(f"Packed intermediate binary: {len(binary)} bytes") + + # 4. Write binary file and call compiler + if output_dir is None: + output_dir = tempfile.mkdtemp(prefix="axon_compile_") + + bin_path = os.path.join(output_dir, f"{model_name}.bin") + with open(bin_path, "wb") as f: + f.write(binary) + + logger.info(f"Wrote intermediate binary to {bin_path}") + + # 5. Call Nordic compiler lib + import platform + system = platform.system() + if system == "Linux": + lib_name = "libnrf-axon-nn-compiler-lib-amd64.so" + elif system == "Darwin": + lib_name = "libnrf-axon-nn-compiler-lib-arm64.dylib" + else: + lib_name = "nrf-axon-nn-compiler-lib-amd64.dll" + + compiler_lib_path = os.path.join( + sdk_edge_ai_path, "tools", "axon", "compiler", "bin", system, lib_name + ) + + if not os.path.exists(compiler_lib_path): + logger.warning(f"AXON compiler lib not found at {compiler_lib_path}") + logger.warning("Skipping compilation — intermediate binary saved for manual compilation") + return { + "binary_path": bin_path, + "layers": len(layers), + "compiled": False, + } + + output_prefix = os.path.join(output_dir, f"nrf_axon_model_{model_name}") + result = _call_compiler_lib(compiler_lib_path, bin_path, output_prefix) + + return { + "binary_path": bin_path, + "output_prefix": output_prefix, + "layers": len(layers), + "compiled": result == 0, + "return_code": result, + } + + +def _call_compiler_lib(compiler_lib_path: str, bin_path: str, output_prefix: str) -> int: + """Call Nordic's nrf_axon_compile_model() via ctypes. + + Args: + compiler_lib_path: Path to the shared library. + bin_path: Path to the intermediate binary file. + output_prefix: Output file prefix for compiled header. + + Returns: + Compiler return code (0 = success). + + Raises: + OSError: If the compiler library cannot be loaded. + """ + logger.info("Loading compiler lib: %s", compiler_lib_path) + try: + lib = ctypes.CDLL(compiler_lib_path) + except OSError as e: + logger.error( + "Failed to load AXON compiler library: %s\n" + " Path: %s\n" + " Is SDK_EDGE_AI_PATH set correctly?", + e, compiler_lib_path, + ) + raise + + # Build command-line arguments (matches Nordic's compiler CLI) + args = [ + f"-c{compiler_lib_path}", + f"-b{bin_path}", + f"-f{output_prefix}", + ] + + # Convert to ctypes + argc = len(args) + argv_type = ctypes.c_char_p * argc + argv = argv_type(*[a.encode("utf-8") for a in args]) + + # Return struct: 5 × uint32 (model_const_size, interlayer_buf, + # psum_buf, cmd_buf_size, profiling_ticks) + return_type = ctypes.c_uint32 * 5 + return_buf = return_type() + + lib.nrf_axon_compile_model.argtypes = [ + ctypes.c_int, + ctypes.POINTER(argv_type), + ctypes.POINTER(return_type), + ] + lib.nrf_axon_compile_model.restype = ctypes.c_int + + logger.info("Calling nrf_axon_compile_model with args: %s", args) + result = lib.nrf_axon_compile_model( + argc, ctypes.pointer(argv), ctypes.pointer(return_buf) + ) + + if result == 0: + logger.info("Compilation successful") + logger.info(" Model const size: %d bytes", return_buf[0]) + logger.info(" Interlayer buffer: %d bytes", return_buf[1]) + logger.info(" PSUM buffer: %d bytes", return_buf[2]) + logger.info(" Command buffer: %d bytes", return_buf[3]) + else: + logger.error("Compilation failed with code %d", result) + + return result diff --git a/backends/nordic/axon_converters.py b/backends/nordic/axon_converters.py new file mode 100644 index 00000000000..62c38d95bba --- /dev/null +++ b/backends/nordic/axon_converters.py @@ -0,0 +1,1498 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""TOSA to AXON per-operation converters. + +Each ``_convert_*`` function handles one TOSA operation type and +produces the corresponding AXON layer descriptor(s). Called by +``tosa_to_axon_layers()`` in ``axon_compiler.py``. +""" +from __future__ import annotations + +import logging +import struct + +import numpy as np + +from .axon.compile_spec import ( + AXON_MAX_CONV2D_FILTER, + AXON_MAX_CONV_STRIDE, + AXON_MAX_FC_INPUT, + AXON_MAX_FC_OUTPUT, + AXON_MAX_POOL_FILTER, + AXON_MAX_TENSOR_DIM, +) +from .axon_types import ( + ActivationQuantInfo, + AxonActivation, + AxonByteWidth, + AxonDimensions, + AxonLayer, + AxonOp, +) +from .tosa_reader import TosaGraph, TosaOperator, TosaTensor + +logger = logging.getLogger(__name__) + +def _handle_table_op( + op: TosaOperator, + ops: list[TosaOperator], + op_idx: int, + layers: list[AxonLayer], + tensor_to_layer: dict[str, int], + tensor_to_zp: dict[str, int], + activation_info: list["ActivationQuantInfo"], + activation_idx: int, + axon_layer_idx: int, +) -> tuple[int, int, int] | None: + """Convert a TOSA TABLE op into an AXON op extension layer. + + Modifies the preceding AXON layer to output INT16 q3.12 (sigmoid/tanh) + and adds an op extension layer (operation 101 or 102) that runs + sigmoid/tanh on the q3.12 input. + + Returns: + (new_axon_layer_idx, new_activation_idx, ops_to_advance) on success, + or None if this TABLE could not be matched / handled. + """ + if activation_idx >= len(activation_info): + return None + + info = activation_info[activation_idx] + if info.op_type not in ("sigmoid", "tanh"): + # Softmax TABLE is part of a multi-op decomposition; handled elsewhere. + return None + + if not op.input_tensors or not op.output_tensors: + return None + + in_name = op.input_tensors[0].name + out_name = op.output_tensors[0].name + if in_name not in tensor_to_layer: + logger.warning(f" TABLE input tensor not produced by an AXON layer; skipping") + return None + + prev_idx = tensor_to_layer[in_name] + prev_layer = layers[prev_idx] + + # ──────────────────────────────────────────────────────────── + # 1. Modify the preceding layer to output INT16 q3.12. + # + # The preceding layer's RESCALE currently encodes: + # int8_out = (acc * mult >> shift) + zp_int8 + # mult/2^shift ≈ S_in_act / 1 (per Nordic) + # but in our pipeline mult/shift currently encodes + # S_in_acc * S_w / S_int8_out + # i.e. acc → INT8. + # + # For q3.12 we need: q3_12 = float_value * 2^12. + # Where float_value = acc * S_in_acc * S_w = (acc * mult/2^shift) * S_int8_out. + # So new_factor = old_factor * S_int8_out * 2^12. + # ──────────────────────────────────────────────────────────── + if not prev_layer.multiplier_data or not prev_layer.shift_data: + logger.warning(f" Preceding layer has no rescale; cannot retarget to q3.12") + return None + + prev_mult = np.frombuffer(prev_layer.multiplier_data, dtype=np.int32).copy() + prev_shift = np.frombuffer(prev_layer.shift_data, dtype=np.int8).copy() + + s_int8_out = info.input_scale # output scale of the preceding INT8 RESCALE + new_mult = np.zeros_like(prev_mult) + new_shift = np.zeros_like(prev_shift) + for ch in range(len(prev_mult)): + old_factor = float(prev_mult[ch]) / (2.0 ** int(prev_shift[ch])) + # ×4096 because q3.12 stores float_value * 2^12 + q312_scale = old_factor * s_int8_out * 4096.0 + # Nordic limits the q3.12 preceding-layer rescale to max_shift=28. + m, s = _optimized_scaling_shift(q312_scale, output_zp=0, + min_shift=8, max_shift=28, bit_limit=31) + new_mult[ch] = m + new_shift[ch] = s + + prev_layer.multiplier_data = new_mult.tobytes() + prev_layer.shift_data = new_shift.tobytes() + # The preceding layer's INT16 q3.12 output has zp=0 by definition. + prev_layer.output_zero_point = 0 + prev_layer.output_dimensions.byte_width = AxonByteWidth.INT16 + + logger.debug( + f" Retargeted layer {prev_idx} to INT16 q3.12 for {info.op_type}: " + f"new mult[0]={int(new_mult[0])} shift[0]={int(new_shift[0])}" + ) + + # ──────────────────────────────────────────────────────────── + # 2. Build the op extension layer (sigmoid=101, tanh=102). + # + # The op extension takes a q3.12 INT16 input, applies the function, + # and produces an INT8 output. The layer's mult/shift encode 1/S_out + # so that: int8_out = (float_result * mult >> shift) + zp_out. + # ──────────────────────────────────────────────────────────── + op_enum = AxonOp.SIGMOID if info.op_type == "sigmoid" else AxonOp.TANH + out_scale = info.output_scale + out_zp = info.output_zp + + inv_scale = 1.0 / out_scale if out_scale > 0 else 0.0 + ext_mult, ext_shift = _optimized_scaling_shift( + inv_scale, output_zp=out_zp, + min_shift=8, max_shift=31, bit_limit=31, + ) + + # Dimensions: same as preceding layer's output (which is now INT16 q3.12). + # AXON op extensions are element-wise — no spatial change. + out_h = prev_layer.output_dimensions.height + out_w = prev_layer.output_dimensions.width + out_c = prev_layer.output_dimensions.channel_cnt + + ext_layer = AxonLayer( + input_ids=[prev_idx], + operation=op_enum, + input_dimensions=[AxonDimensions( + height=out_h, width=out_w, channel_cnt=out_c, + byte_width=AxonByteWidth.INT16, + )], + output_dimensions=AxonDimensions( + height=out_h, width=out_w, channel_cnt=out_c, + byte_width=AxonByteWidth.INT8, + ), + input_zero_point=0, + output_zero_point=out_zp, + activation=AxonActivation.DISABLED, + multiplier_data=np.array([ext_mult], dtype=np.int32).tobytes(), + shift_data=np.array([ext_shift], dtype=np.int8).tobytes(), + scale_shift_cnt=1, + ) + layers.append(ext_layer) + tensor_to_layer[out_name] = axon_layer_idx + tensor_to_zp[out_name] = out_zp + + logger.debug( + f" Layer {axon_layer_idx} ({info.op_type.upper()} ext): " + f"mult={ext_mult} shift={ext_shift} out_zp={out_zp}" + ) + + return (axon_layer_idx + 1, activation_idx + 1, 1) + + +# TOSA ops that appear in the quantized softmax decomposition. +# A softmax (with stable variant) lowers to roughly: +# RESCALE → REDUCE_MAX → RESCALE → SUB → RESCALE → TABLE(exp) +# → RESCALE → RESCALE → REDUCE_SUM → RESCALE → TABLE(reciprocal) +# → RESCALE → MUL → RESCALE +_SOFTMAX_DECOMP_OPS = frozenset({ + "RESCALE", "REDUCE_MAX", "REDUCE_SUM", "SUB", "TABLE", + "MUL", "EXP", "RECIPROCAL", "RESHAPE", +}) + + +def _handle_softmax_pattern( + start_op: TosaOperator, + ops: list[TosaOperator], + start_idx: int, + layers: list[AxonLayer], + tensor_to_layer: dict[str, int], + tensor_to_zp: dict[str, int], + activation_info: list["ActivationQuantInfo"], + activation_idx: int, + axon_layer_idx: int, +) -> tuple[int, int, int] | None: + """Replace a TOSA softmax decomposition with a single AXON op extension. + + Modifies the preceding layer to output INT32 q11.12 with PREPARE_SOFTMAX + activation, then adds an op extension layer (operation=100) that runs + softmax on-device via nrf_axon_nn_op_extension_softmax(). + + Returns: + (new_axon_layer_idx, new_activation_idx, ops_to_advance) on success. + """ + info = activation_info[activation_idx] + + # The REDUCE_MAX's input tensor (possibly through standalone RESCALEs that + # we already skipped/passed through) should map to the preceding AXON layer. + in_name = start_op.input_tensors[0].name + if in_name not in tensor_to_layer: + logger.warning(" Softmax: REDUCE_MAX input not produced by an AXON layer") + return None + + prev_idx = tensor_to_layer[in_name] + prev_layer = layers[prev_idx] + if not prev_layer.multiplier_data or not prev_layer.shift_data: + logger.warning(" Softmax: preceding layer has no rescale; cannot retarget") + return None + + # Walk forward to find the end of the softmax pattern. The chain should + # consist entirely of softmax-decomposition ops; the end is the last + # MUL plus any trailing RESCALE. + last_mul_idx = None + end_idx = start_idx + for k in range(start_idx, len(ops)): + if ops[k].op_name not in _SOFTMAX_DECOMP_OPS: + break + end_idx = k + if ops[k].op_name == "MUL": + last_mul_idx = k + + if last_mul_idx is None: + logger.warning(" Softmax: no MUL found in expected decomposition pattern") + return None + + # Include the trailing RESCALE if present. + if last_mul_idx + 1 < len(ops) and ops[last_mul_idx + 1].op_name == "RESCALE": + end_idx = last_mul_idx + 1 + else: + end_idx = last_mul_idx + + final_op = ops[end_idx] + if not final_op.output_tensors: + return None + final_tensor_name = final_op.output_tensors[0].name + + # ──────────────────────────────────────────────────────────── + # 1. Retarget preceding layer to INT32 q11.12 with PREPARE_SOFTMAX. + # + # Same q-format math as sigmoid/tanh (×4096 = 2^12), but with INT32 + # output to give 11 integer bits of headroom for the unnormalised + # softmax inputs. Nordic uses a data-dependent scaleshift_max_range + # (31 - bits_needed(input_range)) for softmax; we use 28 like the + # sigmoid/tanh case as a safe upper bound. + # ──────────────────────────────────────────────────────────── + prev_mult = np.frombuffer(prev_layer.multiplier_data, dtype=np.int32).copy() + prev_shift = np.frombuffer(prev_layer.shift_data, dtype=np.int8).copy() + + s_int8_out = info.input_scale + new_mult = np.zeros_like(prev_mult) + new_shift = np.zeros_like(prev_shift) + for ch in range(len(prev_mult)): + old_factor = float(prev_mult[ch]) / (2.0 ** int(prev_shift[ch])) + q11_12_scale = old_factor * s_int8_out * 4096.0 + m, s = _optimized_scaling_shift(q11_12_scale, output_zp=0, + min_shift=8, max_shift=28, bit_limit=31) + new_mult[ch] = m + new_shift[ch] = s + + prev_layer.multiplier_data = new_mult.tobytes() + prev_layer.shift_data = new_shift.tobytes() + prev_layer.output_zero_point = 0 + prev_layer.output_dimensions.byte_width = AxonByteWidth.INT32 + prev_layer.activation = AxonActivation.PREPARE_SOFTMAX + + logger.debug( + f" Retargeted layer {prev_idx} to INT32 q11.12 PREPARE_SOFTMAX: " + f"new mult[0]={int(new_mult[0])} shift[0]={int(new_shift[0])}" + ) + + # ──────────────────────────────────────────────────────────── + # 2. Build the SOFTMAX op extension layer (operation 100). + # ──────────────────────────────────────────────────────────── + out_scale = info.output_scale + out_zp = info.output_zp + inv_scale = 1.0 / out_scale if out_scale > 0 else 0.0 + ext_mult, ext_shift = _optimized_scaling_shift( + inv_scale, output_zp=out_zp, + min_shift=8, max_shift=31, bit_limit=31, + ) + + out_h = prev_layer.output_dimensions.height + out_w = prev_layer.output_dimensions.width + out_c = prev_layer.output_dimensions.channel_cnt + + ext_layer = AxonLayer( + input_ids=[prev_idx], + operation=AxonOp.SOFTMAX, + input_dimensions=[AxonDimensions( + height=out_h, width=out_w, channel_cnt=out_c, + byte_width=AxonByteWidth.INT32, + )], + output_dimensions=AxonDimensions( + height=out_h, width=out_w, channel_cnt=out_c, + byte_width=AxonByteWidth.INT8, + ), + input_zero_point=0, + output_zero_point=out_zp, + activation=AxonActivation.DISABLED, + multiplier_data=np.array([ext_mult], dtype=np.int32).tobytes(), + shift_data=np.array([ext_shift], dtype=np.int8).tobytes(), + scale_shift_cnt=1, + ) + layers.append(ext_layer) + tensor_to_layer[final_tensor_name] = axon_layer_idx + tensor_to_zp[final_tensor_name] = out_zp + + logger.debug( + f" Layer {axon_layer_idx} (SOFTMAX ext): " + f"mult={ext_mult} shift={ext_shift} out_zp={out_zp}, " + f"replaced TOSA[{start_idx}..{end_idx}]" + ) + + # Skip ahead past the entire softmax decomposition. + advance = (end_idx - start_idx) + 1 + return (axon_layer_idx + 1, activation_idx + 1, advance) + + +def _optimized_scaling_shift(scale: float, output_zp: int, + min_shift: int = 8, max_shift: int = 30, + bit_limit: int = 31) -> tuple[int, int]: + """Find optimal multiplier and shift for a given scale. + + Matches Nordic's optimized_ip_scaling_shift algorithm: + - Searches shift range [min_shift, max_shift) + - Ensures abs(mult) < 2^bit_limit + - Ensures abs(output_zp * 2^shift) < 2^bit_limit + - Picks highest shift that satisfies constraints (best precision) + + Returns: + (multiplier, shift) + """ + best_shift = min_shift + for s in range(min_shift, max_shift): + m = abs(int(np.round(scale * (2 ** s)))) + zp_scaled = abs(int(np.round(output_zp * (2 ** s)))) + if m < (1 << bit_limit) and zp_scaled < (1 << bit_limit): + best_shift = s + else: + break + mult = abs(int(np.round(scale * (2 ** best_shift)))) + return mult, best_shift + + +def _extract_rescale_params( + rescale_op: TosaOperator, +) -> tuple[int, bytes, bytes, int]: + """Extract quantization parameters from a TOSA RESCALE op. + + Recovers the floating-point scale from TOSA's mult/shift, then + recomputes optimal mult/shift using Nordic's algorithm (range [8, 30), + bit_limit=31). This is critical — TOSA's raw shift values (32-34) + are out of AXON's effective range and cause output_zp to be ignored. + + Returns: + (output_zp, multiplier_data, shift_data, scale_shift_cnt) + """ + mult_tensor = rescale_op.input_tensors[1] + shift_tensor = rescale_op.input_tensors[2] + rescale_out_zp_tensor = rescale_op.input_tensors[4] + + output_zp = 0 + if rescale_out_zp_tensor.data is not None: + output_zp = int(rescale_out_zp_tensor.data.flat[0]) + + multiplier_data = b"" + shift_data = b"" + scale_shift_cnt = 0 + + if mult_tensor.raw_bytes and shift_tensor.data is not None: + # TOSA stores multiplier as INT16 dtype but raw bytes are INT32 + tosa_mult = np.frombuffer(mult_tensor.raw_bytes, dtype=np.int32) + tosa_shift = shift_tensor.data.flatten().astype(np.int32) + + num_channels = len(tosa_mult) + + if num_channels > 1: + logger.debug(f" Per-channel RESCALE: {num_channels} channels") + + # Recover floating-point scale from TOSA's mult/shift: + # scale = tosa_mult / 2^tosa_shift + # Then recompute with Nordic's algorithm + new_mult = np.zeros(num_channels, dtype=np.int32) + new_shift = np.zeros(num_channels, dtype=np.int8) + + for ch in range(num_channels): + # Recover the original floating-point scale + tm = int(tosa_mult[ch]) + ts = int(tosa_shift[ch]) + if ts > 0 and tm != 0: + scale = tm / (2.0 ** ts) + else: + scale = 0.0 + + # Recompute with Nordic's optimal algorithm + m, s = _optimized_scaling_shift(scale, output_zp) + new_mult[ch] = m + new_shift[ch] = s + + if num_channels == 1: + logger.debug(f" RESCALE recomputed: tosa_mult={int(tosa_mult[0])}/tosa_shift={int(tosa_shift[0])} " + f"→ mult={int(new_mult[0])}/shift={int(new_shift[0])} " + f"(scale={tosa_mult[0]/(2.0**tosa_shift[0]):.8f})") + + multiplier_data = new_mult.tobytes() + shift_data = new_shift.tobytes() + scale_shift_cnt = num_channels + + elif mult_tensor.raw_bytes: + mult_int32 = np.frombuffer(mult_tensor.raw_bytes, dtype=np.int32) + multiplier_data = mult_int32.tobytes() + scale_shift_cnt = len(mult_int32) + elif shift_tensor.data is not None: + shift_data = shift_tensor.data.astype(np.int8).tobytes() + scale_shift_cnt = len(shift_data) + + return output_zp, multiplier_data, shift_data, scale_shift_cnt + + +def _resolve_input_id(tensor: TosaTensor, tensor_to_layer: dict[str, int]) -> int: + """Look up which AXON layer produced a tensor, or -1 for graph input.""" + return tensor_to_layer.get(tensor.name, -1) + + +def _validate_axon_layer(layer: AxonLayer, layer_idx: int) -> list[str]: + """Validate an AXON layer against hardware constraints. + + Returns list of warning strings. Empty = all good. + """ + warnings = [] + op = layer.operation + + # Check tensor dimension limits + for i, dim in enumerate(layer.input_dimensions): + if dim.height > AXON_MAX_TENSOR_DIM: + warnings.append(f"Layer {layer_idx}: input[{i}] height {dim.height} > max {AXON_MAX_TENSOR_DIM}") + if dim.width > AXON_MAX_TENSOR_DIM: + warnings.append(f"Layer {layer_idx}: input[{i}] width {dim.width} > max {AXON_MAX_TENSOR_DIM}") + if dim.channel_cnt > AXON_MAX_TENSOR_DIM: + warnings.append(f"Layer {layer_idx}: input[{i}] channels {dim.channel_cnt} > max {AXON_MAX_TENSOR_DIM}") + + od = layer.output_dimensions + if od.height > AXON_MAX_TENSOR_DIM: + warnings.append(f"Layer {layer_idx}: output height {od.height} > max {AXON_MAX_TENSOR_DIM}") + if od.width > AXON_MAX_TENSOR_DIM: + warnings.append(f"Layer {layer_idx}: output width {od.width} > max {AXON_MAX_TENSOR_DIM}") + if od.channel_cnt > AXON_MAX_TENSOR_DIM: + warnings.append(f"Layer {layer_idx}: output channels {od.channel_cnt} > max {AXON_MAX_TENSOR_DIM}") + + # Op-specific constraints + if op == AxonOp.FULLY_CONNECTED: + fd = layer.filter_dimensions + in_size = fd.width # width = in_features for FC + out_size = fd.height # height = out_features for FC + if in_size > AXON_MAX_FC_INPUT: + warnings.append(f"Layer {layer_idx}: FC input size {in_size} > max {AXON_MAX_FC_INPUT}") + if out_size > AXON_MAX_FC_OUTPUT: + warnings.append(f"Layer {layer_idx}: FC output size {out_size} > max {AXON_MAX_FC_OUTPUT}") + + elif op in (AxonOp.CONV2D, AxonOp.DEPTHWISE_CONV2D, AxonOp.POINTWISE_CONV2D): + fd = layer.filter_dimensions + if fd.height > AXON_MAX_CONV2D_FILTER: + warnings.append(f"Layer {layer_idx}: conv filter height {fd.height} > max {AXON_MAX_CONV2D_FILTER}") + if fd.width > AXON_MAX_CONV2D_FILTER: + warnings.append(f"Layer {layer_idx}: conv filter width {fd.width} > max {AXON_MAX_CONV2D_FILTER}") + if layer.stride_x > AXON_MAX_CONV_STRIDE: + warnings.append(f"Layer {layer_idx}: conv stride_x {layer.stride_x} > max {AXON_MAX_CONV_STRIDE}") + if layer.stride_y > AXON_MAX_CONV_STRIDE: + warnings.append(f"Layer {layer_idx}: conv stride_y {layer.stride_y} > max {AXON_MAX_CONV_STRIDE}") + + elif op in (AxonOp.AVERAGE_POOLING, AxonOp.MAX_POOLING): + fd = layer.filter_dimensions + if fd.height > AXON_MAX_POOL_FILTER: + warnings.append(f"Layer {layer_idx}: pool filter height {fd.height} > max {AXON_MAX_POOL_FILTER}") + if fd.width > AXON_MAX_POOL_FILTER: + warnings.append(f"Layer {layer_idx}: pool filter width {fd.width} > max {AXON_MAX_POOL_FILTER}") + + return warnings + + +def _create_rescale_layer( + rescale_op: TosaOperator, + tensor_to_layer: dict[str, int], + layer_idx: int, +) -> AxonLayer: + """Create an AXON layer for a standalone RESCALE operation. + + Standalone RESCALEs requantize tensors between different quantization + domains (e.g., before ADD inputs need matching scales). We implement + this as a POINTWISE_CONV2D with identity weights — each output channel + copies the corresponding input channel, with the RESCALE mult/shift + applied as the output requantization. + + RESCALE inputs: [data, multiplier, shift, input_zp, output_zp] + """ + input_tensor = rescale_op.input_tensors[0] + in_shape = input_tensor.shape # [N, H, W, C] or [N, C] + + # Determine spatial dims and channels + if len(in_shape) == 4: + h, w, c = in_shape[1], in_shape[2], in_shape[3] + elif len(in_shape) == 3: + h, w, c = in_shape[0], in_shape[1], in_shape[2] + elif len(in_shape) == 2: + h, w, c = 1, 1, in_shape[1] + else: + h, w, c = 1, 1, in_shape[0] if in_shape else 1 + + # Identity weights: C→C pointwise conv (1x1 kernel, identity per channel) + identity = np.eye(c, dtype=np.int8) # [C, C] + # AXON expects OIHW: [out_channels, in_channels, 1, 1] + identity = identity.reshape(c, c, 1, 1) + filter_data = identity.tobytes() + + # Input zero point from RESCALE + input_zp = 0 + if len(rescale_op.input_tensors) > 3 and rescale_op.input_tensors[3].data is not None: + input_zp = int(rescale_op.input_tensors[3].data.flat[0]) + + # Bias prime: identity conv weights have sum=1 per channel + # b_prime[ch] = 0 + (-1 * input_zp) = -input_zp + if input_zp != 0: + bias_int32 = np.full(c, -input_zp, dtype=np.int32) + else: + bias_int32 = np.zeros(c, dtype=np.int32) + bias_data = bias_int32.tobytes() + + # Rescale params (recomputed from TOSA's mult/shift using Nordic's algorithm) + output_zp, multiplier_data, shift_data, scale_shift_cnt = _extract_rescale_params(rescale_op) + + input_id = _resolve_input_id(input_tensor, tensor_to_layer) + + # Use FC for 1x1 spatial, POINTWISE_CONV2D for spatial + if h == 1 and w == 1: + axon_op = AxonOp.FULLY_CONNECTED + input_dims = AxonDimensions(height=1, width=c, channel_cnt=1) + filter_dims = AxonDimensions(height=c, width=c, channel_cnt=1) + output_dims = AxonDimensions(height=1, width=c, channel_cnt=1) + else: + axon_op = AxonOp.POINTWISE_CONV2D + input_dims = AxonDimensions(height=h, width=w, channel_cnt=c) + filter_dims = AxonDimensions(height=1, width=1, channel_cnt=c) + output_dims = AxonDimensions(height=h, width=w, channel_cnt=c) + + logger.debug(f" Creating rescale layer: {axon_op} {h}x{w}x{c} (identity conv)") + + return AxonLayer( + input_ids=[input_id], + operation=axon_op, + input_dimensions=[input_dims], + filter_dimensions=filter_dims, + output_dimensions=output_dims, + input_zero_point=0, # b_prime handles input zp correction + output_zero_point=output_zp, + filter_data=filter_data, + bias_data=bias_data, + multiplier_data=multiplier_data, + shift_data=shift_data, + scale_shift_cnt=scale_shift_cnt, + ) + + +def _convert_conv2d( + conv_op: TosaOperator, + rescale_op: TosaOperator | None, + tensor_to_layer: dict[str, int], + layer_idx: int, + pre_flatten_shape: tuple[int, int, int] | None = None, +) -> AxonLayer: + """Convert a TOSA CONV2D (+ optional RESCALE) to an AXON layer.""" + + # CONV2D inputs: [input, weights, bias, input_zp, weight_zp] + input_tensor = conv_op.input_tensors[0] + weight_tensor = conv_op.input_tensors[1] + bias_tensor = conv_op.input_tensors[2] + input_zp_tensor = conv_op.input_tensors[3] + weight_zp_tensor = conv_op.input_tensors[4] + output_tensor = conv_op.output_tensors[0] + + # Determine AXON operation type + weight_shape = weight_tensor.shape # [O, H, W, I] for CONV2D + in_shape = input_tensor.shape # [N, H, W, C] + in_h = in_shape[1] if len(in_shape) > 1 else 1 + in_w = in_shape[2] if len(in_shape) > 2 else 1 + + if len(weight_shape) == 4 and weight_shape[1] == 1 and weight_shape[2] == 1: + if in_h == 1 and in_w == 1: + # True FC: 1×1 conv on 1×1 spatial input + axon_op = AxonOp.FULLY_CONNECTED + else: + # Pointwise conv on spatial input + axon_op = AxonOp.POINTWISE_CONV2D + else: + axon_op = AxonOp.CONV2D + + # Dimension mapping for AXON: + # AXON uses (height, width, channel_cnt) where for FC: + # input: height=1, width=input_features, channel_cnt=1 + # filter: height=1, width=output_features, channel_cnt=1 + # output: height=1, width=output_features, channel_cnt=1 + # This matches TFLite's 2D tensor convention where shape=[batch, features] + # maps to TensorShape(height=batch, width=features, depth=1). + + in_shape = input_tensor.shape # TOSA: [N, H, W, C] or [N, features] + + if axon_op == AxonOp.FULLY_CONNECTED: + # FC dimension mapping (verified from Nordic's TensorShape class): + # TFLite FC input [batch, features] has shape.size==2 → + # height=batch(1), width=features, depth=1 + # Filter [outputs, inputs] → + # height=outputs, width=inputs, depth=1 + in_features = in_shape[-1] + out_features = weight_shape[0] + + input_dims = AxonDimensions( + height=1, + width=in_features, + channel_cnt=1, + byte_width=AxonByteWidth.INT8, + ) + filter_dims = AxonDimensions( + height=out_features, + width=in_features, + channel_cnt=1, + byte_width=AxonByteWidth.INT8, + ) + output_dims = AxonDimensions( + height=1, + width=out_features, + channel_cnt=1, + byte_width=AxonByteWidth.INT8, + ) + else: + # Conv2D: standard NHWC → AXON HWC + input_dims = AxonDimensions( + height=in_shape[1] if len(in_shape) > 1 else 1, + width=in_shape[2] if len(in_shape) > 2 else 1, + channel_cnt=in_shape[3] if len(in_shape) > 3 else in_shape[-1], + byte_width=AxonByteWidth.INT8, + ) + filter_dims = AxonDimensions( + height=weight_shape[1], + width=weight_shape[2], + channel_cnt=weight_shape[0], + byte_width=AxonByteWidth.INT8, + ) + out_shape = output_tensor.shape + output_dims = AxonDimensions( + height=out_shape[1] if len(out_shape) > 1 else 1, + width=out_shape[2] if len(out_shape) > 2 else 1, + channel_cnt=out_shape[3] if len(out_shape) > 3 else out_shape[-1], + byte_width=AxonByteWidth.INT8, + ) + + # Zero points + input_zp = int(input_zp_tensor.data.flat[0]) if input_zp_tensor.data is not None else 0 + weight_zp = int(weight_zp_tensor.data.flat[0]) if weight_zp_tensor.data is not None else 0 + + # Padding, stride, dilation from TOSA attributes + attrs = conv_op.attributes + pad = attrs.get("pad", [0, 0, 0, 0]) # [top, bottom, left, right] + stride = attrs.get("stride", [1, 1]) # [height, width] + dilation = attrs.get("dilation", [1, 1]) # [height, width] + + # FC layers don't use stride/dilation — Nordic sets them to 0 + if axon_op == AxonOp.FULLY_CONNECTED: + stride = [0, 0] + dilation = [0, 0] + + # Weights: TOSA stores as [O, H, W, I], AXON expects [O, I, H, W] (NCHW/OIHW) + if weight_tensor.data is not None: + weights = weight_tensor.data.astype(np.int8) + if weights.ndim == 4: + weights = weights.transpose(0, 3, 1, 2) # OHWI → OIHW + elif weights.ndim == 3: + weights = weights.transpose(2, 0, 1) + + filter_data = weights.tobytes() + else: + filter_data = b"" + + # Bias prime for ALL layer types: b_prime = bias + (-sum(weight[ch]) * input_zp) + # Nordic uses BOTH b_prime AND input_zp in the struct simultaneously. + # Verified from Nordic's intermediate binary for FC, Conv, and Pool layers. + if bias_tensor.raw_bytes: + bias_int32 = np.frombuffer(bias_tensor.raw_bytes, dtype=np.int32).copy() + if weight_tensor.data is not None and input_zp != 0: + weights_orig = weight_tensor.data.astype(np.int32) + for ch in range(min(weights_orig.shape[0], len(bias_int32))): + kernel_sum = int(np.sum(weights_orig[ch])) + bias_int32[ch] += -kernel_sum * input_zp + bias_data = bias_int32.tobytes() + else: + bias_data = b"" + + # Quantization from RESCALE op + output_zp = 0 + multiplier_data = b"" + shift_data = b"" + scale_shift_cnt = 0 + activation = AxonActivation.DISABLED + + if rescale_op: + output_zp, multiplier_data, shift_data, scale_shift_cnt = _extract_rescale_params(rescale_op) + + input_id = _resolve_input_id(input_tensor, tensor_to_layer) + + return AxonLayer( + input_ids=[input_id], + operation=axon_op, + input_dimensions=[input_dims], + filter_dimensions=filter_dims, + output_dimensions=output_dims, + stride_x=stride[1] if len(stride) > 1 else 1, + stride_y=stride[0] if len(stride) > 0 else 1, + dilation_x=dilation[1] if len(dilation) > 1 else 1, + dilation_y=dilation[0] if len(dilation) > 0 else 1, + input_zero_point=input_zp, # Nordic uses both b_prime AND input_zp + output_zero_point=output_zp, + pad_top=pad[0] if len(pad) > 0 else 0, + pad_bottom=pad[1] if len(pad) > 1 else 0, + pad_left=pad[2] if len(pad) > 2 else 0, + pad_right=pad[3] if len(pad) > 3 else 0, + filter_data=filter_data, + bias_data=bias_data, + multiplier_data=multiplier_data, + shift_data=shift_data, + scale_shift_cnt=scale_shift_cnt, + activation=activation, + ) + + +def _convert_depthwise_conv2d( + conv_op: TosaOperator, + rescale_op: TosaOperator | None, + tensor_to_layer: dict[str, int], + layer_idx: int, + as_conv2d: bool = False, +) -> AxonLayer: + """Convert TOSA DEPTHWISE_CONV2D (+ optional RESCALE) to AXON layer. + + TOSA DEPTHWISE_CONV2D inputs: [input, weights, bias, input_zp, weight_zp] + Weights shape: [KH, KW, C_in, M] where C_out = C_in * M. + """ + input_tensor = conv_op.input_tensors[0] + weight_tensor = conv_op.input_tensors[1] + bias_tensor = conv_op.input_tensors[2] + input_zp_tensor = conv_op.input_tensors[3] + output_tensor = conv_op.output_tensors[0] + + in_shape = input_tensor.shape # [N, H, W, C_in] + out_shape = output_tensor.shape # [N, OH, OW, C_out] + weight_shape = weight_tensor.shape # [KH, KW, C_in, M] + + kh, kw = weight_shape[0], weight_shape[1] + c_in = weight_shape[2] + depth_mult = weight_shape[3] + out_channels = c_in * depth_mult + + attrs = conv_op.attributes + pad = attrs.get("pad", [0, 0, 0, 0]) + stride = attrs.get("stride", [1, 1]) + dilation = attrs.get("dilation", [1, 1]) + + input_zp = int(input_zp_tensor.data.flat[0]) if input_zp_tensor.data is not None else 0 + + # Weights transpose + if weight_tensor.data is not None: + weights = weight_tensor.data.astype(np.int8) + if as_conv2d: + # Convert DW format [KH, KW, C_in=1, M] to CONV2D format [O, I, KH, KW] + # [KH, KW, 1, M] → [M, KH, KW, 1] → [O, I, H, W] with I=1 + weights = weights.reshape(kh, kw, out_channels) # squeeze c_in=1 + weights = weights.transpose(2, 0, 1) # [O, KH, KW] + weights = weights.reshape(out_channels, 1, kh, kw) # [O, I=1, KH, KW] + # Then OHWI→OIHW: already in OIHW format + else: + # DW: TOSA [KH, KW, C_in, M] → AXON [C_out, 1, KH, KW] + weights = weights.reshape(kh, kw, out_channels) + weights = weights.transpose(2, 0, 1) # [C_out, KH, KW] + weights = weights.reshape(out_channels, 1, kh, kw) + filter_data = weights.tobytes() + else: + filter_data = b"" + + # Bias prime for DW conv (same as Conv2D — Nordic uses b_prime for all types). + if bias_tensor.raw_bytes: + bias_int32 = np.frombuffer(bias_tensor.raw_bytes, dtype=np.int32).copy() + if weight_tensor.data is not None and input_zp != 0: + weights_orig = weight_tensor.data.astype(np.int32) + for ch in range(min(out_channels, len(bias_int32))): + kernel_sum = int(np.sum(weights_orig[:, :, ch % c_in, ch // c_in])) + bias_int32[ch] += -kernel_sum * input_zp + bias_data = bias_int32.tobytes() + else: + bias_data = b"" + + # Rescale + output_zp = 0 + multiplier_data = b"" + shift_data = b"" + scale_shift_cnt = 0 + if rescale_op: + output_zp, multiplier_data, shift_data, scale_shift_cnt = _extract_rescale_params(rescale_op) + + input_dims = AxonDimensions( + height=in_shape[1], width=in_shape[2], + channel_cnt=in_shape[3] if len(in_shape) > 3 else in_shape[-1], + byte_width=AxonByteWidth.INT8, + ) + # For as_conv2d: filter_dims.c = input_channels (1), not output_channels + # Nordic Conv2D: filter c = input channels per filter + filter_dims = AxonDimensions( + height=kh, width=kw, + channel_cnt=c_in if as_conv2d else out_channels, + byte_width=AxonByteWidth.INT8, + ) + output_dims = AxonDimensions( + height=out_shape[1], width=out_shape[2], + channel_cnt=out_shape[3] if len(out_shape) > 3 else out_shape[-1], + byte_width=AxonByteWidth.INT8, + ) + + return AxonLayer( + input_ids=[_resolve_input_id(input_tensor, tensor_to_layer)], + operation=AxonOp.CONV2D if as_conv2d else AxonOp.DEPTHWISE_CONV2D, + input_dimensions=[input_dims], + filter_dimensions=filter_dims, + output_dimensions=output_dims, + stride_x=stride[1] if len(stride) > 1 else 1, + stride_y=stride[0] if len(stride) > 0 else 1, + dilation_x=dilation[1] if len(dilation) > 1 else 1, + dilation_y=dilation[0] if len(dilation) > 0 else 1, + input_zero_point=input_zp, # Nordic uses both b_prime AND input_zp + output_zero_point=output_zp, + pad_top=pad[0] if len(pad) > 0 else 0, + pad_bottom=pad[1] if len(pad) > 1 else 0, + pad_left=pad[2] if len(pad) > 2 else 0, + pad_right=pad[3] if len(pad) > 3 else 0, + filter_data=filter_data, + bias_data=bias_data, + multiplier_data=multiplier_data, + shift_data=shift_data, + scale_shift_cnt=scale_shift_cnt, + ) + + +def _convert_elementwise( + op: TosaOperator, + rescale_op: TosaOperator | None, + tensor_to_layer: dict[str, int], + layer_idx: int, + axon_op: int, + graph: TosaGraph | None = None, + tensor_to_zp: dict[str, int] | None = None, + tensor_rescale_info: dict | None = None, +) -> AxonLayer: + """Convert TOSA ADD or MUL (+ optional RESCALE) to AXON ADD2 or MULTIPLY. + + TOSA ADD inputs: [input1, input2] + TOSA MUL inputs: [input1, input2, shift] + Both support broadcasting on H/W dimensions. + + For MULTIPLY with a constant second input (scalar multiply), the constant + is stored as filter_data. Nordic's AXON MUL reads the multiplier constant + from the filter data, not from a layer input. + """ + input1 = op.input_tensors[0] + input2 = op.input_tensors[1] + output_tensor = op.output_tensors[0] + + def _dims_from_shape(shape: list[int]) -> AxonDimensions: + if len(shape) == 4: + return AxonDimensions(height=shape[1], width=shape[2], channel_cnt=shape[3]) + elif len(shape) == 3: + return AxonDimensions(height=shape[0], width=shape[1], channel_cnt=shape[2]) + elif len(shape) == 2: + # 2D tensor [batch, features] → h=1, w=features, c=1 (matches FC convention) + return AxonDimensions(height=1, width=shape[1], channel_cnt=1) + else: + return AxonDimensions(height=1, width=shape[0] if shape else 1, channel_cnt=1) + + input1_dims = _dims_from_shape(input1.shape) + input2_dims = _dims_from_shape(input2.shape) + output_dims = _dims_from_shape(output_tensor.shape) + + output_zp = 0 + multiplier_data = b"" + shift_data = b"" + scale_shift_cnt = 0 + bias_data = b"" + if rescale_op: + output_zp, multiplier_data, shift_data, scale_shift_cnt = _extract_rescale_params(rescale_op) + + input1_id = _resolve_input_id(input1, tensor_to_layer) + input2_id = _resolve_input_id(input2, tensor_to_layer) + + # For MUL with a constant second input, store the constant as filter_data. + # Nordic's AXON MUL expects the multiplier constant in the filter data. + # The filter_dimensions should be 0x0x0 (Nordic convention for MUL). + # + # The TOSA graph may have RESCALE(const) → MUL, where the constant comes + # through a RESCALE op. In that case, input2.data is None but we can + # compute the rescaled value from the RESCALE's input constant. + filter_data = b"" + filter_dims = AxonDimensions() # default 1x1x1 + if axon_op == AxonOp.MULTIPLY: + mul_const = None + if input2.data is not None: + mul_const = input2.data.astype(np.int8) + elif input2_id == -1: + # Input2 might come from a RESCALE of a constant (not tracked as a layer). + # Search the graph for a RESCALE that outputs this tensor with a constant input. + for prev_op in graph.get_non_const_operators(): + if (prev_op.op_name == "RESCALE" and prev_op.output_tensors + and prev_op.output_tensors[0].name == input2.name): + # Found the RESCALE that produces input2 + const_input = prev_op.input_tensors[0] + if const_input.data is not None: + # Compute the rescaled constant value + in_val = int(const_input.data.flat[0]) + r_mult = prev_op.input_tensors[1] + r_shift = prev_op.input_tensors[2] + r_in_zp = prev_op.input_tensors[3] + r_out_zp = prev_op.input_tensors[4] + t_mult = int(np.frombuffer(r_mult.raw_bytes, dtype=np.int32)[0]) if r_mult.raw_bytes else 0 + t_shift = int(r_shift.data.flat[0]) if r_shift.data is not None else 0 + t_in_zp = int(r_in_zp.data.flat[0]) if r_in_zp.data is not None else 0 + t_out_zp = int(r_out_zp.data.flat[0]) if r_out_zp.data is not None else 0 + if t_shift > 0 and t_mult != 0: + rescaled = round((in_val - t_in_zp) * t_mult / (2 ** t_shift)) + t_out_zp + else: + rescaled = in_val + rescaled = max(-128, min(127, rescaled)) + mul_const = np.array([rescaled], dtype=np.int8) + mul_const_zp = t_in_zp # constant's quantization zero point + logger.debug(f" MUL: computed constant from RESCALE: " + f"input={in_val} → rescaled={rescaled} (const_zp={t_in_zp})") + break + + if mul_const is not None: + filter_data = mul_const.tobytes() + # For 1D input (h=1), Nordic uses flt=0x0x0. For spatial input, use 1x1x1 (broadcast). + if input1_dims.height <= 1: + filter_dims = AxonDimensions(height=0, width=0, channel_cnt=0) + else: + filter_dims = AxonDimensions(height=1, width=1, channel_cnt=1) + # Nordic MUL with constant: input_id_cnt=1 (only activation input). + # The constant is stored as filter_data, NOT as a second layer input. + input2_id = None # will be excluded from input_ids + input_zp_for_mul = mul_const_zp if 'mul_const_zp' in dir() else 0 + logger.debug(f" MUL constant: {mul_const.flat[:4]} zp={input_zp_for_mul} stored as filter_data") + else: + logger.warning(f" MUL: no constant found for input2 — " + f"AXON MUL requires a constant multiplier") + elif axon_op == AxonOp.ADD2: + # ADD: filter_dims = input_dims (Nordic convention) + filter_dims = input1_dims + + # Nordic ADD uses TWO multipliers (one per input) + bias: + # acc = in1 * mult_a + in2 * mult_b + bias + # out = (acc >> shift) + out_zp + # + # mult_a = round(s_in1/s_out * 2^shift) + # mult_b = round(s_in2/s_out * 2^shift) + # bias = round(-(s_in1*zp_in1 + s_in2*zp_in2) / s_out * 2^shift) + # + # The scale ratios come from the skipped standalone RESCALEs. + # tensor_rescale_info[name] = (scale_ratio, input_zp) + _zp_map = tensor_to_zp or {} + _ri_map = tensor_rescale_info or {} + + zp1 = _zp_map.get(input1.name, 0) + zp2 = _zp_map.get(input2.name, 0) + ri1 = _ri_map.get(input1.name) # (scale_ratio, in_zp) or None + ri2 = _ri_map.get(input2.name) + + if ri1 and ri2 and multiplier_data: + # Both inputs have skipped rescale info — compute proper two multipliers. + # + # The skipped RESCALEs have scale_ratio = s_in / s_add_internal. + # The ADD's fused RESCALE has scale = s_add_internal / s_out. + # Net scale per input: s_in / s_out = scale_ratio * add_output_scale. + add_mult = np.frombuffer(multiplier_data, dtype=np.int32) + add_shift = np.frombuffer(shift_data, dtype=np.int8) + add_output_scale = float(add_mult[0]) / (2.0 ** int(add_shift[0])) + + net_s1 = ri1[0] * add_output_scale # s_in1 / s_out + net_s2 = ri2[0] * add_output_scale # s_in2 / s_out + + # Use Nordic's formula with bit_limit=15 + best_shift = 8 + for s in range(8, 31): + m1 = abs(int(round(net_s1 * (2 ** s)))) + m2 = abs(int(round(net_s2 * (2 ** s)))) + zp_check = abs(int(round(output_zp * (2 ** s)))) + if m1 < (1 << 15) and m2 < (1 << 15) and zp_check < (1 << 31): + best_shift = s + else: + break + + mult_a = abs(int(round(net_s1 * (2 ** best_shift)))) + mult_b = abs(int(round(net_s2 * (2 ** best_shift)))) + + # Bias: use exact float scales (Nordic formula, line 1995): + # bias = round(-(s1*zp1 + s2*zp2) / s_out * 2^shift) + # Using net_s1/net_s2 directly (which ARE s_in/s_out already): + add_bias = int(round(-(net_s1 * zp1 + net_s2 * zp2) * (2 ** best_shift))) + + multiplier_data = np.array([mult_a, mult_b], dtype=np.int32).tobytes() + shift_data = np.array([best_shift], dtype=np.int8).tobytes() + scale_shift_cnt = 1 # cnt=1 in Nordic (the shift count, not mult count) + bias_data = np.array([add_bias], dtype=np.int32).tobytes() + + logger.info(f" ADD: two multipliers: mult_a={mult_a} mult_b={mult_b} " + f"shift={best_shift} bias={add_bias} (zp1={zp1} zp2={zp2})") + else: + # No rescale info — use simple bias and scale=1.0 + add_bias = -(zp1 + zp2) + bias_data = np.array([add_bias], dtype=np.int32).tobytes() + + # Override tiny TOSA scale with 1.0 + if multiplier_data: + mult_arr = np.frombuffer(multiplier_data, dtype=np.int32) + shift_arr = np.frombuffer(shift_data, dtype=np.int8) + scale = float(mult_arr[0]) / (2.0 ** int(shift_arr[0])) + if scale < 0.01: + multiplier_data = np.array([16384], dtype=np.int32).tobytes() + shift_data = np.array([14], dtype=np.int8).tobytes() + + logger.debug(f" ADD: simple bias={add_bias} (zp1={zp1} zp2={zp2})") + + # For MUL with constant: input_zero_point = constant's quantization zero point + # For ADD: input_zero_point = 0 (Nordic convention) + input_zp = input_zp_for_mul if (axon_op == AxonOp.MULTIPLY and filter_data) else 0 + + # Build input_ids and dims: MUL constant is in filter_data, not a layer input + if input2_id is not None: + input_ids = [input1_id, input2_id] + input_dims_list = [input1_dims, input2_dims] + else: + input_ids = [input1_id] + input_dims_list = [input1_dims] + + return AxonLayer( + input_ids=input_ids, + operation=axon_op, + input_dimensions=input_dims_list, + filter_dimensions=filter_dims, + output_dimensions=output_dims, + stride_x=0, # Nordic uses 0 for elementwise ops + stride_y=0, + dilation_x=0, + dilation_y=0, + input_zero_point=input_zp, + output_zero_point=output_zp, + filter_data=filter_data, + bias_data=bias_data, + multiplier_data=multiplier_data, + shift_data=shift_data, + scale_shift_cnt=scale_shift_cnt, + ) + + +def _convert_pool2d( + op: TosaOperator, + rescale_op: TosaOperator | None, + tensor_to_layer: dict[str, int], + layer_idx: int, + axon_op: int, + input_zp_from_graph: int = 0, +) -> AxonLayer: + """Convert TOSA AVG_POOL2D or MAX_POOL2D to AXON pooling layer. + + AVG_POOL2D inputs: [input, input_zp, output_zp] + MAX_POOL2D inputs: [input] + Both have kernel, pad, stride attributes. + """ + input_tensor = op.input_tensors[0] + output_tensor = op.output_tensors[0] + in_shape = input_tensor.shape # [N, H, W, C] + out_shape = output_tensor.shape + + attrs = op.attributes + kernel = attrs.get("kernel", [1, 1]) # [KH, KW] + pad = attrs.get("pad", [0, 0, 0, 0]) + stride = attrs.get("stride", [1, 1]) + + # Zero points (AVG_POOL2D has input_zp/output_zp as tensor inputs) + # MAX_POOL2D has no zp tensors — use the propagated input_zp from the graph + input_zp = input_zp_from_graph if axon_op == AxonOp.MAX_POOLING else 0 + output_zp = 0 + if axon_op == AxonOp.AVERAGE_POOLING and len(op.input_tensors) >= 3: + zp_in = op.input_tensors[1] + zp_out = op.input_tensors[2] + if zp_in.data is not None: + input_zp = int(zp_in.data.flat[0]) + if zp_out.data is not None: + output_zp = int(zp_out.data.flat[0]) + + multiplier_data = b"" + shift_data = b"" + scale_shift_cnt = 0 + if rescale_op: + output_zp, multiplier_data, shift_data, scale_shift_cnt = _extract_rescale_params(rescale_op) + + # Pooling layers need valid rescale params even without explicit RESCALE. + # Nordic's approach: compute b_prime (bias) and mult/shift for the pool. + kh = kernel[0] if len(kernel) > 0 else 1 + kw = kernel[1] if len(kernel) > 1 else 1 + kernel_area = kh * kw + bias_data = b"" + + if scale_shift_cnt == 0: + if axon_op == AxonOp.MAX_POOLING: + # MAX_POOL: identity rescale. Nordic uses mult=1, shift=0 but + # shift=0 may be out of AXON's valid range [1,32]. + # Use mult=2, shift=1 as safe identity: (val * 2 + 1) >> 1 = val + multiplier_data = np.array([2], dtype=np.int32).tobytes() + shift_data = np.array([1], dtype=np.int8).tobytes() + scale_shift_cnt = 1 + # MAX_POOL preserves quantization: Nordic sets output_zp=0 + # and input_zp=actual (the preceding layer's output_zp) + output_zp = 0 + + elif axon_op == AxonOp.AVERAGE_POOLING and kernel_area > 1: + # AVG_POOL: AXON sums values, rescale divides by kernel area. + # Following Nordic's CalculateMultiplierandScaleshift: + # scale = 1/kernel_area (when input/output scales match) + # b_prime = -input_zp * kernel_area (zero-point correction for sum) + # mult/shift chosen so mult * max_accumulator < 2^31 + # Also: output_zp * 2^shift must fit in 2^bit_limit + scale = 1.0 / kernel_area + bit_limit = 25 # Nordic uses 25 for same-scale pool + + # Search for optimal shift (Nordic's optimized_ip_scaling_shift) + best_shift = 8 + for s in range(8, 31): + m = abs(round(scale * (1 << s))) + zp_scaled = abs(round(output_zp * (1 << s))) + if m < (1 << bit_limit) and zp_scaled < (1 << bit_limit): + best_shift = s + else: + break + mult_val = abs(round(scale * (1 << best_shift))) + shift_val = best_shift + + # b_prime: zero-point correction (Nordic: -input_zp * kernel_area) + b_prime = round(-input_zp * kernel_area) + bias_data = np.array([b_prime], dtype=np.int32).tobytes() + + multiplier_data = np.array([mult_val], dtype=np.int32).tobytes() + shift_data = np.array([shift_val], dtype=np.int8).tobytes() + scale_shift_cnt = 1 + + logger.debug(f" AVG_POOL: area={kernel_area} b_prime={b_prime} " + f"mult={mult_val} shift={shift_val}") + else: + # Fallback: identity rescale + multiplier_data = np.array([2], dtype=np.int32).tobytes() + shift_data = np.array([1], dtype=np.int8).tobytes() + scale_shift_cnt = 1 + + in_channels = in_shape[3] if len(in_shape) > 3 else in_shape[-1] + + input_dims = AxonDimensions( + height=in_shape[1] if len(in_shape) > 1 else 1, + width=in_shape[2] if len(in_shape) > 2 else 1, + channel_cnt=in_channels, + byte_width=AxonByteWidth.INT8, + ) + # Filter dimensions encode the kernel size for pooling + # Nordic uses channel_cnt=0 for MAX_POOL, channel_cnt for AVG_POOL/MEAN + filter_ch = 0 if axon_op == AxonOp.MAX_POOLING else in_channels + filter_dims = AxonDimensions( + height=kernel[0] if len(kernel) > 0 else 1, + width=kernel[1] if len(kernel) > 1 else 1, + channel_cnt=filter_ch, + byte_width=AxonByteWidth.INT8, + ) + output_dims = AxonDimensions( + height=out_shape[1] if len(out_shape) > 1 else 1, + width=out_shape[2] if len(out_shape) > 2 else 1, + channel_cnt=out_shape[3] if len(out_shape) > 3 else out_shape[-1], + byte_width=AxonByteWidth.INT8, + ) + + return AxonLayer( + input_ids=[_resolve_input_id(input_tensor, tensor_to_layer)], + operation=axon_op, + input_dimensions=[input_dims], + filter_dimensions=filter_dims, + output_dimensions=output_dims, + stride_x=stride[1] if len(stride) > 1 else 1, + stride_y=stride[0] if len(stride) > 0 else 1, + dilation_x=0, # Nordic uses 0 for pool dilation (not 1) + dilation_y=0, + input_zero_point=input_zp, + output_zero_point=output_zp, + pad_top=pad[0] if len(pad) > 0 else 0, + pad_bottom=pad[1] if len(pad) > 1 else 0, + pad_left=pad[2] if len(pad) > 2 else 0, + pad_right=pad[3] if len(pad) > 3 else 0, + bias_data=bias_data, + multiplier_data=multiplier_data, + shift_data=shift_data, + scale_shift_cnt=scale_shift_cnt, + ) + + +def _convert_reduce_sum( + op: TosaOperator, + rescale_op: TosaOperator | None, + tensor_to_layer: dict[str, int], + layer_idx: int, +) -> AxonLayer: + """Convert TOSA REDUCE_SUM to AXON MEAN (global average pooling). + + TOSA REDUCE_SUM inputs: [input] + Attribute: axis (which dimension to reduce). + AXON MEAN uses concatenate_axis field for the reduction axis. + """ + input_tensor = op.input_tensors[0] + output_tensor = op.output_tensors[0] + in_shape = input_tensor.shape + out_shape = output_tensor.shape + + attrs = op.attributes + tosa_axis = attrs.get("axis", 1) + + # Map TOSA NHWC axis to AXON axis enum: 0=CHANNEL, 1=HEIGHT, 2=WIDTH + # TOSA axis: 0=N, 1=H, 2=W, 3=C + axon_axis_map = {1: 1, 2: 2, 3: 0} # H→HEIGHT, W→WIDTH, C→CHANNEL + axon_axis = axon_axis_map.get(tosa_axis, 1) + + output_zp = 0 + multiplier_data = b"" + shift_data = b"" + scale_shift_cnt = 0 + if rescale_op: + output_zp, multiplier_data, shift_data, scale_shift_cnt = _extract_rescale_params(rescale_op) + + in_channels = in_shape[3] if len(in_shape) > 3 else in_shape[-1] + input_dims = AxonDimensions( + height=in_shape[1] if len(in_shape) > 1 else 1, + width=in_shape[2] if len(in_shape) > 2 else 1, + channel_cnt=in_channels, + byte_width=AxonByteWidth.INT8, + ) + output_dims = AxonDimensions( + height=out_shape[1] if len(out_shape) > 1 else 1, + width=out_shape[2] if len(out_shape) > 2 else 1, + channel_cnt=out_shape[3] if len(out_shape) > 3 else out_shape[-1], + byte_width=AxonByteWidth.INT8, + ) + + return AxonLayer( + input_ids=[_resolve_input_id(input_tensor, tensor_to_layer)], + operation=AxonOp.MEAN, + concatenate_axis=axon_axis, + input_dimensions=[input_dims], + output_dimensions=output_dims, + output_zero_point=output_zp, + multiplier_data=multiplier_data, + shift_data=shift_data, + scale_shift_cnt=scale_shift_cnt, + ) + + +def _convert_concat( + op: TosaOperator, + tensor_to_layer: dict[str, int], + layer_idx: int, +) -> AxonLayer: + """Convert TOSA CONCAT to AXON CONCATENATE layer. + + TOSA CONCAT inputs: [tensor1, tensor2, ...] (2+ tensors to concatenate) + Attribute: axis (NHWC dimension to concatenate along). + AXON supports up to 4 inputs via input_ids. + """ + output_tensor = op.output_tensors[0] + out_shape = output_tensor.shape + + attrs = op.attributes + tosa_axis = attrs.get("axis", 3) # Default: channel axis (last dim in NHWC) + + # Map TOSA NHWC axis to AXON axis enum: 0=CHANNEL, 1=HEIGHT, 2=WIDTH + # TOSA axis: 0=N, 1=H, 2=W, 3=C + axon_axis_map = {1: 1, 2: 2, 3: 0} + axon_axis = axon_axis_map.get(tosa_axis, 0) + + # Collect input IDs and dimensions + input_ids = [] + input_dims = [] + for inp in op.input_tensors: + input_ids.append(_resolve_input_id(inp, tensor_to_layer)) + shape = inp.shape + if len(shape) == 4: + input_dims.append(AxonDimensions(height=shape[1], width=shape[2], channel_cnt=shape[3])) + elif len(shape) == 3: + input_dims.append(AxonDimensions(height=shape[0], width=shape[1], channel_cnt=shape[2])) + else: + input_dims.append(AxonDimensions(height=1, width=shape[0] if shape else 1, channel_cnt=1)) + + # Output dimensions + if len(out_shape) == 4: + output_dims = AxonDimensions(height=out_shape[1], width=out_shape[2], channel_cnt=out_shape[3]) + elif len(out_shape) == 3: + output_dims = AxonDimensions(height=out_shape[0], width=out_shape[1], channel_cnt=out_shape[2]) + else: + output_dims = AxonDimensions(height=1, width=out_shape[0] if out_shape else 1, channel_cnt=1) + + return AxonLayer( + input_ids=input_ids, + operation=AxonOp.CONCATENATE, + concatenate_axis=axon_axis, + input_dimensions=input_dims, + output_dimensions=output_dims, + ) + + +def _convert_slice( + op: TosaOperator, + tensor_to_layer: dict[str, int], + layer_idx: int, +) -> AxonLayer: + """Convert TOSA SLICE to AXON STRIDED_SLICE layer. + + TOSA SLICE inputs: [input, start_tensor, size_tensor] + start and size are constant tensors with indices per dimension. + AXON STRIDED_SLICE uses begin/end/strides arrays in CHW order. + TOSA SLICE has no strides (all = 1), so we set strides to 1. + + The strided slice parameters are packed as 9 × int32 in the filter_data + field: [begin_c, begin_h, begin_w, end_c, end_h, end_w, stride_c, stride_h, stride_w]. + """ + input_tensor = op.input_tensors[0] + output_tensor = op.output_tensors[0] + in_shape = input_tensor.shape # [N, H, W, C] + out_shape = output_tensor.shape + + # Extract start and size from constant input tensors + start_tensor = op.input_tensors[1] if len(op.input_tensors) > 1 else None + size_tensor = op.input_tensors[2] if len(op.input_tensors) > 2 else None + + if start_tensor is not None and start_tensor.data is not None: + start = start_tensor.data.flatten().tolist() + else: + start = [0] * len(in_shape) + + if size_tensor is not None and size_tensor.data is not None: + size = size_tensor.data.flatten().tolist() + else: + size = list(out_shape) + + # Convert NHWC start/size to AXON CHW begin/end/strides + # TOSA: [N, H, W, C], AXON: [C, H, W] + if len(start) == 4: + begin = [int(start[3]), int(start[1]), int(start[2])] # C, H, W + end = [int(start[3] + size[3]), int(start[1] + size[1]), int(start[2] + size[2])] + elif len(start) == 3: + begin = [int(start[2]), int(start[0]), int(start[1])] + end = [int(start[2] + size[2]), int(start[0] + size[0]), int(start[1] + size[1])] + else: + begin = [0, 0, 0] + end = [int(size[0]) if size else 1, 1, 1] + + strides = [1, 1, 1] # TOSA SLICE has no strides + + # Pack as 9 × int32: begin[3] + end[3] + strides[3] + params = np.array(begin + end + strides, dtype=np.int32) + filter_data = params.tobytes() + + # Dimensions + if len(in_shape) == 4: + input_dims = AxonDimensions(height=in_shape[1], width=in_shape[2], channel_cnt=in_shape[3]) + output_dims = AxonDimensions(height=out_shape[1], width=out_shape[2], channel_cnt=out_shape[3]) + else: + input_dims = AxonDimensions(height=1, width=in_shape[0] if in_shape else 1, channel_cnt=1) + output_dims = AxonDimensions(height=1, width=out_shape[0] if out_shape else 1, channel_cnt=1) + + return AxonLayer( + input_ids=[_resolve_input_id(input_tensor, tensor_to_layer)], + operation=AxonOp.STRIDED_SLICE, + input_dimensions=[input_dims], + output_dimensions=output_dims, + filter_data=filter_data, + ) + + +def _convert_pad( + op: TosaOperator, + tensor_to_layer: dict[str, int], + layer_idx: int, +) -> AxonLayer: + """Convert TOSA PAD to AXON CHANNEL_PADDING layer. + + TOSA PAD inputs: [input, padding_tensor, pad_const_tensor] + padding_tensor shape: [N_dims, 2] with [before, after] per dimension. + AXON CHANNEL_PADDING uses the standard padding fields and only supports + channel-dimension padding (top/bottom must be 0). + """ + input_tensor = op.input_tensors[0] + output_tensor = op.output_tensors[0] + in_shape = input_tensor.shape + out_shape = output_tensor.shape + + # Extract padding values from constant tensor + pad_tensor = op.input_tensors[1] if len(op.input_tensors) > 1 else None + padding = [[0, 0]] * len(in_shape) + if pad_tensor is not None and pad_tensor.data is not None: + pad_data = pad_tensor.data.reshape(-1, 2).tolist() + for i, (before, after) in enumerate(pad_data): + if i < len(padding): + padding[i] = [int(before), int(after)] + + # NHWC: padding[0]=N, [1]=H, [2]=W, [3]=C + if len(padding) >= 4: + pad_top, pad_bottom = padding[1] + pad_left, pad_right = padding[2] + # Channel padding stored as pad_top/pad_bottom of the channel axis + # AXON expects front/back channel padding + elif len(padding) >= 3: + pad_top, pad_bottom = padding[0] + pad_left, pad_right = padding[1] + else: + pad_top = pad_bottom = pad_left = pad_right = 0 + + if len(in_shape) == 4: + input_dims = AxonDimensions(height=in_shape[1], width=in_shape[2], channel_cnt=in_shape[3]) + output_dims = AxonDimensions(height=out_shape[1], width=out_shape[2], channel_cnt=out_shape[3]) + else: + input_dims = AxonDimensions(height=1, width=in_shape[0] if in_shape else 1, channel_cnt=1) + output_dims = AxonDimensions(height=1, width=out_shape[0] if out_shape else 1, channel_cnt=1) + + return AxonLayer( + input_ids=[_resolve_input_id(input_tensor, tensor_to_layer)], + operation=AxonOp.CHANNEL_PADDING, + input_dimensions=[input_dims], + output_dimensions=output_dims, + pad_top=pad_top, + pad_bottom=pad_bottom, + pad_left=pad_left, + pad_right=pad_right, + ) + + +def _convert_persistent_var( + op: TosaOperator, + tensor_to_layer: dict[str, int], + layer_idx: int, +) -> AxonLayer: + """Convert TOSA VARIABLE/VARIABLE_READ/VARIABLE_WRITE to AXON PERSISTENT_VAR. + + Used in streaming/stateful models to persist intermediate results between + inference calls. The persistent buffer is allocated separately from the + interlayer buffer. + + TOSA VARIABLE declares the var, VARIABLE_READ reads, VARIABLE_WRITE writes. + """ + output_tensor = op.output_tensors[0] if op.output_tensors else op.input_tensors[0] + shape = output_tensor.shape + + if len(shape) == 4: + dims = AxonDimensions(height=shape[1], width=shape[2], channel_cnt=shape[3]) + elif len(shape) == 3: + dims = AxonDimensions(height=shape[0], width=shape[1], channel_cnt=shape[2]) + else: + total = 1 + for s in shape: + total *= s + dims = AxonDimensions(height=1, width=total, channel_cnt=1) + + input_id = -1 + if op.input_tensors: + input_id = _resolve_input_id(op.input_tensors[0], tensor_to_layer) + + return AxonLayer( + input_ids=[input_id], + operation=AxonOp.PERSISTENT_VAR, + input_dimensions=[dims], + output_dimensions=dims, + ) + diff --git a/backends/nordic/axon_types.py b/backends/nordic/axon_types.py new file mode 100644 index 00000000000..06cb2398d59 --- /dev/null +++ b/backends/nordic/axon_types.py @@ -0,0 +1,110 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON NPU type definitions and enums. + +These mirror the C types in ``nrf_axon_nn_compiler_types.h`` from +Nordic's sdk-edge-ai. Used by both the compiler bridge +(``axon_compiler.py``) and the binary builder (``axon_binary.py``). +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import IntEnum + + +class AxonOp(IntEnum): + """AXON operation codes (from nrf_axon_nn_op_e).""" + FULLY_CONNECTED = 0 + CONV2D = 1 + DEPTHWISE_CONV2D = 2 + POINTWISE_CONV2D = 3 + AVERAGE_POOLING = 4 + MAX_POOLING = 5 + ADD2 = 6 + CHANNEL_PADDING = 7 + PERSISTENT_VAR = 8 + CONCATENATE = 9 + STRIDED_SLICE = 10 + MULTIPLY = 11 + MEAN = 12 + # Op extensions — CPU+AXON hybrid ops + SOFTMAX = 100 + SIGMOID = 101 + TANH = 102 + + +class AxonActivation(IntEnum): + """AXON activation function codes (from nrf_axon_nn_activation_function_e).""" + DISABLED = 0 + RELU = 1 + PREPARE_SOFTMAX = 2 # Preceding layer outputs q11.12 INT32 for softmax + LEAKY_RELU = 3 + + +class AxonByteWidth(IntEnum): + """AXON byte width codes (from nrf_axon_nn_byte_width_e).""" + INT8 = 1 + INT16 = 2 + INT32 = 4 + + +@dataclass +class AxonDimensions: + """Mirrors nrf_axon_nn_compiler_model_layer_dimensions_s.""" + height: int = 1 + width: int = 1 + channel_cnt: int = 1 + byte_width: int = AxonByteWidth.INT8 + + +@dataclass +class ActivationQuantInfo: + """Quantization parameters for an activation op (sigmoid/tanh/softmax). + + Extracted from the PyTorch FX graph BEFORE TOSA lowering, because + TOSA bakes scales into TABLE ops and the original values are lost. + """ + op_type: str # "sigmoid", "tanh", or "softmax" + input_scale: float + input_zp: int + output_scale: float + output_zp: int + + +@dataclass +class AxonLayer: + """An AXON layer descriptor, mirroring nrf_axon_nn_model_layer_desc_s. + + Each layer represents one operation in the AXON command buffer. + The ``filter_data``, ``bias_data``, ``multiplier_data``, and + ``shift_data`` fields contain raw bytes that are packed into the + constants section of the AXON intermediate binary. + """ + # Input layer indices (-1 = graph input, 0+ = preceding layer output) + input_ids: list[int] = field(default_factory=lambda: [-1]) + operation: int = AxonOp.FULLY_CONNECTED + concatenate_axis: int = 0 + input_dimensions: list[AxonDimensions] = field(default_factory=list) + filter_dimensions: AxonDimensions = field(default_factory=AxonDimensions) + output_dimensions: AxonDimensions = field(default_factory=AxonDimensions) + stride_x: int = 1 + stride_y: int = 1 + dilation_x: int = 1 + dilation_y: int = 1 + input_zero_point: int = 0 + output_zero_point: int = 0 + activation: int = AxonActivation.DISABLED + pad_left: int = 0 + pad_right: int = 0 + pad_top: int = 0 + pad_bottom: int = 0 + # Constant data (raw bytes, packed into the binary's constants section) + filter_data: bytes = b"" # INT8 weight data + bias_data: bytes = b"" # INT32 bias_prime values + multiplier_data: bytes = b"" # INT32 output rescale multipliers + shift_data: bytes = b"" # INT8 rescale shifts + scale_shift_cnt: int = 0 # Number of shifts (1 = shared, N = per-channel) + cpu_op_attributes: bytes = b"" # Op extension parameters (softmax/sigmoid/tanh) diff --git a/backends/nordic/operator_support/__init__.py b/backends/nordic/operator_support/__init__.py new file mode 100644 index 00000000000..bdc4d4b1b81 --- /dev/null +++ b/backends/nordic/operator_support/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON NPU operator support checks.""" + +from .axon_support import ( + AXON_CPU_ONLY_OPS, + AXON_FUSED_ACTIVATIONS, + AXON_OP_EXTENSIONS, + AXON_SUPPORTED_OPS, + check_conv2d, + check_fully_connected, + check_input_count, + check_pooling, + check_tensor_dimensions, +) + +__all__ = [ + "AXON_SUPPORTED_OPS", + "AXON_FUSED_ACTIVATIONS", + "AXON_OP_EXTENSIONS", + "AXON_CPU_ONLY_OPS", + "check_fully_connected", + "check_conv2d", + "check_pooling", + "check_tensor_dimensions", + "check_input_count", +] diff --git a/backends/nordic/operator_support/axon_constraints.py b/backends/nordic/operator_support/axon_constraints.py new file mode 100644 index 00000000000..f0722087c42 --- /dev/null +++ b/backends/nordic/operator_support/axon_constraints.py @@ -0,0 +1,171 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON NPU hardware constraint checks for the ExecuTorch partitioner. + +These ``OperatorSupportBase`` subclasses are passed as ``additional_checks`` +to the ``TOSAPartitioner``. They reject nodes that TOSA accepts but AXON +cannot execute due to hardware limits (tensor size, input count, filter +dimensions, etc.). + +Without these checks, over-sized operations would be delegated to AXON +and fail at compile time instead of falling back to CPU gracefully. +""" +from __future__ import annotations + +import typing + +import torch +from torch.fx import Node +from torch.fx.passes.operator_support import OperatorSupportBase + +from executorch.backends.nordic.axon.compile_spec import ( + AXON_MAX_CONV2D_FILTER, + AXON_MAX_CONV_STRIDE, + AXON_MAX_FC_INPUT, + AXON_MAX_FC_OUTPUT, + AXON_MAX_INPUTS_PER_NODE, + AXON_MAX_POOL_FILTER, + AXON_MAX_TENSOR_DIM, +) + + +def _get_tensor_shape(node: Node) -> list[int] | None: + """Extract the output tensor shape from a node's metadata.""" + val = node.meta.get("val") + if val is None: + return None + if isinstance(val, torch.Tensor): + return list(val.shape) + if hasattr(val, "shape"): + return list(val.shape) + return None + + +class AxonTensorDimensionCheck(OperatorSupportBase): + """Reject nodes whose output tensors exceed AXON's max dimensions. + + AXON supports max 1024 for height, width, and channels. + """ + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: Node + ) -> bool: + if node.op != "call_function": + return True + shape = _get_tensor_shape(node) + if shape is None: + return True # Can't determine shape; let TOSA checks handle it + for dim in shape: + if dim > AXON_MAX_TENSOR_DIM: + return False + return True + + +class AxonInputCountCheck(OperatorSupportBase): + """Reject nodes with more than 2 activation tensor inputs. + + AXON allows a maximum of 2 inputs per node. This counts only + activation (non-constant) tensor inputs — weight tensors and + scalar parameters don't count toward this limit. + """ + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: Node + ) -> bool: + if node.op != "call_function": + return True + # Count Node args that are tensor-producing (activation inputs). + # Skip scalar args, None args, and list/tuple args. + tensor_inputs = 0 + for arg in node.args: + if isinstance(arg, Node) and arg.op == "call_function": + tensor_inputs += 1 + # Only reject if clearly over the limit + if tensor_inputs > AXON_MAX_INPUTS_PER_NODE: + return False + return True + + +class AxonConvConstraintCheck(OperatorSupportBase): + """Reject convolution nodes that exceed AXON's filter/stride limits. + + Conv2D: max 16x16 filter, max stride 31. + Only rejects when we can definitively determine the constraint is + violated. Returns True (allow) when metadata is unavailable. + """ + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: Node + ) -> bool: + if node.op != "call_function": + return True + + target_name = str(node.target) + + # Check convolution constraints + if any(op in target_name for op in ("conv2d", "convolution")): + # Weight tensor is typically args[1] + if len(node.args) >= 2 and isinstance(node.args[1], Node): + weight_shape = _get_tensor_shape(node.args[1]) + if weight_shape and len(weight_shape) == 4: + kH, kW = weight_shape[2], weight_shape[3] + if kH > AXON_MAX_CONV2D_FILTER or kW > AXON_MAX_CONV2D_FILTER: + return False + + # Check stride if available + if len(node.args) >= 4: + stride = node.args[3] + if isinstance(stride, (list, tuple)) and len(stride) >= 2: + if stride[0] > AXON_MAX_CONV_STRIDE or stride[1] > AXON_MAX_CONV_STRIDE: + return False + + return True + + +class AxonFCConstraintCheck(OperatorSupportBase): + """Reject fully connected nodes that exceed AXON's size limits. + + FC max input: 2048 elements, max output: 2048 elements. + Only rejects when we can definitively determine the constraint is + violated. Returns True (allow) when metadata is unavailable. + """ + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: Node + ) -> bool: + if node.op != "call_function": + return True + + target_name = str(node.target) + if "linear" not in target_name and "addmm" not in target_name: + return True + + # Check input shape (last dim is feature size) + if node.args and isinstance(node.args[0], Node): + input_shape = _get_tensor_shape(node.args[0]) + if input_shape and input_shape[-1] > AXON_MAX_FC_INPUT: + return False + + # Check output shape + output_shape = _get_tensor_shape(node) + if output_shape and output_shape[-1] > AXON_MAX_FC_OUTPUT: + return False + + return True + + +def get_axon_constraint_checks() -> list[OperatorSupportBase]: + """Return all AXON hardware constraint checks. + + Pass these as ``additional_checks`` to ``AxonPartitioner`` to ensure + that only AXON-compatible operations are delegated. + """ + return [ + AxonTensorDimensionCheck(), + AxonInputCountCheck(), + AxonConvConstraintCheck(), + AxonFCConstraintCheck(), + ] diff --git a/backends/nordic/operator_support/axon_support.py b/backends/nordic/operator_support/axon_support.py new file mode 100644 index 00000000000..0dce5fc9317 --- /dev/null +++ b/backends/nordic/operator_support/axon_support.py @@ -0,0 +1,148 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""AXON NPU operator support checks. + +Defines which operations the AXON NPU can accelerate and their constraints. +Used by the partitioner to decide what gets delegated to AXON vs CPU. + +Operations fall into three categories: + +1. **AXON-accelerated**: Run entirely on the AXON NPU hardware. +2. **Fused activations**: Fused into the preceding compute layer at zero cost. +3. **Op extensions**: Hybrid AXON+CPU — the preceding layer runs on AXON with + higher-precision output (INT16 q3.12 or INT32 q11.12), then a CPU callback + completes the non-linear function. + +AXON hardware constraints (from Nordic documentation): + +- Tensors stored as ``int8_t tensor[channels][height][width]`` (CHW layout). + The ARM TOSA pass pipeline handles NCHW → NHWC, and Nordic's compiler + handles NHWC → CHW internally. +- Maximum tensor dimensions: 1024 height, 1024 width, 1024 channels. +- Maximum 2 inputs per node. +- Output rows aligned to 32-bit boundary (compiler handles padding). +- Most reshape operations are transparent to AXON (no data movement). + Non-transparent reshapes fall back to CPU. +""" + +from executorch.backends.nordic.axon.compile_spec import ( + AXON_MAX_CONV2D_FILTER, + AXON_MAX_CONV_STRIDE, + AXON_MAX_FC_INPUT, + AXON_MAX_FC_OUTPUT, + AXON_MAX_INPUTS_PER_NODE, + AXON_MAX_POOL_FILTER, + AXON_MAX_TENSOR_DIM, +) + +# ── Category 1: AXON-accelerated operations ────────────────────── +# These run entirely on the AXON NPU hardware via command buffers. +AXON_SUPPORTED_OPS = { + "fully_connected", + "conv2d", + "depthwise_conv2d", + "pointwise_conv2d", + "avg_pool2d", + "max_pool2d", + "add", + "mul", + "mean", + "concatenate", + "strided_slice", + "channel_padding", +} + +# ── Category 2: Fused activations (zero overhead) ──────────────── +# These are fused into the preceding compute layer's command buffer. +# They don't create separate AXON layers — the activation is encoded +# in the preceding layer's configuration word. +AXON_FUSED_ACTIVATIONS = {"relu", "relu6", "leaky_relu"} + +# ── Category 3: Op extensions (AXON+CPU hybrid) ────────────────── +# The preceding AXON layer outputs higher-precision data, then a CPU +# callback function completes the non-linear computation. +# sigmoid (101): preceding layer → INT16 q3.12 → CPU expf sigmoid +# tanh (102): preceding layer → INT16 q3.12 → CPU expf tanh +# softmax (100): preceding layer → INT32 q11.12 → CPU softmax +AXON_OP_EXTENSIONS = {"sigmoid", "tanh", "softmax"} + +# ── Operations that fall back to ExecuTorch portable CPU kernels ── +# These are not accelerated by AXON in any way. +# Note: most reshape operations are transparent to AXON (the compiler +# handles them without data movement). Only non-transparent reshapes +# that require actual memory reorganization fall back to CPU. +AXON_CPU_ONLY_OPS = { + "reshape", +} + + +# ── Global constraint checks ───────────────────────────────────── +# These apply to ALL AXON operations regardless of type. + +def check_tensor_dimensions( + height: int, width: int, channels: int +) -> tuple[bool, str]: + """Check if tensor dimensions fit AXON global constraints. + + AXON supports max 1024 for height, width, and channels. + """ + if height > AXON_MAX_TENSOR_DIM: + return False, f"height {height} > max {AXON_MAX_TENSOR_DIM}" + if width > AXON_MAX_TENSOR_DIM: + return False, f"width {width} > max {AXON_MAX_TENSOR_DIM}" + if channels > AXON_MAX_TENSOR_DIM: + return False, f"channels {channels} > max {AXON_MAX_TENSOR_DIM}" + return True, "" + + +def check_input_count(num_inputs: int) -> tuple[bool, str]: + """Check if the number of inputs fits AXON constraints. + + AXON allows a maximum of 2 inputs per node. + """ + if num_inputs > AXON_MAX_INPUTS_PER_NODE: + return False, f"inputs {num_inputs} > max {AXON_MAX_INPUTS_PER_NODE}" + return True, "" + + +# ── Per-operation constraint checks ────────────────────────────── + +def check_fully_connected(input_size: int, output_size: int) -> tuple[bool, str]: + """Check if a fully connected layer fits AXON constraints. + + FC max input/output: 2048 elements each. + """ + if input_size > AXON_MAX_FC_INPUT: + return False, f"FC input size {input_size} > max {AXON_MAX_FC_INPUT}" + if output_size > AXON_MAX_FC_OUTPUT: + return False, f"FC output size {output_size} > max {AXON_MAX_FC_OUTPUT}" + return True, "" + + +def check_conv2d( + filter_h: int, filter_w: int, stride_h: int, stride_w: int, channels: int +) -> tuple[bool, str]: + """Check if a conv2d layer fits AXON constraints. + + Conv2D max filter: 16x16, max stride: 31, max channels: 1024. + """ + if filter_h > AXON_MAX_CONV2D_FILTER or filter_w > AXON_MAX_CONV2D_FILTER: + return False, f"Conv2d filter {filter_h}x{filter_w} > max {AXON_MAX_CONV2D_FILTER}" + if stride_h > AXON_MAX_CONV_STRIDE or stride_w > AXON_MAX_CONV_STRIDE: + return False, f"Conv2d stride > max {AXON_MAX_CONV_STRIDE}" + if channels > AXON_MAX_TENSOR_DIM: + return False, f"Conv2d channels {channels} > max {AXON_MAX_TENSOR_DIM}" + return True, "" + + +def check_pooling(filter_h: int, filter_w: int) -> tuple[bool, str]: + """Check if a pooling layer fits AXON constraints. + + Pooling max filter: 32x32. + """ + if filter_h > AXON_MAX_POOL_FILTER or filter_w > AXON_MAX_POOL_FILTER: + return False, f"Pool filter {filter_h}x{filter_w} > max {AXON_MAX_POOL_FILTER}" + return True, "" diff --git a/backends/nordic/requirements.txt b/backends/nordic/requirements.txt new file mode 100644 index 00000000000..ad1e64ba853 --- /dev/null +++ b/backends/nordic/requirements.txt @@ -0,0 +1,13 @@ +# Python dependencies for the Nordic AXON backend. +# Install with: pip install -r backends/nordic/requirements.txt + +# Backend core +cffi>=1.15 +numpy>=2.0 +pyyaml + +# TOSA flatbuffer support (for model compilation) +tosa-tools + +# Quantization (used by AxonQuantizer and tests) +torchao diff --git a/backends/nordic/runtime/AxonBackend.cpp b/backends/nordic/runtime/AxonBackend.cpp new file mode 100644 index 00000000000..3ffdf683d86 --- /dev/null +++ b/backends/nordic/runtime/AxonBackend.cpp @@ -0,0 +1,353 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * ExecuTorch AXON delegate backend (C++ runtime). + * + * Registers as "AxonBackend" with the ExecuTorch runtime. When the .pte + * contains TOSA-delegated subgraphs tagged for AxonBackend, ExecuTorch + * calls our init/execute/destroy methods. One delegate handle per + * subgraph in the .pte; a single .pte may contain many delegated + * subgraphs. + * + * Multi-subgraph wiring + * --------------------- + * The Python side (AxonBackend.preprocess) writes one Nordic-compiled C + * header per delegated subgraph into a generated directory, plus a master + * table axon_subgraphs_table.h that #includes them all and exposes a + * const array of {name, &model_} entries. The Python side returns + * a small marker as the .pte's processed_bytes: + * + * offset size field + * ------ ---- ----- + * 0 4 magic "AXNG" + * 4 4 version little-endian uint32 = 1 + * 8 4 name_len little-endian uint32 + * 12 N name ASCII subgraph name (no NUL) + * + * We parse the marker at init() time, look the matching + * nrf_axon_nn_compiled_model_s up by name in axon_subgraphs[], and + * stash a pointer in the per-handle state. execute() then runs + * nrf_axon_nn_model_infer_sync() with that model. + */ + +#if defined(CONFIG_NRF_AXON) && CONFIG_NRF_AXON + +#include +#include +#include + +#include +#include +#include +#include + +/* AXON driver */ +#include "axon/nrf_axon_platform.h" +#include "drivers/axon/nrf_axon_driver.h" +#include "drivers/axon/nrf_axon_nn_infer.h" + +/* Auto-generated subgraph table from the AXON backend export pipeline. + * Brings in axon_subgraphs[] and AXON_SUBGRAPHS_COUNT. */ +#include "generated/axon_subgraphs_table.h" + +namespace { + +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::BackendInterface; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; +using executorch::runtime::Span; +using exec_aten::ScalarType; +using exec_aten::Tensor; + +/* Maximum simultaneously-loaded delegate handles. Each handle carries a + * MAX_PACKED_OUTPUT_BYTES scratch region, so the per-handle cost is + * dominated by that scratch. */ +#define MAX_AXON_DELEGATES 16 + +/* Maximum int8 packed output bytes for any one delegated subgraph. */ +#define MAX_PACKED_OUTPUT_BYTES 1024 + +struct AxonDelegateHandle { + const nrf_axon_nn_compiled_model_s *model; + int8_t packed_output[MAX_PACKED_OUTPUT_BYTES]; + bool initialized; + /* Profiling: per-handle cumulative cycles spent inside + * nrf_axon_nn_model_infer_sync() across the entire program run. */ + uint64_t total_infer_cycles; + uint32_t total_infer_calls; +}; + +static bool s_platform_initialized = false; +static AxonDelegateHandle s_handles[MAX_AXON_DELEGATES]; +static int s_handle_count = 0; + +/* Profiling: aggregate AXON delegate cycles across all handles. */ +extern "C" { + uint64_t axon_delegate_total_cycles = 0; + uint32_t axon_delegate_total_calls = 0; +} + +/* ── Marker format (kept in sync with backends/nordic/axon/codegen.py) */ +static constexpr uint8_t MARKER_MAGIC[4] = {'A', 'X', 'N', 'G'}; +static constexpr uint32_t MARKER_VERSION = 1; +static constexpr size_t MARKER_HEADER_SIZE = 12; /* magic + version + name_len */ + +static const nrf_axon_nn_compiled_model_s * +parse_marker_and_lookup(const uint8_t *bytes, size_t len, char *out_name, size_t out_name_cap) +{ + if (len < MARKER_HEADER_SIZE) { + ET_LOG(Error, "AxonBackend: processed_bytes too short (%zu < %zu)", + len, MARKER_HEADER_SIZE); + return nullptr; + } + if (memcmp(bytes, MARKER_MAGIC, 4) != 0) { + ET_LOG(Error, "AxonBackend: bad marker magic %02x%02x%02x%02x", + bytes[0], bytes[1], bytes[2], bytes[3]); + return nullptr; + } + /* Little-endian uint32 reads. */ + uint32_t version = + (uint32_t)bytes[4] | ((uint32_t)bytes[5] << 8) | + ((uint32_t)bytes[6] << 16) | ((uint32_t)bytes[7] << 24); + uint32_t name_len = + (uint32_t)bytes[8] | ((uint32_t)bytes[9] << 8) | + ((uint32_t)bytes[10] << 16) | ((uint32_t)bytes[11] << 24); + if (version != MARKER_VERSION) { + ET_LOG(Error, "AxonBackend: marker version %u, expected %u", + (unsigned)version, (unsigned)MARKER_VERSION); + return nullptr; + } + if (MARKER_HEADER_SIZE + name_len > len) { + ET_LOG(Error, "AxonBackend: marker name overflow (%u + 12 > %zu)", + (unsigned)name_len, len); + return nullptr; + } + if (name_len + 1 > out_name_cap) { + ET_LOG(Error, "AxonBackend: marker name too long (%u >= %zu)", + (unsigned)name_len, out_name_cap); + return nullptr; + } + memcpy(out_name, bytes + MARKER_HEADER_SIZE, name_len); + out_name[name_len] = '\0'; + + /* Linear scan over the generated table — at most ~64 entries. */ + for (size_t i = 0; i < AXON_SUBGRAPHS_COUNT; i++) { + if (strcmp(axon_subgraphs[i].name, out_name) == 0) { + return axon_subgraphs[i].model; + } + } + ET_LOG(Error, "AxonBackend: no subgraph named '%s' in generated table " + "(%d entries) — did you re-run export but forget to " + "rebuild firmware?", + out_name, (int)AXON_SUBGRAPHS_COUNT); + return nullptr; +} + +class AxonBackendImpl final : public BackendInterface { +public: + bool is_available() const override { + return true; + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs + ) const override { + ET_LOG(Info, "AxonBackend::init (delegate %d, processed=%zu bytes)", + s_handle_count, processed->size()); + + if (s_handle_count >= MAX_AXON_DELEGATES) { + ET_LOG(Error, "Too many AXON delegates (max %d)", MAX_AXON_DELEGATES); + return Error::MemoryAllocationFailed; + } + + /* Initialize AXON platform once across all delegate handles. */ + if (!s_platform_initialized) { + nrf_axon_result_e r = nrf_axon_platform_init(); + if (r != NRF_AXON_RESULT_SUCCESS) { + ET_LOG(Error, "AXON platform init failed: %d", (int)r); + return Error::InvalidState; + } + s_platform_initialized = true; + } + + /* Parse the marker, look the model up by name. */ + char name_buf[128]; + const nrf_axon_nn_compiled_model_s *model = parse_marker_and_lookup( + static_cast(processed->data()), + processed->size(), + name_buf, sizeof(name_buf)); + if (!model) { + return Error::InvalidProgram; + } + + nrf_axon_result_e r = nrf_axon_nn_model_validate(model); + if (r != NRF_AXON_RESULT_SUCCESS) { + ET_LOG(Error, "AXON model '%s' validate failed: %d", name_buf, (int)r); + return Error::InvalidProgram; + } + + AxonDelegateHandle *handle = &s_handles[s_handle_count++]; + handle->model = model; + handle->initialized = true; + handle->total_infer_cycles = 0; + handle->total_infer_calls = 0; + memset(handle->packed_output, 0, sizeof(handle->packed_output)); + + ET_LOG(Info, + " AXON model '%s' bound (out: %ux%ux%u byte_width=%u)", + name_buf, + model->output_dimensions.height, + model->output_dimensions.width, + model->output_dimensions.channel_cnt, + (unsigned)model->output_dimensions.byte_width); + + processed->Free(); + return handle; + } + + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle, + Span args + ) const override { + auto *axon_handle = static_cast(handle); + const nrf_axon_nn_compiled_model_s *model = axon_handle->model; + + if (args.size() < 2) { + ET_LOG(Error, "AxonBackend::execute: args=%zu (need >= 2)", args.size()); + return Error::InvalidArgument; + } + + const auto& input_evalue = args[0]; + if (!input_evalue->isTensor()) { + ET_LOG(Error, "AxonBackend: input is not a tensor"); + return Error::InvalidArgument; + } + const Tensor& input_tensor = input_evalue->toTensor(); + if (input_tensor.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "AxonBackend: input dtype %d, expected int8", + (int)input_tensor.scalar_type()); + return Error::InvalidArgument; + } + const int8_t *input_data = input_tensor.const_data_ptr(); + + timing_t t_start = timing_counter_get(); + nrf_axon_result_e r = nrf_axon_nn_model_infer_sync( + model, input_data, axon_handle->packed_output); + timing_t t_end = timing_counter_get(); + if (r != NRF_AXON_RESULT_SUCCESS) { + ET_LOG(Error, "AXON inference failed: %d", (int)r); + return Error::InvalidState; + } + uint64_t cyc = timing_cycles_get(&t_start, &t_end); + axon_handle->total_infer_cycles += cyc; + axon_handle->total_infer_calls++; + axon_delegate_total_cycles += cyc; + axon_delegate_total_calls++; + + /* Copy AXON's packed int8 output into ExecuTorch's output tensor. */ + auto& output_evalue = args[1]; + if (!output_evalue->isTensor()) { + ET_LOG(Error, "AxonBackend: output is not a tensor"); + return Error::InvalidArgument; + } + Tensor& output_tensor = output_evalue->toTensor(); + if (output_tensor.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "AxonBackend: output dtype %d, expected int8", + (int)output_tensor.scalar_type()); + return Error::InvalidArgument; + } + int8_t *out_data = output_tensor.mutable_data_ptr(); + size_t copy_bytes = output_tensor.numel(); + if (copy_bytes > sizeof(axon_handle->packed_output)) { + ET_LOG(Error, + "AxonBackend: output tensor (%zu bytes) > packed_output " + "scratch (%zu bytes); bump MAX_PACKED_OUTPUT_BYTES", + copy_bytes, sizeof(axon_handle->packed_output)); + return Error::MemoryAllocationFailed; + } + memcpy(out_data, axon_handle->packed_output, copy_bytes); + return Error::Ok; + } + + void destroy(DelegateHandle* handle) const override { + (void)handle; + } +}; + +/* Register the backend with ExecuTorch's runtime. */ +static AxonBackendImpl s_axon_backend; +static Backend s_backend_id{"AxonBackend", &s_axon_backend}; +static auto s_registered __attribute__((used)) = + executorch::runtime::register_backend(s_backend_id); + +} /* anonymous namespace */ + +/* Profiling API: zero all cycle counters. */ +extern "C" void axon_delegate_reset_profile(void) +{ + axon_delegate_total_cycles = 0; + axon_delegate_total_calls = 0; + for (int i = 0; i < s_handle_count; i++) { + s_handles[i].total_infer_cycles = 0; + s_handles[i].total_infer_calls = 0; + } +} + +/* Profiling API: dump per-handle AXON cycle counts to the log. */ +extern "C" void axon_delegate_dump_profile(void) +{ + ET_LOG(Info, ""); + ET_LOG(Info, "=== AXON delegate profile ==="); + ET_LOG(Info, "handles bound: %d", s_handle_count); + ET_LOG(Info, "total infer cycles: %llu (%lu calls)", + (unsigned long long)axon_delegate_total_cycles, + (unsigned long)axon_delegate_total_calls); + if (axon_delegate_total_calls > 0) { + uint64_t avg = axon_delegate_total_cycles / axon_delegate_total_calls; + ET_LOG(Info, "avg cycles/call: %llu", (unsigned long long)avg); + } + for (int i = 0; i < s_handle_count; i++) { + const AxonDelegateHandle *h = &s_handles[i]; + if (h->total_infer_calls == 0) { + continue; + } + uint64_t avg = h->total_infer_cycles / h->total_infer_calls; + const auto &dim = h->model->output_dimensions; + ET_LOG(Info, " [%2d] %-25s out=%ux%ux%u calls=%lu total=%llu avg=%llu", + i, + h->model->model_name ? h->model->model_name : "(unnamed)", + dim.height, dim.width, dim.channel_cnt, + (unsigned long)h->total_infer_calls, + (unsigned long long)h->total_infer_cycles, + (unsigned long long)avg); + } + ET_LOG(Info, "============================="); +} + +#else /* not CONFIG_NRF_AXON: stub the profiling symbols so firmware + * can link without #ifdefs everywhere. */ + +#include +extern "C" { + uint64_t axon_delegate_total_cycles = 0; + uint32_t axon_delegate_total_calls = 0; + void axon_delegate_dump_profile(void) {} + void axon_delegate_reset_profile(void) {} +} + +#endif /* CONFIG_NRF_AXON */ diff --git a/backends/nordic/runtime/AxonBackend.h b/backends/nordic/runtime/AxonBackend.h new file mode 100644 index 00000000000..8886712227e --- /dev/null +++ b/backends/nordic/runtime/AxonBackend.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * Nordic AXON NPU delegate — public profiling API. + */ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* Aggregate cycle counters across all AXON delegate handles. */ +extern uint64_t axon_delegate_total_cycles; +extern uint32_t axon_delegate_total_calls; + +/* Reset all per-handle and global cycle counters. */ +void axon_delegate_reset_profile(void); + +/* Dump per-handle cycle counts to the ExecuTorch log. */ +void axon_delegate_dump_profile(void); + +#ifdef __cplusplus +} +#endif diff --git a/backends/nordic/runtime/axon_op_extensions.c b/backends/nordic/runtime/axon_op_extensions.c new file mode 100644 index 00000000000..f59c5b2b7af --- /dev/null +++ b/backends/nordic/runtime/axon_op_extensions.c @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * AXON op extensions — sigmoid and tanh CPU callbacks. + * + * These are CPU-side implementations of the activation functions that + * the AXON command buffer dispatches to when it encounters op extension + * segments (op code 101=sigmoid, 102=tanh). The preceding AXON layer + * outputs INT16 q3.12 data, and these functions convert it to the + * final INT8 output. + * + * These replace Nordic's nrf_axon_nn_op_extension_sigmoid/_tanh from + * sdk-edge-ai/drivers/axon/nrf_axon_nn_op_extensions.c. The AXON + * backend's codegen step rewrites the generated model headers to + * reference axon_op_extension_sigmoid/_tanh instead. + * + * Why custom implementations + * -------------------------- + * Nordic's stock sigmoid uses double-precision libm exp() per element, + * which on the Cortex-M33's single-precision FPU is software-emulated + * (~2,800 cycles per element). Using single-precision expf() instead + * gives ~1.5x speedup with identical quantized output. + */ + +#include +#include +#include +#include +#include "axon/nrf_axon_platform.h" +#include "drivers/axon/nrf_axon_nn_op_extensions.h" + +/* Quantise a sigmoid result in [0,1] to int8. */ +static inline int8_t axon_quantize_sigmoid(float v) +{ + float q = roundf(v * 256.0f) - 128.0f; + if (q > 127.0f) return 127; + if (q < -128.0f) return -128; + return (int8_t)q; +} + +/* Quantise a tanh result in [-1,1] to int8. */ +static inline int8_t axon_quantize_tanh(float v) +{ + float q = roundf(v * 128.0f); + if (q > 127.0f) return 127; + if (q < -128.0f) return -128; + return (int8_t)q; +} + +nrf_axon_result_e axon_op_extension_sigmoid( + uint16_t argc, NRF_AXON_PLATFORM_BITWIDTH_UNSIGNED_TYPE *args) +{ + if (args == NULL || + (argc * sizeof(NRF_AXON_PLATFORM_BITWIDTH_UNSIGNED_TYPE)) + < sizeof(nrf_axon_nn_op_extension_base1_args_s)) { + return NRF_AXON_RESULT_FAILURE; + } + nrf_axon_nn_op_extension_base1_args_s *base1_args = + (nrf_axon_nn_op_extension_base1_args_s *)args; + + if (base1_args->remaining_args.output_bytewidth != 1) { + return NRF_AXON_RESULT_FAILURE; + } + + const uint8_t input_extra_stride = + (!base1_args->remaining_args.input_is_packed + && (base1_args->remaining_args.width & 1)) + ? 1 : 0; + + int16_t *input_ptr = (int16_t *)base1_args->ptr_args.input; + int8_t *output_ptr = (int8_t *)base1_args->ptr_args.output; + const uint16_t channels = base1_args->remaining_args.channel_cnt; + const uint16_t height = base1_args->remaining_args.height; + const uint16_t width = base1_args->remaining_args.width; + + for (uint16_t ch = 0; ch < channels; ch++) { + for (uint16_t row = 0; row < height; row++) { + for (uint16_t col = 0; col < width; col++, input_ptr++, output_ptr++) { + /* Input is q3.12 INT16; convert to float. */ + float x = (float)*input_ptr * (1.0f / 4096.0f); + /* Single-precision expf instead of double-precision exp. */ + float e = expf(-x); + *output_ptr = axon_quantize_sigmoid(1.0f / (1.0f + e)); + } + input_ptr += input_extra_stride; + } + } + return NRF_AXON_RESULT_SUCCESS; +} + +nrf_axon_result_e axon_op_extension_tanh( + uint16_t argc, NRF_AXON_PLATFORM_BITWIDTH_UNSIGNED_TYPE *args) +{ + if (args == NULL || + (argc * sizeof(NRF_AXON_PLATFORM_BITWIDTH_UNSIGNED_TYPE)) + < sizeof(nrf_axon_nn_op_extension_base1_args_s)) { + return NRF_AXON_RESULT_FAILURE; + } + nrf_axon_nn_op_extension_base1_args_s *base1_args = + (nrf_axon_nn_op_extension_base1_args_s *)args; + + if (base1_args->remaining_args.output_bytewidth != 1) { + return NRF_AXON_RESULT_FAILURE; + } + + const uint8_t input_extra_stride = + (!base1_args->remaining_args.input_is_packed + && (base1_args->remaining_args.width & 1)) + ? 1 : 0; + + int16_t *input_ptr = (int16_t *)base1_args->ptr_args.input; + int8_t *output_ptr = (int8_t *)base1_args->ptr_args.output; + const uint16_t channels = base1_args->remaining_args.channel_cnt; + const uint16_t height = base1_args->remaining_args.height; + const uint16_t width = base1_args->remaining_args.width; + + for (uint16_t ch = 0; ch < channels; ch++) { + for (uint16_t row = 0; row < height; row++) { + for (uint16_t col = 0; col < width; col++, input_ptr++, output_ptr++) { + /* q3.12 with 2x factor folded into divisor (1<<11). */ + float two_x = (float)*input_ptr * (1.0f / 2048.0f); + float e = expf(two_x); + *output_ptr = axon_quantize_tanh((e - 1.0f) / (e + 1.0f)); + } + input_ptr += input_extra_stride; + } + } + return NRF_AXON_RESULT_SUCCESS; +} diff --git a/backends/nordic/scripts/run.sh b/backends/nordic/scripts/run.sh new file mode 100755 index 00000000000..4c3c8faf764 --- /dev/null +++ b/backends/nordic/scripts/run.sh @@ -0,0 +1,172 @@ +#!/bin/bash +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Export a model and build firmware for the AXON NPU. +# +# Usage: +# # Source NCS toolchain first +# source ~/ncs-workspace/nrf-connect-sdk-env.sh +# +# # Run with defaults (hello_axon sample) +# ./backends/nordic/scripts/run.sh +# +# # Or specify a custom export script +# ./backends/nordic/scripts/run.sh \ +# --sample=examples/nordic/hello_axon \ +# --board=nrf54lm20dk/nrf54lm20b/cpuapp +# +# Environment variables: +# SDK_EDGE_AI_PATH - Path to Nordic sdk-edge-ai (required for AXON compilation) +# PYTHON - Python interpreter for model export (default: python3) + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +BACKEND_DIR="$(dirname "$SCRIPT_DIR")" +ET_ROOT="$(cd "$BACKEND_DIR/../.." && pwd)" + +# Defaults +SAMPLE="${ET_ROOT}/examples/nordic/hello_axon" +BOARD="nrf54lm20dk/nrf54lm20b/cpuapp" +BUILD_DIR="" +PYTHON="${PYTHON:-python3}" +SKIP_EXPORT=0 +BUILD_ONLY=0 + +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --sample=*) SAMPLE="${1#*=}" ;; + --board=*) BOARD="${1#*=}" ;; + --build-dir=*) BUILD_DIR="${1#*=}" ;; + --python=*) PYTHON="${1#*=}" ;; + --skip-export) SKIP_EXPORT=1 ;; + --build-only) BUILD_ONLY=1 ;; + -h|--help) + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " --sample=PATH Sample directory (default: examples/nordic/hello_axon)" + echo " --board=BOARD Zephyr board target (default: nrf54lm20dk/nrf54lm20b/cpuapp)" + echo " --build-dir=PATH Build output directory (default: /build)" + echo " --python=PYTHON Python for model export (default: python3)" + echo " --skip-export Skip model export, use existing model_pte.h" + echo " --build-only Build firmware but don't flash" + echo "" + echo "Environment:" + echo " SDK_EDGE_AI_PATH Path to Nordic sdk-edge-ai (required)" + echo " Source nrf-connect-sdk-env.sh before running" + exit 0 + ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac + shift +done + +SAMPLE="$(cd "$SAMPLE" && pwd)" +SAMPLE_NAME="$(basename "$SAMPLE")" +BUILD_DIR="${BUILD_DIR:-${SAMPLE}/build/${SAMPLE_NAME}}" + +echo "=== Nordic AXON: Export and Build ===" +echo " Sample: $SAMPLE" +echo " Board: $BOARD" +echo " Build dir: $BUILD_DIR" +echo " SDK: ${SDK_EDGE_AI_PATH:-NOT SET}" +echo "" + +# ── Step 0: Validate environment ────────────────────────────────── +echo "--- Checking environment ---" +if ! bash "$SCRIPT_DIR/setup.sh" 2>/dev/null; then + echo "" + echo "WARNING: Environment check found issues (see above)." + echo " Continuing anyway — some steps may fail." + echo "" +fi + +# Check west is available (needed for firmware build) +if ! command -v west &>/dev/null; then + echo "ERROR: 'west' not found." + echo " Source your NCS toolchain environment first:" + echo " source ~/ncs-workspace/nrf-connect-sdk-env.sh" + exit 1 +fi + +# ── Step 1: Export model ────────────────────────────────────────── +if [ "$SKIP_EXPORT" -eq 0 ] && [ -f "$SAMPLE/export_model.py" ]; then + echo "" + echo "--- Exporting model ---" + + # Set up the export venv if it doesn't exist yet + if [ ! -d "$SAMPLE/.venv" ]; then + if [ -f "$SAMPLE/setup_export_env.sh" ]; then + echo "First run — setting up export environment..." + bash "$SAMPLE/setup_export_env.sh" + fi + fi + + # Export using uv (isolated from NCS Python). + # Must unset PYTHONHOME/PYTHONPATH to avoid NCS toolchain conflicts. + if command -v uv &>/dev/null && [ -d "$SAMPLE/.venv" ]; then + PYTHONHOME= PYTHONPATH= SDK_EDGE_AI_PATH="${SDK_EDGE_AI_PATH}" \ + uv run --directory "$SAMPLE" python export_model.py + else + # Fallback: use whatever python is available + SDK_EDGE_AI_PATH="${SDK_EDGE_AI_PATH}" \ + "$PYTHON" "$SAMPLE/export_model.py" + fi +else + if [ "$SKIP_EXPORT" -eq 0 ]; then + echo "No export_model.py found in $SAMPLE — skipping export." + else + echo "Skipping export (--skip-export)." + fi +fi + +# Check model_pte.h exists +if [ ! -f "$SAMPLE/src/model_pte.h" ]; then + echo "ERROR: $SAMPLE/src/model_pte.h not found." + echo " Run the export step first. If using uv:" + echo " cd $SAMPLE && ./setup_export_env.sh" + echo " PYTHONHOME= SDK_EDGE_AI_PATH=~/sdk-edge-ai uv run python export_model.py" + exit 1 +fi + +# ── Step 2: Build firmware ──────────────────────────────────────── +echo "" +echo "--- Building firmware ---" + +EXTRA_MODULES="${ET_ROOT}" +if [ -n "$SDK_EDGE_AI_PATH" ] && [ -d "$SDK_EDGE_AI_PATH" ]; then + EXTRA_MODULES="${ET_ROOT};${SDK_EDGE_AI_PATH}" +fi + +west build \ + -b "$BOARD" \ + "$SAMPLE" \ + -d "$BUILD_DIR" \ + --no-sysbuild \ + -p \ + -- \ + -DZEPHYR_EXTRA_MODULES="$EXTRA_MODULES" + +echo "" +echo "=== Build complete ===" +echo " Hex: $BUILD_DIR/zephyr/zephyr.hex" +echo " ELF: $BUILD_DIR/zephyr/zephyr.elf" + +# ── Step 3: Flash ───────────────────────────────────────────────── +if [ "$BUILD_ONLY" -eq 0 ]; then + echo "" + echo "--- Flashing ---" + west flash --build-dir "$BUILD_DIR" + echo "" + echo "Flash complete. Open serial console (115200 baud) to see output." +else + echo "" + echo "Build only — skipping flash." + echo " Flash with: west flash --build-dir $BUILD_DIR" +fi diff --git a/backends/nordic/scripts/setup.sh b/backends/nordic/scripts/setup.sh new file mode 100755 index 00000000000..1f74fde5236 --- /dev/null +++ b/backends/nordic/scripts/setup.sh @@ -0,0 +1,104 @@ +#!/bin/bash +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Setup script for the Nordic AXON backend. +# Validates the environment and prints diagnostic information. + +set -e + +ERRORS=0 + +echo "=== Nordic AXON Backend Setup ===" +echo "" + +# Check SDK_EDGE_AI_PATH +if [ -z "$SDK_EDGE_AI_PATH" ]; then + echo "WARNING: SDK_EDGE_AI_PATH is not set." + echo " Set it to your Nordic sdk-edge-ai directory:" + echo " export SDK_EDGE_AI_PATH=/path/to/sdk-edge-ai" + echo "" + echo " Without the SDK, TOSA lowering works but AXON compilation" + echo " (producing command buffer headers) will be skipped." + echo "" + SDK_STATUS="NOT SET" + ERRORS=$((ERRORS + 1)) +else + if [ -d "$SDK_EDGE_AI_PATH" ]; then + SDK_STATUS="OK ($SDK_EDGE_AI_PATH)" + # Check for compiler lib + SYSTEM=$(uname -s) + case "$SYSTEM" in + Linux) LIB_NAME="libnrf-axon-nn-compiler-lib-amd64.so" ;; + Darwin) LIB_NAME="libnrf-axon-nn-compiler-lib-arm64.dylib" ;; + *) LIB_NAME="nrf-axon-nn-compiler-lib-amd64.dll" ;; + esac + COMPILER_LIB="$SDK_EDGE_AI_PATH/tools/axon/compiler/bin/$SYSTEM/$LIB_NAME" + if [ -f "$COMPILER_LIB" ]; then + echo " Compiler lib: $COMPILER_LIB" + else + echo " WARNING: Compiler lib not found at: $COMPILER_LIB" + SDK_STATUS="INCOMPLETE (missing compiler lib)" + fi + else + echo " WARNING: SDK_EDGE_AI_PATH directory does not exist: $SDK_EDGE_AI_PATH" + SDK_STATUS="INVALID" + fi +fi + +# Check Python packages +echo "" +echo "Checking Python dependencies..." +MISSING="" +for pkg in cffi numpy yaml tosa; do + if python3 -c "import $pkg" 2>/dev/null; then + echo " $pkg: OK" + else + echo " $pkg: MISSING" + MISSING="$MISSING $pkg" + fi +done + +if [ -n "$MISSING" ]; then + echo "" + echo "Install missing packages:" + echo " pip install -r backends/nordic/requirements.txt" +fi + +# Check ExecuTorch +echo "" +if python3 -c "import executorch" 2>/dev/null; then + echo "ExecuTorch: OK" +else + echo "ExecuTorch: NOT FOUND" + echo " Install ExecuTorch first — see the root README." +fi + +# Check AXON backend +if python3 -c "from executorch.backends.nordic.axon import AxonBackend" 2>/dev/null; then + echo "AXON backend: OK" +else + echo "AXON backend: IMPORT FAILED" +fi + +# Summary +echo "" +echo "=== Summary ===" +echo " SDK_EDGE_AI_PATH: $SDK_STATUS" +echo " Python deps: $([ -z "$MISSING" ] && echo 'OK' || echo "MISSING:$MISSING")" +echo "" + +# Run quick test +echo "Running quick import test..." +python3 -c " +from executorch.backends.nordic.axon import AxonBackend, AxonCompileSpec, AxonPartitioner +from executorch.backends.nordic.operator_support import AXON_SUPPORTED_OPS +print(f' AxonBackend: OK') +print(f' Supported ops: {len(AXON_SUPPORTED_OPS)}') +print('Setup complete.') +" 2>/dev/null || { echo " Import test failed — check your ExecuTorch installation."; ERRORS=$((ERRORS + 1)); } + +exit $ERRORS diff --git a/backends/nordic/test/__init__.py b/backends/nordic/test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/nordic/test/conftest.py b/backends/nordic/test/conftest.py new file mode 100644 index 00000000000..fe75b628818 --- /dev/null +++ b/backends/nordic/test/conftest.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Pytest configuration and shared fixtures for the Nordic AXON backend tests.""" + +from __future__ import annotations + +import os + +import pytest + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line( + "markers", "requires_sdk: test requires Nordic sdk-edge-ai" + ) + config.addinivalue_line( + "markers", "requires_hardware: test requires nRF54LM20DK hardware" + ) + + +@pytest.fixture +def sdk_edge_ai_path() -> str | None: + """Path to Nordic sdk-edge-ai, or None if not available.""" + path = os.environ.get("SDK_EDGE_AI_PATH", "") + if path and os.path.isdir(path): + return path + return None + + +@pytest.fixture +def require_sdk(sdk_edge_ai_path): + """Skip the test if Nordic SDK is not available.""" + if sdk_edge_ai_path is None: + pytest.skip( + "Nordic sdk-edge-ai not found. Set SDK_EDGE_AI_PATH to enable." + ) + return sdk_edge_ai_path diff --git a/backends/nordic/test/test_axon_compile.py b/backends/nordic/test/test_axon_compile.py new file mode 100644 index 00000000000..a24f36892cc --- /dev/null +++ b/backends/nordic/test/test_axon_compile.py @@ -0,0 +1,187 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for the Nordic AXON backend — full compilation stage. + +These tests require the Nordic sdk-edge-ai to be installed and +SDK_EDGE_AI_PATH to be set. They validate the complete pipeline from +PyTorch model through TOSA to compiled AXON command buffers. +""" +from __future__ import annotations + +import os + +import pytest +import torch +import torch.nn as nn + +_SDK_PATH = os.environ.get("SDK_EDGE_AI_PATH", "") +_HAS_SDK = bool(_SDK_PATH) and os.path.isdir(_SDK_PATH) + +pytestmark = [ + pytest.mark.requires_sdk, + pytest.mark.skipif( + not _HAS_SDK, + reason="Nordic SDK not available (set SDK_EDGE_AI_PATH to enable)", + ), +] + + +@pytest.fixture +def sdk_path(): + """Path to Nordic sdk-edge-ai.""" + return _SDK_PATH + + +class TestAxonCompilation: + """End-to-end compilation tests.""" + + def _compile_model(self, model, example_input, sdk_path, tmp_path): + """Compile a model through the full AXON pipeline.""" + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig + + generated_dir = tmp_path / "generated" + generated_dir.mkdir() + + compile_spec = AxonCompileSpec( + sdk_edge_ai_path=sdk_path, + model_name="test_model", + axon_generated_dir=str(generated_dir), + ) + partitioner = AxonPartitioner(compile_spec) + + # Quantize + model.eval() + exported = torch.export.export(model, example_input) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True) + ) + prepared = prepare_pt2e(exported, quantizer) + prepared(*example_input) + quantized = convert_pt2e(prepared) + + # Edge lower with AXON partitioner + edge = to_edge_transform_and_lower( + quantized, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + return edge, generated_dir + + def test_linear_compiles(self, sdk_path, tmp_path): + """A simple linear model compiles to AXON command buffers.""" + model = nn.Sequential(nn.Linear(16, 8)) + example_input = (torch.randn(1, 16),) + edge, gen_dir = self._compile_model(model, example_input, sdk_path, tmp_path) + + # Check generated headers exist + headers = list(gen_dir.glob("axon_subgraph_*.h")) + assert len(headers) >= 1, f"No subgraph headers generated in {gen_dir}" + + # Check table exists + table = gen_dir / "axon_subgraphs_table.h" + assert table.exists(), "axon_subgraphs_table.h not generated" + content = table.read_text() + assert "AXON_SUBGRAPHS_COUNT" in content + + def test_conv_relu_compiles(self, sdk_path, tmp_path): + """Conv2d + ReLU compiles with fused activation.""" + model = nn.Sequential( + nn.Conv2d(1, 4, kernel_size=3, padding=1), + nn.ReLU(), + ) + example_input = (torch.randn(1, 1, 8, 8),) + edge, gen_dir = self._compile_model(model, example_input, sdk_path, tmp_path) + + headers = list(gen_dir.glob("axon_subgraph_*.h")) + assert len(headers) >= 1 + + def test_multi_layer_produces_unique_names(self, sdk_path, tmp_path): + """A multi-layer model produces distinct subgraph names.""" + model = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 8), + ) + example_input = (torch.randn(1, 16),) + edge, gen_dir = self._compile_model(model, example_input, sdk_path, tmp_path) + + headers = list(gen_dir.glob("axon_subgraph_*.h")) + # Should have at least 2 distinct subgraphs (2 linears) + names = [h.stem for h in headers] + assert len(names) == len(set(names)), f"Duplicate subgraph names: {names}" + + def test_compiled_header_has_cmd_buffer(self, sdk_path, tmp_path): + """The compiled header contains a command buffer array.""" + model = nn.Sequential(nn.Linear(16, 8)) + example_input = (torch.randn(1, 16),) + edge, gen_dir = self._compile_model(model, example_input, sdk_path, tmp_path) + + headers = list(gen_dir.glob("axon_subgraph_*.h")) + assert len(headers) >= 1 + content = headers[0].read_text() + assert "cmd_buffer_" in content + assert "nrf_axon_nn_compiled_model_s" in content + + def test_no_nordic_symbols_leaked(self, sdk_path, tmp_path): + """Op extension symbols are rewritten to axon_op_extension_*.""" + model = nn.Sequential(nn.Linear(16, 8)) + example_input = (torch.randn(1, 16),) + edge, gen_dir = self._compile_model(model, example_input, sdk_path, tmp_path) + + # Check no nrf_axon_nn_op_extension_ symbols remain + for header in gen_dir.glob("axon_subgraph_*.h"): + content = header.read_text() + assert "nrf_axon_nn_op_extension_sigmoid" not in content + assert "nrf_axon_nn_op_extension_tanh" not in content + + def test_pte_export(self, sdk_path, tmp_path): + """Full .pte export succeeds.""" + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig + + model = nn.Sequential(nn.Linear(16, 8)) + example_input = (torch.randn(1, 16),) + + generated_dir = tmp_path / "generated" + generated_dir.mkdir() + + compile_spec = AxonCompileSpec( + sdk_edge_ai_path=sdk_path, + model_name="pte_test", + axon_generated_dir=str(generated_dir), + ) + partitioner = AxonPartitioner(compile_spec) + + model.eval() + exported = torch.export.export(model, example_input) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True) + ) + prepared = prepare_pt2e(exported, quantizer) + prepared(*example_input) + quantized = convert_pt2e(prepared) + + edge = to_edge_transform_and_lower( + quantized, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + pte_path = tmp_path / "test.pte" + edge.to_executorch().save(str(pte_path)) + assert pte_path.exists() + assert pte_path.stat().st_size > 0 diff --git a/backends/nordic/test/test_operators.py b/backends/nordic/test/test_operators.py new file mode 100644 index 00000000000..6adf1181a54 --- /dev/null +++ b/backends/nordic/test/test_operators.py @@ -0,0 +1,255 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Per-operator tests for the Nordic AXON backend. + +Validates that each supported operation type can be: +1. Exported from a PyTorch model +2. Lowered through the TOSA pipeline +3. Converted to AXON layer descriptors +4. Packed into an intermediate binary + +These tests run without the Nordic SDK (no compilation to command +buffers — just TOSA lowering and AXON layer conversion). +""" +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + +from executorch.backends.nordic.axon_types import AxonOp + + +_test_counter = 0 + + +def _lower_to_axon_layers(model, example_input): + """Export, quantize, and lower a model to AXON layers.""" + global _test_counter + _test_counter += 1 + model_name = f"optest_{_test_counter}" + + from executorch.backends.arm.tosa.specification import TosaSpecification + from executorch.backends.arm.quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, + ) + from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig + import tempfile, os + + model.eval() + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + exported = torch.export.export(model, example_input, strict=False) + captured = exported.module() + + # Strip torch 2.11 _guards_fn nodes + guard_nodes = [ + n for n in captured.graph.nodes + if n.op == "call_module" and "_guards" in str(n.target) + ] + for n in guard_nodes: + n.replace_all_uses_with(None) + captured.graph.erase_node(n) + for name in list(captured._modules.keys()): + if "_guards" in name: + delattr(captured, name) + captured.graph.lint() + captured.recompile() + + quantizer = EthosUQuantizer(tosa_spec).set_global( + get_symmetric_quantization_config(is_per_channel=True) + ) + prepared = prepare_pt2e(captured, quantizer) + prepared(*example_input) + quantized = convert_pt2e(prepared) + re_exported = torch.export.export(quantized, example_input, strict=False) + + compile_spec = AxonCompileSpec(model_name=model_name) + partitioner = AxonPartitioner(compile_spec) + to_edge_transform_and_lower( + re_exported, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + # Read the TOSA debug file (unique per test via model_name) + tosa_path = os.path.join(tempfile.gettempdir(), f"axon_tosa_debug_{model_name}.tosa") + if not os.path.exists(tosa_path): + pytest.skip("TOSA debug file not generated") + + from executorch.backends.nordic.tosa_reader import parse_tosa_flatbuffer + from executorch.backends.nordic.axon_compiler import tosa_to_axon_layers + + with open(tosa_path, "rb") as f: + tosa_bytes = f.read() + graph = parse_tosa_flatbuffer(tosa_bytes) + return tosa_to_axon_layers(graph) + + +class TestLinearOp: + """Test FC (fully connected) layer delegation.""" + + def test_simple_linear(self): + model = nn.Sequential(nn.Linear(16, 8)) + layers = _lower_to_axon_layers(model, (torch.randn(1, 16),)) + assert len(layers) >= 1 + # TOSA lowers Linear to CONV2D + compute = [l for l in layers if l.operation in (AxonOp.FULLY_CONNECTED, AxonOp.CONV2D)] + assert len(compute) >= 1 + + def test_linear_with_relu(self): + model = nn.Sequential(nn.Linear(16, 8), nn.ReLU()) + layers = _lower_to_axon_layers(model, (torch.randn(1, 16),)) + compute = [l for l in layers if l.operation in (AxonOp.FULLY_CONNECTED, AxonOp.CONV2D)] + assert len(compute) >= 1 + + def test_multi_linear(self): + model = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 8), + ) + layers = _lower_to_axon_layers(model, (torch.randn(1, 16),)) + # Multiple linears should produce multiple compute layers + assert len(layers) >= 2 + + +class TestConv2dOp: + """Test Conv2D layer delegation.""" + + def test_simple_conv2d(self): + model = nn.Sequential(nn.Conv2d(1, 4, kernel_size=3, padding=1)) + layers = _lower_to_axon_layers(model, (torch.randn(1, 1, 8, 8),)) + conv_layers = [l for l in layers if l.operation == AxonOp.CONV2D] + assert len(conv_layers) >= 1 + + def test_conv2d_relu(self): + model = nn.Sequential( + nn.Conv2d(1, 4, kernel_size=3, padding=1), + nn.ReLU(), + ) + layers = _lower_to_axon_layers(model, (torch.randn(1, 1, 8, 8),)) + # ReLU gets fused into the conv, so still just conv layers + assert len(layers) >= 1 + + def test_conv2d_different_filters(self): + """Test various filter sizes within AXON limits.""" + for k in [1, 3, 5, 7]: + model = nn.Sequential(nn.Conv2d(1, 4, kernel_size=k, padding=k // 2)) + layers = _lower_to_axon_layers(model, (torch.randn(1, 1, 8, 8),)) + assert len(layers) >= 1, f"Failed for kernel_size={k}" + + +class TestPoolOp: + """Test pooling layer delegation.""" + + def test_avg_pool2d(self): + model = nn.Sequential( + nn.Conv2d(1, 4, kernel_size=3, padding=1), + nn.ReLU(), + nn.AvgPool2d(2), + ) + layers = _lower_to_axon_layers(model, (torch.randn(1, 1, 8, 8),)) + pool_layers = [l for l in layers if l.operation == AxonOp.AVERAGE_POOLING] + assert len(pool_layers) >= 1 + + def test_max_pool2d(self): + model = nn.Sequential( + nn.Conv2d(1, 4, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + ) + layers = _lower_to_axon_layers(model, (torch.randn(1, 1, 8, 8),)) + pool_layers = [l for l in layers if l.operation == AxonOp.MAX_POOLING] + assert len(pool_layers) >= 1 + + +class TestElementwiseOps: + """Test element-wise operations (add, multiply).""" + + def test_add(self): + class AddModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 4, 3, padding=1) + self.conv2 = nn.Conv2d(1, 4, 3, padding=1) + + def forward(self, x): + return self.conv1(x) + self.conv2(x) + + layers = _lower_to_axon_layers(AddModel(), (torch.randn(1, 1, 8, 8),)) + add_layers = [l for l in layers if l.operation == AxonOp.ADD2] + assert len(add_layers) >= 1 + + def test_multiply(self): + class MulModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 4, 3, padding=1) + self.conv2 = nn.Conv2d(1, 4, 3, padding=1) + + def forward(self, x): + return self.conv1(x) * self.conv2(x) + + layers = _lower_to_axon_layers(MulModel(), (torch.randn(1, 1, 8, 8),)) + mul_layers = [l for l in layers if l.operation == AxonOp.MULTIPLY] + assert len(mul_layers) >= 1 + + +class TestBinaryBuilder: + """Test that AXON layers pack into valid intermediate binaries.""" + + def test_linear_binary(self): + from executorch.backends.nordic.axon_binary import AxonBinaryBuilder + + model = nn.Sequential(nn.Linear(16, 8)) + layers = _lower_to_axon_layers(model, (torch.randn(1, 16),)) + builder = AxonBinaryBuilder() + binary = builder.build(layers, model_name="test_linear") + assert len(binary) > 100 + assert b"AXON_INTERMEDIATE_REPRESENTATION_FILE" in binary + + def test_conv_binary(self): + from executorch.backends.nordic.axon_binary import AxonBinaryBuilder + + model = nn.Sequential(nn.Conv2d(1, 4, 3, padding=1), nn.ReLU()) + layers = _lower_to_axon_layers(model, (torch.randn(1, 1, 8, 8),)) + builder = AxonBinaryBuilder() + binary = builder.build(layers, model_name="test_conv") + assert len(binary) > 100 + + +class TestConstraintChecks: + """Test that AXON constraint checks work via the partitioner.""" + + def test_axon_constraints_importable(self): + from executorch.backends.nordic.operator_support.axon_constraints import ( + AxonTensorDimensionCheck, + AxonInputCountCheck, + AxonConvConstraintCheck, + AxonFCConstraintCheck, + get_axon_constraint_checks, + ) + checks = get_axon_constraint_checks() + assert len(checks) == 4 + + def test_partitioner_with_constraints(self): + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + from executorch.backends.nordic.operator_support.axon_constraints import ( + get_axon_constraint_checks, + ) + spec = AxonCompileSpec(model_name="test") + # Constraints are opt-in via additional_checks + partitioner = AxonPartitioner(spec, additional_checks=get_axon_constraint_checks()) + assert len(partitioner.additional_checks) >= 4 + + def test_partitioner_default_no_constraints(self): + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + spec = AxonCompileSpec(model_name="test") + partitioner = AxonPartitioner(spec) + assert len(partitioner.additional_checks) == 0 diff --git a/backends/nordic/test/test_tosa_lowering.py b/backends/nordic/test/test_tosa_lowering.py new file mode 100644 index 00000000000..96e2e932aac --- /dev/null +++ b/backends/nordic/test/test_tosa_lowering.py @@ -0,0 +1,377 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for the Nordic AXON backend — TOSA lowering stage. + +These tests validate that PyTorch models lower correctly through the +TOSA pipeline and produce valid AXON layer descriptors. No Nordic SDK +is required — these run on any machine with ExecuTorch and PyTorch. +""" +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + + +class TestAxonImports: + """Verify the backend package imports correctly.""" + + def test_lazy_imports(self): + from executorch.backends.nordic import ( + AxonBackend, + AxonCompileSpec, + AxonPartitioner, + ) + assert AxonBackend is not None + assert AxonCompileSpec is not None + assert AxonPartitioner is not None + + def test_direct_imports(self): + from executorch.backends.nordic.axon import AxonBackend + from executorch.backends.nordic.axon.compile_spec import AxonCompileSpec + from executorch.backends.nordic.axon.partitioner import AxonPartitioner + assert AxonBackend.__name__ == "AxonBackend" + + def test_operator_support(self): + from executorch.backends.nordic.operator_support import ( + AXON_SUPPORTED_OPS, + AXON_FUSED_ACTIVATIONS, + AXON_OP_EXTENSIONS, + check_fully_connected, + check_conv2d, + check_pooling, + ) + assert "fully_connected" in AXON_SUPPORTED_OPS + assert "conv2d" in AXON_SUPPORTED_OPS + assert "pointwise_conv2d" in AXON_SUPPORTED_OPS + assert "channel_padding" in AXON_SUPPORTED_OPS + assert len(AXON_SUPPORTED_OPS) >= 12 + assert "relu" in AXON_FUSED_ACTIVATIONS + assert "sigmoid" in AXON_OP_EXTENSIONS + assert "tanh" in AXON_OP_EXTENSIONS + assert "softmax" in AXON_OP_EXTENSIONS + + def test_quantizer_import(self): + from executorch.backends.nordic.axon import AxonQuantizer + q = AxonQuantizer() + assert q is not None + + def test_quantizer_lazy_import(self): + from executorch.backends.nordic import AxonQuantizer + assert AxonQuantizer is not None + + +class TestOperatorConstraints: + """Validate AXON hardware constraint checks.""" + + def test_fc_within_limits(self): + from executorch.backends.nordic.operator_support import check_fully_connected + ok, msg = check_fully_connected(128, 64) + assert ok is True + + def test_fc_max_input(self): + from executorch.backends.nordic.operator_support import check_fully_connected + ok, msg = check_fully_connected(2048, 64) + assert ok is True + + def test_fc_exceeds_input(self): + from executorch.backends.nordic.operator_support import check_fully_connected + ok, msg = check_fully_connected(4096, 64) + assert ok is False + assert "4096" in msg + + def test_fc_exceeds_output(self): + from executorch.backends.nordic.operator_support import check_fully_connected + ok, msg = check_fully_connected(128, 4096) + assert ok is False + + def test_conv2d_within_limits(self): + from executorch.backends.nordic.operator_support import check_conv2d + ok, msg = check_conv2d(3, 3, 1, 1, 32) + assert ok is True + + def test_conv2d_exceeds_filter(self): + from executorch.backends.nordic.operator_support import check_conv2d + ok, msg = check_conv2d(32, 32, 1, 1, 32) + assert ok is False + + def test_pooling_within_limits(self): + from executorch.backends.nordic.operator_support import check_pooling + ok, msg = check_pooling(2, 2) + assert ok is True + + def test_pooling_exceeds_filter(self): + from executorch.backends.nordic.operator_support import check_pooling + ok, msg = check_pooling(64, 64) + assert ok is False + + def test_tensor_dims_within_limits(self): + from executorch.backends.nordic.operator_support import check_tensor_dimensions + ok, msg = check_tensor_dimensions(512, 512, 64) + assert ok is True + + def test_tensor_dims_at_max(self): + from executorch.backends.nordic.operator_support import check_tensor_dimensions + ok, msg = check_tensor_dimensions(1024, 1024, 1024) + assert ok is True + + def test_tensor_dims_exceeds_height(self): + from executorch.backends.nordic.operator_support import check_tensor_dimensions + ok, msg = check_tensor_dimensions(2048, 512, 64) + assert ok is False + assert "height" in msg + + def test_tensor_dims_exceeds_channels(self): + from executorch.backends.nordic.operator_support import check_tensor_dimensions + ok, msg = check_tensor_dimensions(8, 8, 2048) + assert ok is False + assert "channels" in msg + + def test_input_count_valid(self): + from executorch.backends.nordic.operator_support import check_input_count + ok, msg = check_input_count(1) + assert ok is True + ok, msg = check_input_count(2) + assert ok is True + + def test_input_count_exceeds(self): + from executorch.backends.nordic.operator_support import check_input_count + ok, msg = check_input_count(3) + assert ok is False + assert "3" in msg + + def test_conv2d_exceeds_stride(self): + from executorch.backends.nordic.operator_support import check_conv2d + ok, msg = check_conv2d(3, 3, 32, 32, 16) + assert ok is False + assert "stride" in msg + + def test_conv2d_exceeds_channels(self): + from executorch.backends.nordic.operator_support import check_conv2d + ok, msg = check_conv2d(3, 3, 1, 1, 2048) + assert ok is False + assert "channels" in msg + + +class TestCompileSpec: + """Validate AxonCompileSpec serialization.""" + + def test_default_spec(self): + from executorch.backends.nordic.axon import AxonCompileSpec + spec = AxonCompileSpec() + compile_specs = spec.to_compile_specs() + keys = {s.key for s in compile_specs} + assert "tosa_spec" in keys + assert "output_format" in keys + assert "model_name" in keys + + def test_custom_spec(self): + from executorch.backends.nordic.axon import AxonCompileSpec + spec = AxonCompileSpec( + sdk_edge_ai_path="/opt/sdk-edge-ai", + model_name="test_model", + axon_generated_dir="/tmp/generated", + ) + compile_specs = spec.to_compile_specs() + keys = {s.key for s in compile_specs} + assert "sdk_edge_ai_path" in keys + assert "axon_generated_dir" in keys + + def test_spec_without_sdk(self): + from executorch.backends.nordic.axon import AxonCompileSpec + spec = AxonCompileSpec(model_name="no_sdk") + compile_specs = spec.to_compile_specs() + keys = {s.key for s in compile_specs} + assert "sdk_edge_ai_path" not in keys + + +class TestCodegen: + """Validate codegen utilities.""" + + def test_make_marker(self): + from executorch.backends.nordic.axon.codegen import make_marker + marker = make_marker("test_model_abc123") + assert marker[:4] == b"AXNG" + assert len(marker) % 4 == 0 + + def test_derive_subgraph_name(self): + from executorch.backends.nordic.axon.codegen import derive_subgraph_name + name1 = derive_subgraph_name("model", b"binary_data_1") + name2 = derive_subgraph_name("model", b"binary_data_2") + name3 = derive_subgraph_name("model", b"binary_data_1") + # Different content → different names + assert name1 != name2 + # Same content → same name + assert name1 == name3 + # Starts with prefix + assert name1.startswith("model_") + + def test_rewrite_header_symbols(self): + from executorch.backends.nordic.axon.codegen import rewrite_header_symbols + header = 'const int model_old_name = 1;\n.model_name = "old_name"' + result = rewrite_header_symbols(header, "old_name", "new_name") + assert "model_new_name" in result + assert '.model_name = "new_name"' in result + + def test_rewrite_op_extension_symbols(self): + from executorch.backends.nordic.axon.codegen import rewrite_op_extension_symbols + header = "extern void nrf_axon_nn_op_extension_sigmoid(void);" + result = rewrite_op_extension_symbols(header) + assert "axon_op_extension_sigmoid" in result + assert "nrf_axon_nn_op_extension_sigmoid" not in result + + def test_write_and_regenerate(self, tmp_path): + from executorch.backends.nordic.axon.codegen import ( + write_subgraph_header, + regenerate_table, + clean_generated_dir, + ) + # Write two subgraph headers + write_subgraph_header(tmp_path, "sub_aaa", "/* header A */\n") + write_subgraph_header(tmp_path, "sub_bbb", "/* header B */\n") + # Regenerate table + table_path = regenerate_table(tmp_path) + assert table_path.exists() + content = table_path.read_text() + assert "AXON_SUBGRAPHS_COUNT 2" in content + assert '"sub_aaa"' in content + assert '"sub_bbb"' in content + # Clean + removed = clean_generated_dir(tmp_path) + assert removed == 3 # 2 subgraph headers + 1 table + + +class TestTosaLowering: + """Test TOSA lowering for simple models. + + These tests export a simple PyTorch model through ExecuTorch's + edge lowering and TOSA conversion pipeline. They validate that + the AXON backend can process the TOSA flatbuffer into AXON layer + descriptors without requiring the Nordic compiler. + """ + + def _export_to_tosa(self, model, example_input): + """Export a model through the AXON backend, returning the TOSA + flatbuffer from the first delegated subgraph. + + Uses the full edge-lower pipeline (same as real deployment) so + the ARM pass pipeline handles quantized weight decomposition. + """ + import tempfile + from executorch.backends.arm.tosa.specification import TosaSpecification + from executorch.backends.arm.quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, + ) + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig + + model.eval() + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + + # Quantize + exported = torch.export.export(model, example_input) + quantizer = EthosUQuantizer(tosa_spec).set_global( + get_symmetric_quantization_config(is_per_channel=True) + ) + prepared = prepare_pt2e(exported.module(), quantizer) + prepared(*example_input) + quantized = convert_pt2e(prepared) + re_exported = torch.export.export(quantized, example_input) + + # Edge lower with AXON partitioner (no SDK needed — returns marker only) + # Use unique model name to avoid TOSA debug file collisions between tests + if not hasattr(self, '_tosa_test_counter'): + type(self)._tosa_test_counter = 0 + type(self)._tosa_test_counter += 1 + model_name = f"tosatest_{self._tosa_test_counter}" + + compile_spec = AxonCompileSpec(model_name=model_name) + partitioner = AxonPartitioner(compile_spec) + edge = to_edge_transform_and_lower( + re_exported, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + # Read the debug TOSA flatbuffer that AxonBackend.preprocess writes + import os + tosa_path = os.path.join(tempfile.gettempdir(), f"axon_tosa_debug_{model_name}.tosa") + if os.path.exists(tosa_path): + with open(tosa_path, "rb") as f: + return f.read() + + pytest.skip("TOSA debug file not generated — backend may have skipped") + + def test_simple_linear_to_tosa(self): + """A simple linear layer lowers to TOSA successfully.""" + model = nn.Sequential(nn.Linear(16, 8)) + example_input = (torch.randn(1, 16),) + tosa_bytes = self._export_to_tosa(model, example_input) + assert len(tosa_bytes) > 0 + + # Parse the TOSA flatbuffer + from executorch.backends.nordic.tosa_reader import parse_tosa_flatbuffer + graph = parse_tosa_flatbuffer(tosa_bytes) + assert len(graph.operators) > 0 + op_names = [op.op_name for op in graph.get_non_const_operators()] + # TOSA represents FC as CONV2D with reshapes + assert "CONV2D" in op_names or "FULLY_CONNECTED" in op_names + + def test_linear_to_axon_layers(self): + """A simple linear layer converts to AXON layers.""" + model = nn.Sequential(nn.Linear(16, 8)) + example_input = (torch.randn(1, 16),) + tosa_bytes = self._export_to_tosa(model, example_input) + + from executorch.backends.nordic.tosa_reader import parse_tosa_flatbuffer + from executorch.backends.nordic.axon_compiler import tosa_to_axon_layers + + graph = parse_tosa_flatbuffer(tosa_bytes) + layers = tosa_to_axon_layers(graph) + assert len(layers) >= 1 + # AXON compiler converts TOSA ops to AXON layer descriptors + # FC may appear as FULLY_CONNECTED (0) or CONV2D (1) depending on TOSA lowering + from executorch.backends.nordic.axon_compiler import AxonOp + compute_layers = [l for l in layers if l.operation in ( + AxonOp.FULLY_CONNECTED, AxonOp.CONV2D, AxonOp.POINTWISE_CONV2D, + )] + assert len(compute_layers) >= 1 + + def test_conv2d_to_axon_layers(self): + """A Conv2d layer converts to AXON layers.""" + model = nn.Sequential( + nn.Conv2d(1, 4, kernel_size=3, padding=1), + nn.ReLU(), + ) + example_input = (torch.randn(1, 1, 8, 8),) + tosa_bytes = self._export_to_tosa(model, example_input) + + from executorch.backends.nordic.tosa_reader import parse_tosa_flatbuffer + from executorch.backends.nordic.axon_compiler import tosa_to_axon_layers + + graph = parse_tosa_flatbuffer(tosa_bytes) + layers = tosa_to_axon_layers(graph) + assert len(layers) >= 1 + + def test_binary_builder(self): + """AXON binary builder produces non-empty output.""" + model = nn.Sequential(nn.Linear(16, 8)) + example_input = (torch.randn(1, 16),) + tosa_bytes = self._export_to_tosa(model, example_input) + + from executorch.backends.nordic.tosa_reader import parse_tosa_flatbuffer + from executorch.backends.nordic.axon_compiler import tosa_to_axon_layers + from executorch.backends.nordic.axon_binary import AxonBinaryBuilder + + graph = parse_tosa_flatbuffer(tosa_bytes) + layers = tosa_to_axon_layers(graph) + builder = AxonBinaryBuilder() + binary = builder.build(layers, model_name="test_linear") + assert len(binary) > 100 # Header alone is ~100 bytes + # Verify it contains the title string + assert b"AXON_INTERMEDIATE_REPRESENTATION_FILE" in binary diff --git a/backends/nordic/tosa_reader.py b/backends/nordic/tosa_reader.py new file mode 100644 index 00000000000..ff35ff449cf --- /dev/null +++ b/backends/nordic/tosa_reader.py @@ -0,0 +1,367 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Based on analysis of Vela's tosa_reader.py (ARM, Apache 2.0). +"""TOSA flatbuffer reader for the AXON backend. + +Parses a TOSA flatbuffer (as produced by ExecuTorch's TOSABackend._preprocess) +and extracts operators, tensors, weights, and quantization parameters into +a simple graph representation that can be converted to AXON layer descriptors. +""" + +from __future__ import annotations + +import logging +import numpy as np +from dataclasses import dataclass, field +from enum import IntEnum + +logger = logging.getLogger(__name__) +from typing import Any + +from tosa import TosaGraph as TG, Op +from tosa import Attribute as TosaAttribute +from tosa import ( + ClampAttribute, + ConcatAttribute, + Conv2dAttribute, + DepthwiseConv2dAttribute, + AvgPool2dAttribute, + MaxPool2dAttribute, + MulAttribute, + ReduceSumAttribute, + RescaleAttribute, +) + + +# TOSA DType enum → numpy dtype +TOSA_DTYPE_TO_NUMPY = { + 2: np.uint8, # UINT8 + 3: np.int8, # INT8 + 4: np.int8, # INT8 (alternate) + 5: np.int16, # INT16 + 6: np.int32, # INT32 + 7: np.int64, # INT48 (stored as int64) + 12: np.float32, # FP32 +} + +TOSA_DTYPE_NAMES = { + 2: "uint8", 3: "int8", 4: "int8", 5: "int16", 6: "int32", + 7: "int48", 10: "fp16", 12: "fp32", 14: "bool", +} + +# Build reverse map: Op enum value → name +TOSA_OP_NAMES = {} +for _attr in dir(Op.Op): + _val = getattr(Op.Op, _attr) + if isinstance(_val, int) and not _attr.startswith("_"): + TOSA_OP_NAMES[_val] = _attr + + +@dataclass +class TosaTensor: + """A tensor in the TOSA graph.""" + index: int + name: str + shape: list[int] + dtype: int # TOSA DType enum value + data: np.ndarray | None = None # Constant data (interpreted per dtype) + raw_bytes: bytes = b"" # Raw constant bytes (for multi-byte reinterpretation) + + @property + def dtype_name(self) -> str: + return TOSA_DTYPE_NAMES.get(self.dtype, f"unknown({self.dtype})") + + @property + def has_data(self) -> bool: + return self.data is not None + + @property + def numel(self) -> int: + result = 1 + for s in self.shape: + result *= s + return result + + def __repr__(self): + data_str = f", data={self.data.shape}" if self.has_data else "" + return f"TosaTensor({self.name}, shape={self.shape}, dtype={self.dtype_name}{data_str})" + + +@dataclass +class TosaOperator: + """An operator in the TOSA graph.""" + index: int + op_type: int # TOSA Op enum value + input_tensors: list[TosaTensor] + output_tensors: list[TosaTensor] + attributes: dict[str, Any] = field(default_factory=dict) + + @property + def op_name(self) -> str: + return TOSA_OP_NAMES.get(self.op_type, f"Unknown({self.op_type})") + + def __repr__(self): + ins = [t.name.split("/")[-1][:25] for t in self.input_tensors] + outs = [t.name.split("/")[-1][:25] for t in self.output_tensors] + return f"TosaOperator({self.op_name}, in={ins}, out={outs})" + + +@dataclass +class TosaGraph: + """Parsed TOSA graph.""" + tensors: list[TosaTensor] + operators: list[TosaOperator] + input_tensor_indices: list[int] # Indices into tensors[] for graph inputs + output_tensor_indices: list[int] # Indices into tensors[] for graph outputs + + def get_non_const_operators(self) -> list[TosaOperator]: + """Return operators that aren't CONST or CONST_SHAPE.""" + return [ + op for op in self.operators + if op.op_name not in ("CONST", "CONST_SHAPE") + ] + + def print_summary(self): + """Log a human-readable summary of the graph.""" + logger.info("TOSA Graph: %d tensors, %d operators", + len(self.tensors), len(self.operators)) + logger.info(" Inputs: %s", + [self.tensors[i].name for i in self.input_tensor_indices]) + logger.info(" Outputs: %s", + [self.tensors[i].name for i in self.output_tensor_indices]) + for op in self.operators: + if op.op_name in ("CONST", "CONST_SHAPE"): + continue + logger.info(" %s:", op.op_name) + for t in op.input_tensors: + prefix = " [const]" if t.has_data else " " + logger.info(" in: %s %-40s %-15s %s", + prefix, t.name.split("/")[-1][:40], + str(t.shape), t.dtype_name) + for t in op.output_tensors: + logger.info(" out: %-40s %-15s %s", + t.name.split("/")[-1][:40], + str(t.shape), t.dtype_name) + + +def _parse_conv2d_attrs(fb_op) -> dict[str, Any]: + """Extract pad, stride, dilation from TOSA Conv2dAttribute.""" + attr = Conv2dAttribute.Conv2dAttribute() + attr.Init(fb_op.Attribute().Bytes, fb_op.Attribute().Pos) + pad = [attr.Pad(i) for i in range(attr.PadLength())] if attr.PadLength() else [] + stride = [attr.Stride(i) for i in range(attr.StrideLength())] if attr.StrideLength() else [1, 1] + dilation = [attr.Dilation(i) for i in range(attr.DilationLength())] if attr.DilationLength() else [1, 1] + return {"pad": pad, "stride": stride, "dilation": dilation} + + +def _parse_depthwise_conv2d_attrs(fb_op) -> dict[str, Any]: + """Extract pad, stride, dilation from TOSA DepthwiseConv2dAttribute.""" + attr = DepthwiseConv2dAttribute.DepthwiseConv2dAttribute() + attr.Init(fb_op.Attribute().Bytes, fb_op.Attribute().Pos) + pad = [attr.Pad(i) for i in range(attr.PadLength())] if attr.PadLength() else [] + stride = [attr.Stride(i) for i in range(attr.StrideLength())] if attr.StrideLength() else [1, 1] + dilation = [attr.Dilation(i) for i in range(attr.DilationLength())] if attr.DilationLength() else [1, 1] + return {"pad": pad, "stride": stride, "dilation": dilation} + + +def _parse_avg_pool2d_attrs(fb_op) -> dict[str, Any]: + """Extract kernel, pad, stride from TOSA AvgPool2dAttribute.""" + attr = AvgPool2dAttribute.AvgPool2dAttribute() + attr.Init(fb_op.Attribute().Bytes, fb_op.Attribute().Pos) + kernel = [attr.Kernel(i) for i in range(attr.KernelLength())] if attr.KernelLength() else [] + pad = [attr.Pad(i) for i in range(attr.PadLength())] if attr.PadLength() else [] + stride = [attr.Stride(i) for i in range(attr.StrideLength())] if attr.StrideLength() else [1, 1] + return {"kernel": kernel, "pad": pad, "stride": stride} + + +def _parse_max_pool2d_attrs(fb_op) -> dict[str, Any]: + """Extract kernel, pad, stride from TOSA MaxPool2dAttribute.""" + attr = MaxPool2dAttribute.MaxPool2dAttribute() + attr.Init(fb_op.Attribute().Bytes, fb_op.Attribute().Pos) + kernel = [attr.Kernel(i) for i in range(attr.KernelLength())] if attr.KernelLength() else [] + pad = [attr.Pad(i) for i in range(attr.PadLength())] if attr.PadLength() else [] + stride = [attr.Stride(i) for i in range(attr.StrideLength())] if attr.StrideLength() else [1, 1] + return {"kernel": kernel, "pad": pad, "stride": stride} + + +def _parse_reduce_sum_attrs(fb_op) -> dict[str, Any]: + """Extract axis from TOSA ReduceSumAttribute.""" + attr = ReduceSumAttribute.ReduceSumAttribute() + attr.Init(fb_op.Attribute().Bytes, fb_op.Attribute().Pos) + return {"axis": attr.Axis()} + + +def _parse_concat_attrs(fb_op) -> dict[str, Any]: + """Extract axis from TOSA ConcatAttribute.""" + attr = ConcatAttribute.ConcatAttribute() + attr.Init(fb_op.Attribute().Bytes, fb_op.Attribute().Pos) + return {"axis": attr.Axis()} + + +def _parse_clamp_attrs(fb_op) -> dict[str, Any]: + """Extract min/max values from TOSA ClampAttribute. + + MinVal/MaxVal are stored as raw bytes; for INT8 quantized models + they represent the integer clamp bounds. + """ + attr = ClampAttribute.ClampAttribute() + attr.Init(fb_op.Attribute().Bytes, fb_op.Attribute().Pos) + min_val = list(attr.MinValAsNumpy()) if not attr.MinValIsNone() and attr.MinValLength() > 0 else [] + max_val = list(attr.MaxValAsNumpy()) if not attr.MaxValIsNone() and attr.MaxValLength() > 0 else [] + # Interpret as int8 for quantized models + min_int = int(np.frombuffer(bytes(min_val), dtype=np.int8)[0]) if len(min_val) >= 1 else -128 + max_int = int(np.frombuffer(bytes(max_val), dtype=np.int8)[0]) if len(max_val) >= 1 else 127 + return {"min_int": min_int, "max_int": max_int} + + +# Map TOSA attribute type enum → parser function +_ATTR_PARSERS: dict[int, Any] = { + TosaAttribute.Attribute.Conv2dAttribute: _parse_conv2d_attrs, + TosaAttribute.Attribute.DepthwiseConv2dAttribute: _parse_depthwise_conv2d_attrs, + TosaAttribute.Attribute.AvgPool2dAttribute: _parse_avg_pool2d_attrs, + TosaAttribute.Attribute.MaxPool2dAttribute: _parse_max_pool2d_attrs, + TosaAttribute.Attribute.ReduceSumAttribute: _parse_reduce_sum_attrs, + TosaAttribute.Attribute.ClampAttribute: _parse_clamp_attrs, + TosaAttribute.Attribute.ConcatAttribute: _parse_concat_attrs, +} + + +def _parse_operator_attributes(fb_op) -> dict[str, Any]: + """Parse operator-specific attributes from the TOSA flatbuffer.""" + attr_type = fb_op.AttributeType() + if attr_type == TosaAttribute.Attribute.NONE or fb_op.Attribute() is None: + return {} + parser = _ATTR_PARSERS.get(attr_type) + if parser is not None: + return parser(fb_op) + return {} + + +def parse_tosa_flatbuffer(tosa_bytes: bytes) -> TosaGraph: + """Parse a TOSA flatbuffer into a TosaGraph. + + Args: + tosa_bytes: Raw TOSA flatbuffer bytes (from TOSABackend._preprocess). + + Returns: + TosaGraph with tensors and operators. + """ + graph = TG.TosaGraph.GetRootAs(tosa_bytes, 0) + + if graph.RegionsLength() == 0: + raise ValueError("TOSA graph has no regions") + + region = graph.Regions(0) + if region.BlocksLength() == 0: + raise ValueError("TOSA region has no blocks") + + block = region.Blocks(0) + + # Parse tensors + tensors = [] + for t in range(block.TensorsLength()): + fb_tensor = block.Tensors(t) + name = fb_tensor.Name().decode() if fb_tensor.Name() else f"tensor_{t}" + shape = [fb_tensor.Shape(i) for i in range(fb_tensor.ShapeLength())] + dtype = fb_tensor.Type() + + # Extract constant data + data = None + raw_bytes = b"" + data_len = fb_tensor.DataLength() + if data_len > 0: + raw = fb_tensor.DataAsNumpy() + raw_bytes = bytes(raw) + np_dtype = TOSA_DTYPE_TO_NUMPY.get(dtype) + if np_dtype is not None: + values = np.frombuffer(raw_bytes, dtype=np_dtype) + numel = 1 + for s in shape: + numel *= s + if len(values) >= numel: + data = values[:numel].reshape(shape) + else: + data = values # Can't reshape, store flat + + tensors.append(TosaTensor( + index=t, + name=name, + shape=shape, + dtype=dtype, + data=data, + raw_bytes=raw_bytes, + )) + + # Build tensor lookup by name (for resolving operator inputs/outputs) + tensor_by_name = {t.name: t for t in tensors} + + # Parse operators + operators = [] + for o in range(block.OperatorsLength()): + fb_op = block.Operators(o) + op_type = fb_op.Op() + + # Resolve input/output tensors + # TOSA flatbuffer stores tensor references as names (bytes) + input_tensors = [] + for i in range(fb_op.InputsLength()): + tensor_name = fb_op.Inputs(i) + if isinstance(tensor_name, bytes): + tensor_name = tensor_name.decode() + if tensor_name in tensor_by_name: + input_tensors.append(tensor_by_name[tensor_name]) + + output_tensors = [] + for i in range(fb_op.OutputsLength()): + tensor_name = fb_op.Outputs(i) + if isinstance(tensor_name, bytes): + tensor_name = tensor_name.decode() + if tensor_name in tensor_by_name: + output_tensors.append(tensor_by_name[tensor_name]) + + # Deserialize operator attributes (padding, stride, dilation, kernel, etc.) + attributes = _parse_operator_attributes(fb_op) + + operators.append(TosaOperator( + index=o, + op_type=op_type, + input_tensors=input_tensors, + output_tensors=output_tensors, + attributes=attributes, + )) + + # Identify graph inputs (tensors with no data and no producing CONST op) + const_output_names = set() + for op in operators: + if TOSA_OP_NAMES.get(op.op_type) in ("CONST", "CONST_SHAPE"): + for t in op.output_tensors: + const_output_names.add(t.name) + + input_indices = [] + output_indices = [] + for t in tensors: + if not t.has_data and t.name not in const_output_names: + # Check if this tensor is an input to any operator but not an output + is_op_output = any( + t in op.output_tensors + for op in operators + if TOSA_OP_NAMES.get(op.op_type) not in ("CONST", "CONST_SHAPE") + ) + if not is_op_output: + input_indices.append(t.index) + + # Graph outputs are the outputs of the last non-CONST operator + if operators: + last_op = operators[-1] + output_indices = [t.index for t in last_op.output_tensors] + + return TosaGraph( + tensors=tensors, + operators=operators, + input_tensor_indices=input_indices, + output_tensor_indices=output_indices, + ) diff --git a/examples/nordic/README.md b/examples/nordic/README.md new file mode 100644 index 00000000000..8185fcccaa2 --- /dev/null +++ b/examples/nordic/README.md @@ -0,0 +1,48 @@ +# Nordic AXON NPU Examples + +Examples demonstrating ExecuTorch deployment on Nordic Semiconductor's +AXON NPU (nRF54LM20B). Each example builds progressively on the +previous one. + +## Examples + +| Example | What it demonstrates | AXON subgraphs | +|---------|---------------------|----------------| +| [hello_axon](hello_axon/) | Basic AXON delegation: single FC model, export, build, flash | 1 | +| [multi_layer](multi_layer/) | Layer chaining: AXON compiler combines multiple layers into one command buffer | 1 | +| [simple_rnn](simple_rnn/) | Multi-subgraph delegation: FC layers separated by CPU ops (tanh) produce separate command buffers | 2 | + +**Start with `hello_axon`** — it has the most detailed README with +setup instructions, Python environment explanation, and step-by-step +walkthrough. + +## Prerequisites + +- **nRF54LM20DK** development kit +- **nRF Connect SDK (NCS)** with Zephyr +- **Nordic sdk-edge-ai** — set `SDK_EDGE_AI_PATH` +- **uv** — Python package manager (`pip install uv`) + +## General workflow + +Each example follows the same pattern: + +```bash +cd examples/nordic/ + +# 1. Set up Python export environment (one-time) +./setup_export_env.sh + +# 2. Export model (trains, quantizes, compiles to AXON command buffers) +SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh + +# 3. Build firmware (in a new terminal with NCS toolchain) +source ~/ncs-workspace/nrf-connect-sdk-env.sh +cd +west build -b nrf54lm20dk/nrf54lm20b/cpuapp examples/nordic/ \ + --no-sysbuild -- \ + -DZEPHYR_EXTRA_MODULES="$(pwd);$SDK_EDGE_AI_PATH" + +# 4. Flash and check serial output +west flash +``` diff --git a/examples/nordic/hello_axon/.gitignore b/examples/nordic/hello_axon/.gitignore new file mode 100644 index 00000000000..066c18181ff --- /dev/null +++ b/examples/nordic/hello_axon/.gitignore @@ -0,0 +1,21 @@ +# Python venv (created by setup_export_env.sh) +.venv/ + +# Build artifacts +build/ + +# Generated model header (created by export_model.py) +src/model_pte.h + +# Generated AXON headers (created by export_model.py) +src/generated/axon_subgraph_*.h +src/generated/axon_subgraphs_table.h + +# uv lock file +uv.lock + +# Generated export wrapper +run_export.sh + +__pycache__/ +*.pyc diff --git a/examples/nordic/hello_axon/CMakeLists.txt b/examples/nordic/hello_axon/CMakeLists.txt new file mode 100644 index 00000000000..75e13b48e57 --- /dev/null +++ b/examples/nordic/hello_axon/CMakeLists.txt @@ -0,0 +1,97 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Hello AXON — minimal ExecuTorch + AXON NPU inference. +# +# Build: +# source nrf-connect-sdk-env.sh +# west build -b nrf54lm20dk/nrf54lm20b/cpuapp examples/nordic/hello_axon \ +# --no-sysbuild -- \ +# -DZEPHYR_EXTRA_MODULES=";" + +cmake_minimum_required(VERSION 3.24) + +# Skip install rules — avoids ExecuTorch export dependency issues with Zephyr +set(CMAKE_SKIP_INSTALL_RULES ON CACHE BOOL "" FORCE) + +find_package(Zephyr REQUIRED HINTS $ENV{ZEPHYR_BASE}) +project(hello_axon) + +# Source files: main.c (pure C entry) + inference.cpp (C++ ExecuTorch runner) +target_sources(app PRIVATE + src/main.c + src/inference.cpp +) + +# AXON delegate (from our ExecuTorch fork's backends/nordic/runtime/) +if(CONFIG_NRF_AXON) + if(DEFINED ZEPHYR_EXECUTORCH_MODULE_DIR) + target_sources(app PRIVATE + ${ZEPHYR_EXECUTORCH_MODULE_DIR}/backends/nordic/runtime/AxonBackend.cpp + ${ZEPHYR_EXECUTORCH_MODULE_DIR}/backends/nordic/runtime/axon_op_extensions.c + ) + endif() +endif() + +# Include app source dirs + generated headers +target_include_directories(app PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CMAKE_CURRENT_SOURCE_DIR}/src/generated +) + +# Stub axon_subgraphs_table.h if not generated yet +file(MAKE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/generated) +if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/src/generated/axon_subgraphs_table.h) + file(WRITE ${CMAKE_CURRENT_SOURCE_DIR}/src/generated/axon_subgraphs_table.h +"/* Stub — run export_model.py to generate real AXON subgraph headers. */\n" +"#pragma once\n" +"#include \"axon/nrf_axon_platform.h\"\n" +"#include \"drivers/axon/nrf_axon_nn_infer.h\"\n" +"#define NRF_AXON_MODEL_ALLOCATE_PACKED_OUTPUT_BUFFER 1\n" +"typedef struct {\n" +" const char *name;\n" +" const nrf_axon_nn_compiled_model_s *model;\n" +"} axon_subgraph_entry_t;\n" +"#define AXON_SUBGRAPHS_COUNT 0\n" +"static const axon_subgraph_entry_t axon_subgraphs[1] = {{0}};\n") +endif() + +# ExecuTorch setup +if(CONFIG_EXECUTORCH) + if(NOT DEFINED EXECUTORCH_DIR) + if(DEFINED ZEPHYR_EXECUTORCH_MODULE_DIR) + set(EXECUTORCH_DIR ${ZEPHYR_EXECUTORCH_MODULE_DIR}) + endif() + endif() + + set(EXECUTORCH_ROOT ${EXECUTORCH_DIR}) + include(${EXECUTORCH_DIR}/tools/cmake/Utils.cmake) + + # Kernel registry: 27 prim + 2 selective + 19 quantized = 48 → use 64 + target_compile_definitions(app PRIVATE MAX_KERNEL_NUM=64) + if(TARGET executorch_core) + target_compile_definitions(executorch_core PRIVATE MAX_KERNEL_NUM=64) + endif() + + # Link ExecuTorch (provides include paths + core runtime) + target_link_libraries(app PRIVATE libexecutorch) + + # Portable kernels (for un-delegated ops) + if(TARGET portable_kernels) + executorch_target_link_options_shared_lib(portable_kernels) + target_link_libraries(app PRIVATE portable_kernels) + endif() + + # Quantized ops (for q/dq at AXON delegation boundaries) + if(TARGET quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) + target_link_libraries(app PRIVATE quantized_ops_lib) + endif() + if(TARGET quantized_kernels) + executorch_target_link_options_shared_lib(quantized_kernels) + target_link_libraries(app PRIVATE quantized_kernels) + endif() +endif() diff --git a/examples/nordic/hello_axon/README.md b/examples/nordic/hello_axon/README.md new file mode 100644 index 00000000000..aba93b1e1ab --- /dev/null +++ b/examples/nordic/hello_axon/README.md @@ -0,0 +1,168 @@ +# Hello AXON — ExecuTorch + Nordic AXON NPU + +Minimal example: train a PyTorch model, compile it for the AXON NPU, +and run inference on the nRF54LM20DK. + +## What it does + +1. Trains a small 3-layer FC model to approximate sin(x) +2. Quantizes to INT8 via ExecuTorch PT2E +3. Delegates FC layers to the AXON NPU backend +4. Exports as `.pte` + AXON command buffer headers +5. Builds Zephyr firmware with the model embedded +6. Runs inference on the nRF54LM20DK — AXON NPU executes the FC layers + +## Two Python environments + +This example uses **two separate Python environments** for different +stages. This is necessary because the nRF Connect SDK (NCS) ships its +own Python (3.12) with its own `PYTHONHOME` and `PYTHONPATH`, which +conflict with PyTorch and ExecuTorch's Python packages. + +| Stage | Python | Why | +|-------|--------|-----| +| **Model export** (`export_model.py`) | Your own Python (3.10+) via `uv` | Needs PyTorch, ExecuTorch, tosa-tools — packages that don't exist in the NCS Python | +| **Firmware build** (`west build`) | NCS toolchain Python (3.12) | Needs Zephyr's cmake modules and the NCS build system | + +The `setup_export_env.sh` script creates an isolated `.venv/` in this +directory with all export dependencies. It uses +[uv](https://docs.astral.sh/uv/) to manage the environment. The NCS +toolchain Python is used only by `west build` and is activated by +sourcing `nrf-connect-sdk-env.sh`. + +**Important:** Do not source `nrf-connect-sdk-env.sh` in the same +terminal where you run `export_model.py`. The NCS environment sets +`PYTHONHOME` which overrides Python's standard library path and causes +import errors in any non-NCS Python. The `run_export.sh` wrapper +handles this automatically by unsetting `PYTHONHOME` before invoking +`uv`. + +## Prerequisites + +- **nRF54LM20DK** development kit +- **nRF Connect SDK (NCS)** installed — provides `west`, Zephyr, and + the ARM cross-compiler. See [Nordic's install guide](https://docs.nordicsemi.com/bundle/ncs-latest/page/nrf/installation.html). +- **Nordic sdk-edge-ai** — contains the AXON compiler library. + Set `SDK_EDGE_AI_PATH` to its location. +- **uv** — Python package manager. Install with `pip install uv`. + +## Step-by-step + +### 1. Set up the export environment (one time) + +```bash +cd examples/nordic/hello_axon +./setup_export_env.sh +``` + +This creates `.venv/` with PyTorch (CPU), ExecuTorch, tosa-tools, and +torchao. It also generates `run_export.sh` — a wrapper that sets the +correct `PYTHONPATH` for ExecuTorch. + +### 2. Export the model + +```bash +SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh +``` + +This trains sin(x), quantizes to INT8, compiles FC layers to AXON +command buffers, and produces: + +| Output | Description | +|--------|-------------| +| `build/hello_axon.pte` | ExecuTorch program file | +| `src/model_pte.h` | Model embedded as a C array (16-byte aligned) | +| `src/generated/axon_subgraph_*.h` | AXON command buffers per layer | +| `src/generated/axon_subgraphs_table.h` | Delegate lookup table | + +### 3. Build firmware + +Open a **new terminal** (or unrelated to the export step), then: + +```bash +# Activate the NCS toolchain (provides west, arm-zephyr-eabi-gcc, cmake) +source ~/ncs-workspace/nrf-connect-sdk-env.sh + +# Build from the executorch root directory +cd +west build -b nrf54lm20dk/nrf54lm20b/cpuapp examples/nordic/hello_axon \ + --no-sysbuild -- \ + -DZEPHYR_EXTRA_MODULES="$(pwd);$SDK_EDGE_AI_PATH" +``` + +### 4. Flash and verify + +```bash +west flash + +# Serial console (115200 baud): +# Linux: screen /dev/ttyACM0 115200 +# macOS: screen /dev/cu.usbmodem* 115200 +``` + +### Expected output + +``` +Hello AXON - ExecuTorch + Nordic AXON NPU +Board: nrf54lm20dk/nrf54lm20b/cpuapp +AXON NPU: enabled +Loading model (2084 bytes)... +Program loaded, 1 method(s) +Method: forward +AxonBackend::init (delegate 0, processed=36 bytes) + AXON model 'hello_axon_...' bound (out: 1x1x1 byte_width=1) +Method loaded +Running inference (x=1.57, expected sin~1.0)... +Inference: 20876 cycles (163 us @ 128 MHz) + output[0] = 0.987485 +Done. +``` + +## Architecture + +``` + setup_export_env.sh (Python venv with PyTorch + ExecuTorch) + | + v + run_export.sh (unsets PYTHONHOME, sets PYTHONPATH) + | + v + export_model.py (train → quantize → AXON compile → .pte) + | + +-- build/hello_axon.pte + +-- src/model_pte.h (embedded C array) + +-- src/generated/axon_subgraph_*.h (AXON command buffers) + +-- src/generated/axon_subgraphs_table.h + + nrf-connect-sdk-env.sh (NCS toolchain Python + west + compiler) + | + v + west build (Zephyr + ExecuTorch + AXON delegate) + | + v + zephyr.hex (firmware with model + runtime) + | + v + nRF54LM20DK (AXON NPU executes FC layers) +``` + +## File structure + +``` +hello_axon/ +├── setup_export_env.sh # One-time: create .venv with export deps +├── run_export.sh # Generated: export wrapper (sets PYTHONPATH) +├── export_model.py # Train, quantize, export sin(x) model +├── pyproject.toml # Python project (base deps for uv) +├── CMakeLists.txt # Zephyr firmware build +├── prj.conf # Zephyr project config +├── boards/ +│ └── nrf54lm20dk_...conf # AXON NPU board config +├── src/ +│ ├── main.c # Entry point (pure C) +│ ├── inference.cpp # ExecuTorch runtime (C++) +│ ├── model_pte.h # Generated: embedded .pte +│ └── generated/ # Generated: AXON command buffers +├── build/ # Build output (gitignored) +└── .venv/ # Export Python env (gitignored) +``` diff --git a/examples/nordic/hello_axon/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf b/examples/nordic/hello_axon/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf new file mode 100644 index 00000000000..ede2123520a --- /dev/null +++ b/examples/nordic/hello_axon/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf @@ -0,0 +1,9 @@ +# Board-specific config for nRF54LM20DK (nRF54LM20B with AXON NPU) + +# AXON NPU +CONFIG_NRF_AXON=y +CONFIG_NRF_AXON_INTERLAYER_BUFFER_SIZE=256 +CONFIG_NRF_AXON_PSUM_BUFFER_SIZE=0 + +# RRAM must stay in standby mode for AXON +CONFIG_MPSL_FORCE_RRAM_ON_ALL_THE_TIME=y diff --git a/examples/nordic/hello_axon/export_model.py b/examples/nordic/hello_axon/export_model.py new file mode 100644 index 00000000000..8b13f1a7b4d --- /dev/null +++ b/examples/nordic/hello_axon/export_model.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Export a simple FC model for the AXON NPU. + +Trains a 3-layer FC network on sin(x), quantizes to INT8, +partitions to AXON, and exports as .pte + generated headers. + +Usage: + # After running setup_export_env.sh: + PYTHONHOME= SDK_EDGE_AI_PATH=~/sdk-edge-ai uv run python export_model.py +""" +from __future__ import annotations + +import math +import os +from pathlib import Path + +import torch +import torch.nn as nn + + +class SineModel(nn.Module): + """3-layer FC: 1 -> 16 -> 16 -> 1. Approximates sin(x).""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(1, 16) + self.fc2 = nn.Linear(16, 16) + self.fc3 = nn.Linear(16, 1) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + return self.fc3(x) + + +def main(): + script_dir = Path(__file__).parent + build_dir = script_dir / "build" + build_dir.mkdir(exist_ok=True) + generated_dir = script_dir / "src" / "generated" + generated_dir.mkdir(parents=True, exist_ok=True) + + sdk_path = os.environ.get("SDK_EDGE_AI_PATH", os.path.expanduser("~/sdk-edge-ai")) + + # 1. Train + print("Training sine model...") + model = SineModel() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + loss_fn = nn.MSELoss() + x_train = torch.linspace(0, 2 * math.pi, 1000).unsqueeze(1) + y_train = torch.sin(x_train) + + model.train() + for epoch in range(1000): + pred = model(x_train) + loss = loss_fn(pred, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + print(f" Final loss: {loss.item():.6f}") + + # 2. Quantize + print("Quantizing to INT8...") + from executorch.backends.arm.tosa.specification import TosaSpecification + from executorch.backends.arm.quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, + ) + from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e + + model.eval() + example_input = (torch.randn(1, 1),) + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + + exported = torch.export.export(model, example_input, strict=False) + captured = exported.module() + + # Torch 2.11 quirk: export().module() inserts _guards_fn call_module + # nodes that the ExecuTorch pass manager doesn't handle. Strip them. + guard_nodes = [ + n for n in captured.graph.nodes + if n.op == "call_module" and "_guards" in str(n.target) + ] + for n in guard_nodes: + n.replace_all_uses_with(None) + captured.graph.erase_node(n) + for name in list(captured._modules.keys()): + if "_guards" in name: + delattr(captured, name) + captured.graph.lint() + captured.recompile() + + quantizer = EthosUQuantizer(tosa_spec).set_global( + get_symmetric_quantization_config(is_per_channel=True) + ) + prepared = prepare_pt2e(captured, quantizer) + # Calibrate + for _ in range(100): + prepared(torch.rand(1, 1) * 2 * math.pi) + quantized = convert_pt2e(prepared) + re_exported = torch.export.export(quantized, example_input, strict=False) + + # 3. Partition to AXON and export + print("Exporting with AXON backend...") + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig + + compile_spec = AxonCompileSpec( + sdk_edge_ai_path=sdk_path, + model_name="hello_axon", + axon_generated_dir=str(generated_dir), + ) + partitioner = AxonPartitioner(compile_spec) + + edge = to_edge_transform_and_lower( + re_exported, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + pte_path = build_dir / "hello_axon.pte" + edge.to_executorch().save(str(pte_path)) + print(f" .pte: {pte_path} ({pte_path.stat().st_size} bytes)") + + # 4. Generate C header from .pte (16-byte aligned for ExecuTorch) + model_pte_h = script_dir / "src" / "model_pte.h" + pte_bytes = pte_path.read_bytes() + with open(model_pte_h, "w") as f: + f.write("/* Auto-generated from hello_axon.pte */\n") + f.write("#include \n\n") + f.write("static const uint8_t model_pte[] __attribute__((aligned(16))) = {\n") + for i, b in enumerate(pte_bytes): + if i % 16 == 0: + f.write(" ") + f.write(f"0x{b:02x},") + if i % 16 == 15: + f.write("\n") + f.write("\n};\n") + f.write(f"static const uint32_t model_pte_len = {len(pte_bytes)};\n") + print(f" C header: {model_pte_h}") + + # List generated AXON headers + headers = list(generated_dir.glob("*.h")) + print(f" Generated {len(headers)} header(s) in {generated_dir}/") + for h in sorted(headers): + print(f" {h.name}") + + print("\nDone. Rebuild firmware to embed the model.") + + +if __name__ == "__main__": + main() diff --git a/examples/nordic/hello_axon/prj.conf b/examples/nordic/hello_axon/prj.conf new file mode 100644 index 00000000000..516019714d8 --- /dev/null +++ b/examples/nordic/hello_axon/prj.conf @@ -0,0 +1,37 @@ +# Copyright (c) 2026 iote.ai +# SPDX-License-Identifier: BSD-3-Clause +# +# hello_axon — ExecuTorch + AXON NPU inference + +# Console / UART +CONFIG_CONSOLE=y +CONFIG_UART_CONSOLE=y +CONFIG_SERIAL=y + +# Logging +CONFIG_LOG=y +CONFIG_LOG_DEFAULT_LEVEL=3 +CONFIG_LOG_BACKEND_UART=y +CONFIG_LOG_PRINTK=y +CONFIG_PRINTK=y + +# Timing (for cycle counting) +CONFIG_TIMING_FUNCTIONS=y + +# Float printing +CONFIG_PICOLIBC_IO_FLOAT=y + +# Memory +CONFIG_HEAP_MEM_POOL_SIZE=32768 +CONFIG_MAIN_STACK_SIZE=16384 + +# C++ (required by ExecuTorch) +CONFIG_CPP=y +CONFIG_STD_CPP17=y +CONFIG_REQUIRES_FULL_LIBCPP=y + +# ExecuTorch +CONFIG_EXECUTORCH=y +CONFIG_EXECUTORCH_ENABLE_LOGGING=y +CONFIG_EXECUTORCH_BUILD_PORTABLE_OPS=n +CONFIG_EXECUTORCH_OPTIMIZE_FOR_SIZE=y diff --git a/examples/nordic/hello_axon/pyproject.toml b/examples/nordic/hello_axon/pyproject.toml new file mode 100644 index 00000000000..cf478c16fb2 --- /dev/null +++ b/examples/nordic/hello_axon/pyproject.toml @@ -0,0 +1,15 @@ +# Copyright (c) 2026 iote.ai +# SPDX-License-Identifier: BSD-3-Clause +# +# Python project for hello_axon model export. +# See setup_export_env.sh for one-shot environment setup. + +[project] +name = "hello-axon" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "cffi>=1.15", + "numpy>=2.0", + "pyyaml", +] diff --git a/examples/nordic/hello_axon/setup_export_env.sh b/examples/nordic/hello_axon/setup_export_env.sh new file mode 100755 index 00000000000..64eed59310c --- /dev/null +++ b/examples/nordic/hello_axon/setup_export_env.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# One-shot setup of the Python environment for model export. +# +# Model export (PyTorch → ExecuTorch → AXON) and firmware build +# (Zephyr + west) use DIFFERENT Python environments: +# +# - Model export needs PyTorch, ExecuTorch, tosa-tools — large ML +# packages that are not part of the NCS toolchain. +# +# - Firmware build uses the NCS toolchain Python, which provides +# west, Zephyr cmake modules, and the ARM cross-compiler. The NCS +# toolchain sets PYTHONHOME and PYTHONPATH to point at its own +# Python 3.12, which breaks imports for any other Python. +# +# This script creates an isolated .venv/ in this directory with the +# export dependencies. It does NOT affect the NCS toolchain or any +# other Python environment on the system. The generated run_export.sh +# wrapper unsets PYTHONHOME/PYTHONPATH before running, so it works +# even if you previously sourced nrf-connect-sdk-env.sh. +# +# ExecuTorch itself is NOT pip-installed (its setup.py triggers a +# heavy cmake build). Instead, ExecuTorch's Python source tree is +# added to PYTHONPATH at runtime — this is sufficient for the export +# pipeline which only uses the Python backend code, not the C++ runtime. +# +# Prerequisites: uv (install with: pip install uv) +# +# Usage: +# cd examples/nordic/hello_axon +# ./setup_export_env.sh +# SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh + +set -e + +# Unset the NCS toolchain's Python environment variables. The NCS +# toolchain (activated by sourcing nrf-connect-sdk-env.sh) sets these +# to point at its bundled Python 3.12. If left set, uv would try to +# use the NCS Python's stdlib, causing "SRE module mismatch" errors +# because uv's Python (3.13) and NCS's stdlib (3.12) are incompatible. +unset PYTHONHOME +unset PYTHONPATH + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ET_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +echo "=== Setting up hello_axon export environment ===" +echo " Directory: $SCRIPT_DIR" +echo " ExecuTorch: $ET_ROOT" + +# Check uv +if ! command -v uv &>/dev/null; then + echo "ERROR: 'uv' not found. Install with: pip install uv" + exit 1 +fi + +cd "$SCRIPT_DIR" + +# Create venv and install base deps from pyproject.toml +echo "" +echo "--- Creating venv and installing base dependencies ---" +uv sync + +# Install torch (CPU variant) — needed by ExecuTorch and the export pipeline +echo "" +echo "--- Installing PyTorch (CPU) ---" +uv pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Install ExecuTorch Python packages. +# We DON'T do `pip install -e` (which triggers a heavy cmake build). +# Instead, we install the Python-only dependencies and add ExecuTorch +# to the path at runtime via PYTHONPATH. +echo "" +echo "--- Installing ExecuTorch dependencies ---" +uv pip install setuptools flatbuffers packaging "ruamel.yaml" tabulate + +# Install tosa-tools and torchao +echo "" +echo "--- Installing tosa-tools and torchao ---" +uv pip install tosa-tools torchao + +# Create a wrapper script that sets PYTHONPATH for export +cat > "$SCRIPT_DIR/run_export.sh" << 'WRAPPER' +#!/bin/bash +# Auto-generated by setup_export_env.sh +# +# Runs export_model.py in the isolated .venv with the correct PYTHONPATH. +# Safe to run even if nrf-connect-sdk-env.sh was sourced in this shell — +# we unset PYTHONHOME/PYTHONPATH to avoid NCS toolchain conflicts. +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ET_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +# Clear NCS Python env (PYTHONHOME would redirect stdlib to NCS's Python 3.12) +unset PYTHONHOME +unset PYTHONPATH +# Add ExecuTorch Python source to the path (not pip-installed, see README) +export PYTHONPATH="${ET_ROOT}/src" +exec uv run --directory "$SCRIPT_DIR" python "$SCRIPT_DIR/export_model.py" "$@" +WRAPPER +chmod +x "$SCRIPT_DIR/run_export.sh" + +# Verify +echo "" +echo "--- Verifying installation ---" +PYTHONPATH="${ET_ROOT}/src" uv run python -c " +from executorch.backends.nordic.axon import AxonBackend, AxonQuantizer +print(' ExecuTorch AXON backend: OK') +import tosa +print(' tosa-tools: OK') +import torch +print(f' PyTorch: {torch.__version__}') +print('Setup complete.') +" + +echo "" +echo "=== Done ===" +echo "" +echo "Export a model with:" +echo " cd $SCRIPT_DIR" +echo " SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh" diff --git a/examples/nordic/hello_axon/src/generated/.gitignore b/examples/nordic/hello_axon/src/generated/.gitignore new file mode 100644 index 00000000000..03f0c1f2a2a --- /dev/null +++ b/examples/nordic/hello_axon/src/generated/.gitignore @@ -0,0 +1,3 @@ +# Auto-generated by AXON backend; do not commit. +* +!.gitignore diff --git a/examples/nordic/hello_axon/src/inference.cpp b/examples/nordic/hello_axon/src/inference.cpp new file mode 100644 index 00000000000..321c52cd2f2 --- /dev/null +++ b/examples/nordic/hello_axon/src/inference.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * ExecuTorch inference runner for hello_axon. + * + * Loads an embedded .pte model, runs inference, prints output. + * If the model has AXON-delegated subgraphs, those run on the NPU + * automatically via the AxonBackend delegate. + */ + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +/* The .pte model embedded as a C array. + * Generated by export_model.py → src/model_pte.h + */ +#include "model_pte.h" + +namespace et = executorch::runtime; +using et::Error; +using et::EValue; +using et::HierarchicalAllocator; +using et::MemoryAllocator; +using et::MemoryManager; +using et::Method; +using et::Program; +using et::Result; +using et::Span; +using executorch::extension::BufferDataLoader; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using exec_aten::TensorImpl; + +/* Static memory pools */ +static uint8_t method_allocator_pool[16 * 1024]; +static uint8_t planned_memory_pool[16 * 1024]; +static uint8_t temp_allocator_pool[4 * 1024]; + +extern "C" int run_inference(void) +{ + if (model_pte_len == 0) { + ET_LOG(Error, "No model embedded. Run: ./setup_export_env.sh && ./run_export.sh"); + return -1; + } + + ET_LOG(Info, "Loading model (%u bytes)...", model_pte_len); + + BufferDataLoader loader(model_pte, model_pte_len); + Result program = Program::load(&loader); + if (!program.ok()) { + ET_LOG(Error, "Program::load failed: 0x%x", + static_cast(program.error())); + return -1; + } + ET_LOG(Info, "Program loaded, %zu method(s)", program->num_methods()); + + const char *method_name = nullptr; + { + auto name_result = program->get_method_name(0); + if (!name_result.ok()) { + ET_LOG(Error, "No methods in program"); + return -2; + } + method_name = *name_result; + } + ET_LOG(Info, "Method: %s", method_name); + + /* Memory management */ + MemoryAllocator method_allocator( + sizeof(method_allocator_pool), method_allocator_pool); + MemoryAllocator temp_allocator( + sizeof(temp_allocator_pool), temp_allocator_pool); + + auto method_meta = program->method_meta(method_name); + if (!method_meta.ok()) { + ET_LOG(Error, "Failed to get method meta"); + return -3; + } + + Span planned_span(planned_memory_pool, + sizeof(planned_memory_pool)); + HierarchicalAllocator planned_allocator({&planned_span, 1}); + MemoryManager memory_manager( + &method_allocator, &planned_allocator, &temp_allocator); + + Result method = program->load_method( + method_name, &memory_manager); + if (!method.ok()) { + ET_LOG(Error, "load_method failed: 0x%x", + static_cast(method.error())); + return -4; + } + ET_LOG(Info, "Method loaded"); + + /* Prepare input: single float value */ + float input_data[1] = {1.57f}; /* pi/2 — sin should be ~1.0 */ + Tensor::SizesType input_sizes[] = {1, 1}; + Tensor::DimOrderType input_dim_order[] = {0, 1}; + TensorImpl input_impl( + ScalarType::Float, 2, input_sizes, input_data, input_dim_order); + Tensor input_tensor(&input_impl); + + Error err = method->set_input(input_tensor, 0); + if (err != Error::Ok) { + ET_LOG(Error, "set_input failed: 0x%x", static_cast(err)); + return -5; + } + + /* Run inference. + * Use timing.h API for true CPU cycles — k_cycle_get_32() on this + * board ticks at 1 MHz (system clock), not 128 MHz (DWT). */ + timing_init(); + timing_start(); + + ET_LOG(Info, "Running inference (x=1.57, expected sin~1.0)..."); + + timing_t t_start = timing_counter_get(); + err = method->execute(); + timing_t t_end = timing_counter_get(); + + if (err != Error::Ok) { + ET_LOG(Error, "execute failed: 0x%x", static_cast(err)); + return -6; + } + + uint64_t cycles = timing_cycles_get(&t_start, &t_end); + uint64_t ns = timing_cycles_to_ns(cycles); + ET_LOG(Info, "Inference: %llu cycles (%llu us @ 128 MHz)", + (unsigned long long)cycles, (unsigned long long)(ns / 1000)); + + /* Read output */ + const EValue &output = method->get_output(0); + if (output.isTensor()) { + const auto &out_tensor = output.toTensor(); + if (out_tensor.scalar_type() == ScalarType::Float) { + const float *data = out_tensor.const_data_ptr(); + for (int i = 0; i < out_tensor.numel() && i < 10; i++) { + ET_LOG(Info, " output[%d] = %f", i, + static_cast(data[i])); + } + } + } + + ET_LOG(Info, "Done."); + return 0; +} diff --git a/examples/nordic/hello_axon/src/main.c b/examples/nordic/hello_axon/src/main.c new file mode 100644 index 00000000000..0f2171ccdfd --- /dev/null +++ b/examples/nordic/hello_axon/src/main.c @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * Hello AXON — minimal ExecuTorch + AXON NPU example. + */ + +#include +#include + +LOG_MODULE_REGISTER(hello_axon, LOG_LEVEL_INF); + +/* Inference runner implemented in C++ (inference.cpp) */ +extern int run_inference(void); + +int main(void) +{ + LOG_INF("Hello AXON - ExecuTorch + Nordic AXON NPU"); + LOG_INF("Board: %s", CONFIG_BOARD_TARGET); + +#if defined(CONFIG_NRF_AXON) && CONFIG_NRF_AXON + LOG_INF("AXON NPU: enabled"); +#else + LOG_INF("AXON NPU: not available (CPU only)"); +#endif + + int ret = run_inference(); + if (ret != 0) { + LOG_ERR("Inference failed: %d", ret); + } + + return 0; +} diff --git a/examples/nordic/multi_layer/.gitignore b/examples/nordic/multi_layer/.gitignore new file mode 100644 index 00000000000..066c18181ff --- /dev/null +++ b/examples/nordic/multi_layer/.gitignore @@ -0,0 +1,21 @@ +# Python venv (created by setup_export_env.sh) +.venv/ + +# Build artifacts +build/ + +# Generated model header (created by export_model.py) +src/model_pte.h + +# Generated AXON headers (created by export_model.py) +src/generated/axon_subgraph_*.h +src/generated/axon_subgraphs_table.h + +# uv lock file +uv.lock + +# Generated export wrapper +run_export.sh + +__pycache__/ +*.pyc diff --git a/examples/nordic/multi_layer/CMakeLists.txt b/examples/nordic/multi_layer/CMakeLists.txt new file mode 100644 index 00000000000..42587dc3c50 --- /dev/null +++ b/examples/nordic/multi_layer/CMakeLists.txt @@ -0,0 +1,97 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Hello AXON — minimal ExecuTorch + AXON NPU inference. +# +# Build: +# source nrf-connect-sdk-env.sh +# west build -b nrf54lm20dk/nrf54lm20b/cpuapp examples/nordic/hello_axon \ +# --no-sysbuild -- \ +# -DZEPHYR_EXTRA_MODULES=";" + +cmake_minimum_required(VERSION 3.24) + +# Skip install rules — avoids ExecuTorch export dependency issues with Zephyr +set(CMAKE_SKIP_INSTALL_RULES ON CACHE BOOL "" FORCE) + +find_package(Zephyr REQUIRED HINTS $ENV{ZEPHYR_BASE}) +project(multi_layer) + +# Source files: main.c (pure C entry) + inference.cpp (C++ ExecuTorch runner) +target_sources(app PRIVATE + src/main.c + src/inference.cpp +) + +# AXON delegate (from our ExecuTorch fork's backends/nordic/runtime/) +if(CONFIG_NRF_AXON) + if(DEFINED ZEPHYR_EXECUTORCH_MODULE_DIR) + target_sources(app PRIVATE + ${ZEPHYR_EXECUTORCH_MODULE_DIR}/backends/nordic/runtime/AxonBackend.cpp + ${ZEPHYR_EXECUTORCH_MODULE_DIR}/backends/nordic/runtime/axon_op_extensions.c + ) + endif() +endif() + +# Include app source dirs + generated headers +target_include_directories(app PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CMAKE_CURRENT_SOURCE_DIR}/src/generated +) + +# Stub axon_subgraphs_table.h if not generated yet +file(MAKE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/generated) +if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/src/generated/axon_subgraphs_table.h) + file(WRITE ${CMAKE_CURRENT_SOURCE_DIR}/src/generated/axon_subgraphs_table.h +"/* Stub — run export_model.py to generate real AXON subgraph headers. */\n" +"#pragma once\n" +"#include \"axon/nrf_axon_platform.h\"\n" +"#include \"drivers/axon/nrf_axon_nn_infer.h\"\n" +"#define NRF_AXON_MODEL_ALLOCATE_PACKED_OUTPUT_BUFFER 1\n" +"typedef struct {\n" +" const char *name;\n" +" const nrf_axon_nn_compiled_model_s *model;\n" +"} axon_subgraph_entry_t;\n" +"#define AXON_SUBGRAPHS_COUNT 0\n" +"static const axon_subgraph_entry_t axon_subgraphs[1] = {{0}};\n") +endif() + +# ExecuTorch setup +if(CONFIG_EXECUTORCH) + if(NOT DEFINED EXECUTORCH_DIR) + if(DEFINED ZEPHYR_EXECUTORCH_MODULE_DIR) + set(EXECUTORCH_DIR ${ZEPHYR_EXECUTORCH_MODULE_DIR}) + endif() + endif() + + set(EXECUTORCH_ROOT ${EXECUTORCH_DIR}) + include(${EXECUTORCH_DIR}/tools/cmake/Utils.cmake) + + # Kernel registry: 27 prim + 2 selective + 19 quantized = 48 → use 64 + target_compile_definitions(app PRIVATE MAX_KERNEL_NUM=64) + if(TARGET executorch_core) + target_compile_definitions(executorch_core PRIVATE MAX_KERNEL_NUM=64) + endif() + + # Link ExecuTorch (provides include paths + core runtime) + target_link_libraries(app PRIVATE libexecutorch) + + # Portable kernels (for un-delegated ops) + if(TARGET portable_kernels) + executorch_target_link_options_shared_lib(portable_kernels) + target_link_libraries(app PRIVATE portable_kernels) + endif() + + # Quantized ops (for q/dq at AXON delegation boundaries) + if(TARGET quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) + target_link_libraries(app PRIVATE quantized_ops_lib) + endif() + if(TARGET quantized_kernels) + executorch_target_link_options_shared_lib(quantized_kernels) + target_link_libraries(app PRIVATE quantized_kernels) + endif() +endif() diff --git a/examples/nordic/multi_layer/README.md b/examples/nordic/multi_layer/README.md new file mode 100644 index 00000000000..4e7c06ba674 --- /dev/null +++ b/examples/nordic/multi_layer/README.md @@ -0,0 +1,69 @@ +# Multi-Layer AXON — Chained Layers + +Demonstrates a multi-layer model where the AXON compiler chains +multiple FC layers into a single command buffer. Also showcases +the AXON delegate profiling API. + +## Model architecture + +``` +input (8-dim) + | + +----> fc_a (8 -> 16, ReLU) ─┐ + | │ + +----> fc_b (8 -> 16, ReLU) ──── multiply ── fc_head (16 -> 4) ── output +``` + +All operations (FC, ReLU, Multiply) are AXON-supported, so the AXON +compiler chains them into a single command buffer. The entire model +executes in one NPU dispatch call. + +## Generated files + +After export, `src/generated/` contains: + +``` +axon_subgraph_multi_layer_.h ← command buffer +axon_subgraphs_table.h ← lookup: name → compiled model +``` + +## Expected output + +``` +Multi-layer AXON — ExecuTorch multi-subgraph delegation +AXON NPU: enabled +Loading model (2084 bytes)... +AxonBackend::init (delegate 0, processed=36 bytes) + AXON model 'multi_layer_...' bound (out: 1x4x1 byte_width=1) +Method loaded (AXON delegates bound: 0) + input[0]: class=3 (-27.799, -6.318, -22.745, 30.326) 213 us + input[1]: class=1 (-26.535, 22.113, -34.117, -14.531) 211 us + input[2]: class=2 (-22.113, -27.167, 24.008, -17.690) 212 us + input[3]: class=0 (30.958, -19.586, -15.795, -37.908) 209 us +=== AXON delegate profile === +handles bound: 1 +total infer cycles: 56254 (4 calls) +avg cycles/call: 14063 +Done. +``` + +Note: `handles bound: 1` — all layers fit in one AXON command buffer. +The AXON delegate profiling shows 14K cycles per inference call. + +## Build and run + +Same pattern as `hello_axon` — see its README for prerequisites. + +```bash +cd examples/nordic/multi_layer +./setup_export_env.sh # one-time +SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh # export model + +# In a new terminal: +source ~/ncs-workspace/nrf-connect-sdk-env.sh +cd +west build -b nrf54lm20dk/nrf54lm20b/cpuapp examples/nordic/multi_layer \ + --no-sysbuild -- \ + -DZEPHYR_EXTRA_MODULES="$(pwd);$SDK_EDGE_AI_PATH" +west flash +``` diff --git a/examples/nordic/multi_layer/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf b/examples/nordic/multi_layer/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf new file mode 100644 index 00000000000..ede2123520a --- /dev/null +++ b/examples/nordic/multi_layer/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf @@ -0,0 +1,9 @@ +# Board-specific config for nRF54LM20DK (nRF54LM20B with AXON NPU) + +# AXON NPU +CONFIG_NRF_AXON=y +CONFIG_NRF_AXON_INTERLAYER_BUFFER_SIZE=256 +CONFIG_NRF_AXON_PSUM_BUFFER_SIZE=0 + +# RRAM must stay in standby mode for AXON +CONFIG_MPSL_FORCE_RRAM_ON_ALL_THE_TIME=y diff --git a/examples/nordic/multi_layer/export_model.py b/examples/nordic/multi_layer/export_model.py new file mode 100644 index 00000000000..44d52e6eb30 --- /dev/null +++ b/examples/nordic/multi_layer/export_model.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Export a multi-layer classifier for the AXON NPU. + +Demonstrates multi-subgraph delegation: each Linear layer becomes +its own AXON-compiled subgraph with a separate command buffer header. +The delegate lookup table maps subgraph names to compiled models. + +Model: 8 inputs → 32 → 16 → 4 outputs (3 FC layers, 3 AXON subgraphs) + +Usage: + SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh +""" +from __future__ import annotations + +import math +import os +from pathlib import Path + +import torch +import torch.nn as nn + + +class MultiLayerClassifier(nn.Module): + """Two-branch classifier with independent AXON subgraphs. + + Branch A: fc_a (8 -> 16) + Branch B: fc_b (8 -> 16) + Merge: element-wise multiply (breaks the delegation chain) + Head: fc_head (16 -> 4) + + The element-wise multiply between the branches forces the + partitioner to create separate AXON subgraphs for each branch + and the head, demonstrating multi-subgraph delegation. + """ + + def __init__(self): + super().__init__() + self.fc_a = nn.Linear(8, 16) + self.fc_b = nn.Linear(8, 16) + self.fc_head = nn.Linear(16, 4) + + def forward(self, x): + a = torch.relu(self.fc_a(x)) + b = torch.relu(self.fc_b(x)) + merged = a * b + return self.fc_head(merged) + + +def main(): + script_dir = Path(__file__).parent + build_dir = script_dir / "build" + build_dir.mkdir(exist_ok=True) + generated_dir = script_dir / "src" / "generated" + generated_dir.mkdir(parents=True, exist_ok=True) + + sdk_path = os.environ.get("SDK_EDGE_AI_PATH", os.path.expanduser("~/sdk-edge-ai")) + + # 1. Train on a simple classification task (XOR-like) + print("Training multi-layer classifier...") + model = MultiLayerClassifier() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + loss_fn = nn.CrossEntropyLoss() + + # Generate training data: classify 8-dim input into 4 classes + torch.manual_seed(42) + x_train = torch.randn(500, 8) + # Labels based on which quadrant the first two dims fall into + y_train = ((x_train[:, 0] > 0).long() * 2 + (x_train[:, 1] > 0).long()) + + model.train() + for epoch in range(500): + pred = model(x_train) + loss = loss_fn(pred, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + accuracy = (model(x_train).argmax(dim=1) == y_train).float().mean() + print(f" Final loss: {loss.item():.4f}, accuracy: {accuracy:.1%}") + + # 2. Quantize + print("Quantizing to INT8...") + from executorch.backends.arm.tosa.specification import TosaSpecification + from executorch.backends.arm.quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, + ) + from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e + + model.eval() + example_input = (torch.randn(1, 8),) + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + + exported = torch.export.export(model, example_input, strict=False) + captured = exported.module() + + # Strip torch 2.11 _guards_fn nodes + guard_nodes = [ + n for n in captured.graph.nodes + if n.op == "call_module" and "_guards" in str(n.target) + ] + for n in guard_nodes: + n.replace_all_uses_with(None) + captured.graph.erase_node(n) + for name in list(captured._modules.keys()): + if "_guards" in name: + delattr(captured, name) + captured.graph.lint() + captured.recompile() + + quantizer = EthosUQuantizer(tosa_spec).set_global( + get_symmetric_quantization_config(is_per_channel=True) + ) + prepared = prepare_pt2e(captured, quantizer) + # Calibrate + for _ in range(50): + prepared(torch.randn(1, 8)) + quantized = convert_pt2e(prepared) + re_exported = torch.export.export(quantized, example_input, strict=False) + + # 3. Partition to AXON and export + print("Exporting with AXON backend...") + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig + + compile_spec = AxonCompileSpec( + sdk_edge_ai_path=sdk_path, + model_name="multi_layer", + axon_generated_dir=str(generated_dir), + ) + partitioner = AxonPartitioner(compile_spec) + + edge = to_edge_transform_and_lower( + re_exported, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + pte_path = build_dir / "multi_layer.pte" + edge.to_executorch().save(str(pte_path)) + print(f" .pte: {pte_path} ({pte_path.stat().st_size} bytes)") + + # 4. Generate C header from .pte (16-byte aligned) + model_pte_h = script_dir / "src" / "model_pte.h" + pte_bytes = pte_path.read_bytes() + with open(model_pte_h, "w") as f: + f.write("/* Auto-generated from multi_layer.pte */\n") + f.write("#include \n\n") + f.write("static const uint8_t model_pte[] __attribute__((aligned(16))) = {\n") + for i, b in enumerate(pte_bytes): + if i % 16 == 0: + f.write(" ") + f.write(f"0x{b:02x},") + if i % 16 == 15: + f.write("\n") + f.write("\n};\n") + f.write(f"static const uint32_t model_pte_len = {len(pte_bytes)};\n") + print(f" C header: {model_pte_h}") + + # List generated AXON headers + headers = sorted(generated_dir.glob("*.h")) + subgraph_headers = [h for h in headers if h.name.startswith("axon_subgraph_") and h.name != "axon_subgraphs_table.h"] + print(f"\n Generated {len(subgraph_headers)} AXON subgraph(s):") + for h in subgraph_headers: + print(f" {h.name}") + print(f" Lookup table: axon_subgraphs_table.h") + + print(f"\n AXON compiler chained all layers into {len(subgraph_headers)} command buffer(s).") + print(" (Models with non-delegatable ops between layers produce multiple subgraphs.)") + + print("\nDone. Rebuild firmware to embed the model.") + + +if __name__ == "__main__": + main() diff --git a/examples/nordic/multi_layer/prj.conf b/examples/nordic/multi_layer/prj.conf new file mode 100644 index 00000000000..516019714d8 --- /dev/null +++ b/examples/nordic/multi_layer/prj.conf @@ -0,0 +1,37 @@ +# Copyright (c) 2026 iote.ai +# SPDX-License-Identifier: BSD-3-Clause +# +# hello_axon — ExecuTorch + AXON NPU inference + +# Console / UART +CONFIG_CONSOLE=y +CONFIG_UART_CONSOLE=y +CONFIG_SERIAL=y + +# Logging +CONFIG_LOG=y +CONFIG_LOG_DEFAULT_LEVEL=3 +CONFIG_LOG_BACKEND_UART=y +CONFIG_LOG_PRINTK=y +CONFIG_PRINTK=y + +# Timing (for cycle counting) +CONFIG_TIMING_FUNCTIONS=y + +# Float printing +CONFIG_PICOLIBC_IO_FLOAT=y + +# Memory +CONFIG_HEAP_MEM_POOL_SIZE=32768 +CONFIG_MAIN_STACK_SIZE=16384 + +# C++ (required by ExecuTorch) +CONFIG_CPP=y +CONFIG_STD_CPP17=y +CONFIG_REQUIRES_FULL_LIBCPP=y + +# ExecuTorch +CONFIG_EXECUTORCH=y +CONFIG_EXECUTORCH_ENABLE_LOGGING=y +CONFIG_EXECUTORCH_BUILD_PORTABLE_OPS=n +CONFIG_EXECUTORCH_OPTIMIZE_FOR_SIZE=y diff --git a/examples/nordic/multi_layer/pyproject.toml b/examples/nordic/multi_layer/pyproject.toml new file mode 100644 index 00000000000..cf478c16fb2 --- /dev/null +++ b/examples/nordic/multi_layer/pyproject.toml @@ -0,0 +1,15 @@ +# Copyright (c) 2026 iote.ai +# SPDX-License-Identifier: BSD-3-Clause +# +# Python project for hello_axon model export. +# See setup_export_env.sh for one-shot environment setup. + +[project] +name = "hello-axon" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "cffi>=1.15", + "numpy>=2.0", + "pyyaml", +] diff --git a/examples/nordic/multi_layer/setup_export_env.sh b/examples/nordic/multi_layer/setup_export_env.sh new file mode 100755 index 00000000000..64eed59310c --- /dev/null +++ b/examples/nordic/multi_layer/setup_export_env.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# One-shot setup of the Python environment for model export. +# +# Model export (PyTorch → ExecuTorch → AXON) and firmware build +# (Zephyr + west) use DIFFERENT Python environments: +# +# - Model export needs PyTorch, ExecuTorch, tosa-tools — large ML +# packages that are not part of the NCS toolchain. +# +# - Firmware build uses the NCS toolchain Python, which provides +# west, Zephyr cmake modules, and the ARM cross-compiler. The NCS +# toolchain sets PYTHONHOME and PYTHONPATH to point at its own +# Python 3.12, which breaks imports for any other Python. +# +# This script creates an isolated .venv/ in this directory with the +# export dependencies. It does NOT affect the NCS toolchain or any +# other Python environment on the system. The generated run_export.sh +# wrapper unsets PYTHONHOME/PYTHONPATH before running, so it works +# even if you previously sourced nrf-connect-sdk-env.sh. +# +# ExecuTorch itself is NOT pip-installed (its setup.py triggers a +# heavy cmake build). Instead, ExecuTorch's Python source tree is +# added to PYTHONPATH at runtime — this is sufficient for the export +# pipeline which only uses the Python backend code, not the C++ runtime. +# +# Prerequisites: uv (install with: pip install uv) +# +# Usage: +# cd examples/nordic/hello_axon +# ./setup_export_env.sh +# SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh + +set -e + +# Unset the NCS toolchain's Python environment variables. The NCS +# toolchain (activated by sourcing nrf-connect-sdk-env.sh) sets these +# to point at its bundled Python 3.12. If left set, uv would try to +# use the NCS Python's stdlib, causing "SRE module mismatch" errors +# because uv's Python (3.13) and NCS's stdlib (3.12) are incompatible. +unset PYTHONHOME +unset PYTHONPATH + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ET_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +echo "=== Setting up hello_axon export environment ===" +echo " Directory: $SCRIPT_DIR" +echo " ExecuTorch: $ET_ROOT" + +# Check uv +if ! command -v uv &>/dev/null; then + echo "ERROR: 'uv' not found. Install with: pip install uv" + exit 1 +fi + +cd "$SCRIPT_DIR" + +# Create venv and install base deps from pyproject.toml +echo "" +echo "--- Creating venv and installing base dependencies ---" +uv sync + +# Install torch (CPU variant) — needed by ExecuTorch and the export pipeline +echo "" +echo "--- Installing PyTorch (CPU) ---" +uv pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Install ExecuTorch Python packages. +# We DON'T do `pip install -e` (which triggers a heavy cmake build). +# Instead, we install the Python-only dependencies and add ExecuTorch +# to the path at runtime via PYTHONPATH. +echo "" +echo "--- Installing ExecuTorch dependencies ---" +uv pip install setuptools flatbuffers packaging "ruamel.yaml" tabulate + +# Install tosa-tools and torchao +echo "" +echo "--- Installing tosa-tools and torchao ---" +uv pip install tosa-tools torchao + +# Create a wrapper script that sets PYTHONPATH for export +cat > "$SCRIPT_DIR/run_export.sh" << 'WRAPPER' +#!/bin/bash +# Auto-generated by setup_export_env.sh +# +# Runs export_model.py in the isolated .venv with the correct PYTHONPATH. +# Safe to run even if nrf-connect-sdk-env.sh was sourced in this shell — +# we unset PYTHONHOME/PYTHONPATH to avoid NCS toolchain conflicts. +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ET_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +# Clear NCS Python env (PYTHONHOME would redirect stdlib to NCS's Python 3.12) +unset PYTHONHOME +unset PYTHONPATH +# Add ExecuTorch Python source to the path (not pip-installed, see README) +export PYTHONPATH="${ET_ROOT}/src" +exec uv run --directory "$SCRIPT_DIR" python "$SCRIPT_DIR/export_model.py" "$@" +WRAPPER +chmod +x "$SCRIPT_DIR/run_export.sh" + +# Verify +echo "" +echo "--- Verifying installation ---" +PYTHONPATH="${ET_ROOT}/src" uv run python -c " +from executorch.backends.nordic.axon import AxonBackend, AxonQuantizer +print(' ExecuTorch AXON backend: OK') +import tosa +print(' tosa-tools: OK') +import torch +print(f' PyTorch: {torch.__version__}') +print('Setup complete.') +" + +echo "" +echo "=== Done ===" +echo "" +echo "Export a model with:" +echo " cd $SCRIPT_DIR" +echo " SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh" diff --git a/examples/nordic/multi_layer/src/generated/.gitignore b/examples/nordic/multi_layer/src/generated/.gitignore new file mode 100644 index 00000000000..03f0c1f2a2a --- /dev/null +++ b/examples/nordic/multi_layer/src/generated/.gitignore @@ -0,0 +1,3 @@ +# Auto-generated by AXON backend; do not commit. +* +!.gitignore diff --git a/examples/nordic/multi_layer/src/inference.cpp b/examples/nordic/multi_layer/src/inference.cpp new file mode 100644 index 00000000000..986ff6d9d24 --- /dev/null +++ b/examples/nordic/multi_layer/src/inference.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * Multi-layer inference runner. + * + * Loads a .pte with multiple AXON-delegated subgraphs (one per FC layer). + * Each subgraph has its own compiled command buffer. The AXON delegate + * binds all subgraphs at init() and dispatches them at execute() time. + */ + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "model_pte.h" + +/* Profiling API from AxonBackend.h */ +extern "C" { + extern uint64_t axon_delegate_total_cycles; + extern uint32_t axon_delegate_total_calls; + void axon_delegate_dump_profile(void); +} + +namespace et = executorch::runtime; +using et::Error; +using et::EValue; +using et::HierarchicalAllocator; +using et::MemoryAllocator; +using et::MemoryManager; +using et::Method; +using et::Program; +using et::Result; +using et::Span; +using executorch::extension::BufferDataLoader; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using exec_aten::TensorImpl; + +static uint8_t method_allocator_pool[32 * 1024]; +static uint8_t planned_memory_pool[32 * 1024]; +static uint8_t temp_allocator_pool[8 * 1024]; + +extern "C" int run_inference(void) +{ + if (model_pte_len == 0) { + ET_LOG(Error, "No model embedded. Run: ./setup_export_env.sh && ./run_export.sh"); + return -1; + } + + ET_LOG(Info, "Loading model (%u bytes)...", model_pte_len); + + BufferDataLoader loader(model_pte, model_pte_len); + Result program = Program::load(&loader); + if (!program.ok()) { + ET_LOG(Error, "Program::load failed: 0x%x", + static_cast(program.error())); + return -1; + } + ET_LOG(Info, "Program loaded, %zu method(s)", program->num_methods()); + + const char *method_name = nullptr; + { + auto name_result = program->get_method_name(0); + if (!name_result.ok()) { + ET_LOG(Error, "No methods in program"); + return -2; + } + method_name = *name_result; + } + ET_LOG(Info, "Method: %s", method_name); + + MemoryAllocator method_allocator( + sizeof(method_allocator_pool), method_allocator_pool); + MemoryAllocator temp_allocator( + sizeof(temp_allocator_pool), temp_allocator_pool); + + auto method_meta = program->method_meta(method_name); + if (!method_meta.ok()) { + ET_LOG(Error, "Failed to get method meta"); + return -3; + } + + Span planned_span(planned_memory_pool, + sizeof(planned_memory_pool)); + HierarchicalAllocator planned_allocator({&planned_span, 1}); + MemoryManager memory_manager( + &method_allocator, &planned_allocator, &temp_allocator); + + Result method = program->load_method( + method_name, &memory_manager); + if (!method.ok()) { + ET_LOG(Error, "load_method failed: 0x%x", + static_cast(method.error())); + return -4; + } + ET_LOG(Info, "Method loaded (AXON delegates bound: %lu)", + (unsigned long)axon_delegate_total_calls); + + /* Test inputs — 8-dimensional vectors */ + float test_inputs[][8] = { + { 1.0f, 1.0f, 0.5f, -0.5f, 0.0f, 0.2f, -0.3f, 0.1f}, /* class 3 */ + {-1.0f, 1.0f, 0.3f, 0.7f, -0.1f, 0.0f, 0.4f, -0.2f}, /* class 1 */ + { 1.0f, -1.0f, -0.2f, 0.3f, 0.6f, -0.4f, 0.1f, 0.5f}, /* class 2 */ + {-1.0f, -1.0f, 0.0f, -0.8f, 0.2f, 0.5f, -0.6f, 0.0f}, /* class 0 */ + }; + const int num_tests = sizeof(test_inputs) / sizeof(test_inputs[0]); + + timing_init(); + timing_start(); + + for (int t = 0; t < num_tests; t++) { + Tensor::SizesType sizes[] = {1, 8}; + Tensor::DimOrderType dim_order[] = {0, 1}; + TensorImpl input_impl( + ScalarType::Float, 2, sizes, test_inputs[t], dim_order); + Tensor input_tensor(&input_impl); + + Error err = method->set_input(input_tensor, 0); + if (err != Error::Ok) { + ET_LOG(Error, "set_input failed: 0x%x", static_cast(err)); + return -5; + } + + timing_t t_start = timing_counter_get(); + err = method->execute(); + timing_t t_end = timing_counter_get(); + + if (err != Error::Ok) { + ET_LOG(Error, "execute failed: 0x%x", static_cast(err)); + return -6; + } + + uint64_t cycles = timing_cycles_get(&t_start, &t_end); + uint64_t ns = timing_cycles_to_ns(cycles); + + const EValue &output = method->get_output(0); + if (output.isTensor()) { + const auto &out = output.toTensor(); + const float *data = out.const_data_ptr(); + int best = 0; + for (int i = 1; i < out.numel() && i < 4; i++) { + if (data[i] > data[best]) best = i; + } + ET_LOG(Info, " input[%d]: class=%d (%.3f, %.3f, %.3f, %.3f) %llu us", + t, best, + (double)data[0], (double)data[1], + (double)data[2], (double)data[3], + (unsigned long long)(ns / 1000)); + } + } + + /* Dump AXON delegate profiling */ + axon_delegate_dump_profile(); + + ET_LOG(Info, "Done."); + return 0; +} diff --git a/examples/nordic/multi_layer/src/main.c b/examples/nordic/multi_layer/src/main.c new file mode 100644 index 00000000000..0d2840ca0d9 --- /dev/null +++ b/examples/nordic/multi_layer/src/main.c @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * Multi-layer AXON example — demonstrates multi-subgraph delegation. + */ + +#include +#include + +LOG_MODULE_REGISTER(multi_layer, LOG_LEVEL_INF); + +extern int run_inference(void); + +int main(void) +{ + LOG_INF("Multi-layer AXON — ExecuTorch multi-subgraph delegation"); + LOG_INF("Board: %s", CONFIG_BOARD_TARGET); + +#if defined(CONFIG_NRF_AXON) && CONFIG_NRF_AXON + LOG_INF("AXON NPU: enabled"); +#else + LOG_INF("AXON NPU: not available (CPU only)"); +#endif + + int ret = run_inference(); + if (ret != 0) { + LOG_ERR("Inference failed: %d", ret); + } + + return 0; +} diff --git a/examples/nordic/simple_rnn/.gitignore b/examples/nordic/simple_rnn/.gitignore new file mode 100644 index 00000000000..066c18181ff --- /dev/null +++ b/examples/nordic/simple_rnn/.gitignore @@ -0,0 +1,21 @@ +# Python venv (created by setup_export_env.sh) +.venv/ + +# Build artifacts +build/ + +# Generated model header (created by export_model.py) +src/model_pte.h + +# Generated AXON headers (created by export_model.py) +src/generated/axon_subgraph_*.h +src/generated/axon_subgraphs_table.h + +# uv lock file +uv.lock + +# Generated export wrapper +run_export.sh + +__pycache__/ +*.pyc diff --git a/examples/nordic/simple_rnn/CMakeLists.txt b/examples/nordic/simple_rnn/CMakeLists.txt new file mode 100644 index 00000000000..11d81872ca2 --- /dev/null +++ b/examples/nordic/simple_rnn/CMakeLists.txt @@ -0,0 +1,135 @@ +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Hello AXON — minimal ExecuTorch + AXON NPU inference. +# +# Build: +# source nrf-connect-sdk-env.sh +# west build -b nrf54lm20dk/nrf54lm20b/cpuapp examples/nordic/hello_axon \ +# --no-sysbuild -- \ +# -DZEPHYR_EXTRA_MODULES=";" + +cmake_minimum_required(VERSION 3.24) + +# Skip install rules — avoids ExecuTorch export dependency issues with Zephyr +set(CMAKE_SKIP_INSTALL_RULES ON CACHE BOOL "" FORCE) + +find_package(Zephyr REQUIRED HINTS $ENV{ZEPHYR_BASE}) +project(simple_rnn) + +# Source files: main.c (pure C entry) + inference.cpp (C++ ExecuTorch runner) +target_sources(app PRIVATE + src/main.c + src/inference.cpp +) + +# AXON delegate (from our ExecuTorch fork's backends/nordic/runtime/) +if(CONFIG_NRF_AXON) + if(DEFINED ZEPHYR_EXECUTORCH_MODULE_DIR) + target_sources(app PRIVATE + ${ZEPHYR_EXECUTORCH_MODULE_DIR}/backends/nordic/runtime/AxonBackend.cpp + ${ZEPHYR_EXECUTORCH_MODULE_DIR}/backends/nordic/runtime/axon_op_extensions.c + ) + endif() +endif() + +# Include app source dirs + generated headers +target_include_directories(app PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CMAKE_CURRENT_SOURCE_DIR}/src/generated +) + +# Stub axon_subgraphs_table.h if not generated yet +file(MAKE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/generated) +if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/src/generated/axon_subgraphs_table.h) + file(WRITE ${CMAKE_CURRENT_SOURCE_DIR}/src/generated/axon_subgraphs_table.h +"/* Stub — run export_model.py to generate real AXON subgraph headers. */\n" +"#pragma once\n" +"#include \"axon/nrf_axon_platform.h\"\n" +"#include \"drivers/axon/nrf_axon_nn_infer.h\"\n" +"#define NRF_AXON_MODEL_ALLOCATE_PACKED_OUTPUT_BUFFER 1\n" +"typedef struct {\n" +" const char *name;\n" +" const nrf_axon_nn_compiled_model_s *model;\n" +"} axon_subgraph_entry_t;\n" +"#define AXON_SUBGRAPHS_COUNT 0\n" +"static const axon_subgraph_entry_t axon_subgraphs[1] = {{0}};\n") +endif() + +# ExecuTorch setup +if(CONFIG_EXECUTORCH) + if(NOT DEFINED EXECUTORCH_DIR) + if(DEFINED ZEPHYR_EXECUTORCH_MODULE_DIR) + set(EXECUTORCH_DIR ${ZEPHYR_EXECUTORCH_MODULE_DIR}) + endif() + endif() + + set(EXECUTORCH_ROOT ${EXECUTORCH_DIR}) + include(${EXECUTORCH_DIR}/tools/cmake/Utils.cmake) + + # Kernel registry: 27 prim + 2 selective + 19 quantized = 48 → use 64 + target_compile_definitions(app PRIVATE MAX_KERNEL_NUM=64) + if(TARGET executorch_core) + target_compile_definitions(executorch_core PRIVATE MAX_KERNEL_NUM=64) + endif() + + # Build portable kernels (needed by selective op build) + if(NOT TARGET portable_kernels) + set(EXECUTORCH_PORTABLE_BUILD_KERNELS_ONLY ON) + add_subdirectory( + ${EXECUTORCH_DIR}/kernels/portable + ${CMAKE_CURRENT_BINARY_DIR}/executorch/kernels/portable + ) + unset(EXECUTORCH_PORTABLE_BUILD_KERNELS_ONLY) + endif() + + # Selective ops for the simple RNN model. + # The FC layers run on AXON. These ops run on CPU between subgraphs: + # add.out — combines fc_ih + fc_hh outputs + # tanh.out — recurrent hidden state activation + # To find ops for your model: + # python -c "from executorch.codegen.tools.selective_build import *; \ + # p=_get_program_from_buffer(open('model.pte','rb').read()); \ + # print(','.join(sorted(_get_program_operators(p))))" + set(ET_OPS_LIST "aten::add.out,aten::tanh.out" + CACHE STRING "CPU ops for simple_rnn (between AXON subgraphs)") + + include(${EXECUTORCH_DIR}/tools/cmake/Codegen.cmake) + gen_selected_ops( + LIB_NAME "simple_rnn_ops_lib" + ROOT_OPS "${ET_OPS_LIST}" + INCLUDE_ALL_OPS "" + ) + generate_bindings_for_kernels( + LIB_NAME "simple_rnn_ops_lib" + FUNCTIONS_YAML ${EXECUTORCH_DIR}/kernels/portable/functions.yaml + ) + gen_operators_lib( + LIB_NAME "simple_rnn_ops_lib" + KERNEL_LIBS portable_kernels + DEPS executorch + ) + target_link_libraries(app PRIVATE simple_rnn_ops_lib) + + # Link ExecuTorch (provides include paths + core runtime) + target_link_libraries(app PRIVATE libexecutorch) + + # Portable kernels + if(TARGET portable_kernels) + executorch_target_link_options_shared_lib(portable_kernels) + target_link_libraries(app PRIVATE portable_kernels) + endif() + + # Quantized ops (for q/dq at AXON delegation boundaries) + if(TARGET quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) + target_link_libraries(app PRIVATE quantized_ops_lib) + endif() + if(TARGET quantized_kernels) + executorch_target_link_options_shared_lib(quantized_kernels) + target_link_libraries(app PRIVATE quantized_kernels) + endif() +endif() diff --git a/examples/nordic/simple_rnn/README.md b/examples/nordic/simple_rnn/README.md new file mode 100644 index 00000000000..c017cd27ccf --- /dev/null +++ b/examples/nordic/simple_rnn/README.md @@ -0,0 +1,88 @@ +# Simple RNN — Multi-Subgraph AXON Delegation + +Demonstrates **multiple AXON subgraphs** in a single model. The RNN +has Linear layers delegated to the AXON NPU, separated by a recurrent +hidden state update (tanh) that runs on the CPU. This forces the +partitioner to create separate command buffers for each group of +delegatable layers. + +## Why multiple subgraphs + +The AXON NPU accelerates Linear (FC) layers but cannot execute +recurrent operations. In this RNN, the `tanh` activation on the +hidden state is not TOSA INT-compatible, so the partitioner splits +the model: + +``` +input (4-dim) hidden (8-dim) + | | + v v + fc_ih (4->8) fc_hh (8->8) ← AXON subgraph 1 + | | + +------ add ----+ + | + tanh ← CPU (breaks delegation) + | + +-------+-------+ + | | + v v + fc_out (8->2) h_new (8-dim) ← AXON subgraph 2 + | + output (2-dim) +``` + +Each subgraph has its own compiled command buffer. The delegate +lookup table maps subgraph names to compiled models. + +## Recurrent execution + +The firmware runs 4 RNN steps, feeding the hidden state output back +as input to the next step. Each step dispatches the AXON subgraphs +and runs tanh on the CPU between them. + +## Build and run + +Same pattern as `hello_axon` — see its README for prerequisites. + +```bash +cd examples/nordic/simple_rnn +./setup_export_env.sh # one-time +SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh # export model + +# In a new terminal: +source ~/ncs-workspace/nrf-connect-sdk-env.sh +cd +west build -b nrf54lm20dk/nrf54lm20b/cpuapp examples/nordic/simple_rnn \ + --no-sysbuild -- \ + -DZEPHYR_EXTRA_MODULES="$(pwd);$SDK_EDGE_AI_PATH" +west flash +``` + +## Expected output + +``` +Simple RNN - ExecuTorch multi-subgraph AXON delegation +AXON NPU: enabled +Loading model (4516 bytes)... +AxonBackend::init (delegate 0, processed=36 bytes) + AXON model 'rnn_step_4fcd48193cbf' bound (out: 1x8x1 byte_width=1) +AxonBackend::init (delegate 1, processed=36 bytes) + AXON model 'rnn_step_7ddecacbd5d9' bound (out: 1x2x1 byte_width=1) +Method loaded + step 0: out=(-0.334, -0.198) 629 us + step 1: out=(-0.699, -0.296) 691 us + step 2: out=(-0.433, -0.251) 688 us + step 3: out=(-0.919, 0.084) 691 us +=== AXON delegate profile === +handles bound: 2 +total infer cycles: 89897 (8 calls) +avg cycles/call: 11237 +Done. +``` + +Key observations: +- `handles bound: 2` — two separate AXON subgraphs +- `8 calls` — 2 subgraphs x 4 RNN steps +- Subgraph 0 (fc_ih + fc_hh → 8 outputs): 12K cycles/call +- Subgraph 1 (fc_out → 2 outputs): 10K cycles/call +- ~690 us per step total (AXON dispatch + CPU tanh + CPU add) diff --git a/examples/nordic/simple_rnn/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf b/examples/nordic/simple_rnn/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf new file mode 100644 index 00000000000..ede2123520a --- /dev/null +++ b/examples/nordic/simple_rnn/boards/nrf54lm20dk_nrf54lm20b_cpuapp.conf @@ -0,0 +1,9 @@ +# Board-specific config for nRF54LM20DK (nRF54LM20B with AXON NPU) + +# AXON NPU +CONFIG_NRF_AXON=y +CONFIG_NRF_AXON_INTERLAYER_BUFFER_SIZE=256 +CONFIG_NRF_AXON_PSUM_BUFFER_SIZE=0 + +# RRAM must stay in standby mode for AXON +CONFIG_MPSL_FORCE_RRAM_ON_ALL_THE_TIME=y diff --git a/examples/nordic/simple_rnn/export_model.py b/examples/nordic/simple_rnn/export_model.py new file mode 100644 index 00000000000..165aa3d096f --- /dev/null +++ b/examples/nordic/simple_rnn/export_model.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Export a simple RNN for the AXON NPU — multi-subgraph delegation. + +This model demonstrates why real models produce multiple AXON subgraphs. +A simple RNN has Linear layers (AXON-delegatable) separated by a +recurrent hidden state update (tanh — runs on CPU). The partitioner +cannot group the FC layers into one subgraph because the recurrent +loop between them is not TOSA-compatible. + +Model (single-step RNN): + input (4-dim) + hidden (8-dim) + → fc_ih: Linear(4 → 8) ← AXON subgraph A + → fc_hh: Linear(8 → 8) ← AXON subgraph B + → add + tanh ← CPU (recurrent state update) + → fc_out: Linear(8 → 2) ← AXON subgraph C + output (2-dim) + new_hidden (8-dim) + +The tanh activation on the hidden state is not TOSA INT-delegatable +(it requires the TABLE op which breaks the subgraph boundary), so +fc_ih, fc_hh, and fc_out become separate AXON subgraphs. + +Usage: + SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh +""" +from __future__ import annotations + +import os +from pathlib import Path + +import torch +import torch.nn as nn + + +class SimpleRNNStep(nn.Module): + """Single-step RNN cell with separate input and hidden projections. + + Unrolled for export: takes input + hidden, returns output + new_hidden. + The tanh between the Linear layers forces multi-subgraph delegation. + """ + + def __init__(self, input_size=4, hidden_size=8, output_size=2): + super().__init__() + self.fc_ih = nn.Linear(input_size, hidden_size) # input → hidden + self.fc_hh = nn.Linear(hidden_size, hidden_size) # hidden → hidden + self.fc_out = nn.Linear(hidden_size, output_size) # hidden → output + + def forward(self, x, h): + # Input and hidden projections (each delegatable to AXON) + ih = self.fc_ih(x) + hh = self.fc_hh(h) + # Recurrent state update — tanh is NOT TOSA INT-delegatable + h_new = torch.tanh(ih + hh) + # Output projection (delegatable to AXON) + out = self.fc_out(h_new) + return out, h_new + + +def main(): + script_dir = Path(__file__).parent + build_dir = script_dir / "build" + build_dir.mkdir(exist_ok=True) + generated_dir = script_dir / "src" / "generated" + generated_dir.mkdir(parents=True, exist_ok=True) + + sdk_path = os.environ.get("SDK_EDGE_AI_PATH", os.path.expanduser("~/sdk-edge-ai")) + + # 1. Create model (no training needed — just demonstrating delegation) + print("Creating simple RNN step model...") + model = SimpleRNNStep(input_size=4, hidden_size=8, output_size=2) + model.eval() + print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") + + # 2. Quantize + print("Quantizing to INT8...") + from executorch.backends.arm.tosa.specification import TosaSpecification + from executorch.backends.arm.quantizer import ( + TOSAQuantizer, + get_symmetric_quantization_config, + ) + from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e + + example_input = (torch.randn(1, 4), torch.randn(1, 8)) + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + + exported = torch.export.export(model, example_input, strict=False) + captured = exported.module() + + # Strip _guards_fn nodes + guard_nodes = [ + n for n in captured.graph.nodes + if n.op == "call_module" and "_guards" in str(n.target) + ] + for n in guard_nodes: + n.replace_all_uses_with(None) + captured.graph.erase_node(n) + for name in list(captured._modules.keys()): + if "_guards" in name: + delattr(captured, name) + captured.graph.lint() + captured.recompile() + + # Quantize only Linear layers — tanh stays in fp32 on CPU + quantizer = TOSAQuantizer(tosa_spec) + quantizer.set_module_type( + nn.Linear, get_symmetric_quantization_config(is_per_channel=False) + ) + prepared = prepare_pt2e(captured, quantizer) + # Calibrate + for _ in range(50): + prepared(torch.randn(1, 4), torch.randn(1, 8)) + quantized = convert_pt2e(prepared) + re_exported = torch.export.export(quantized, example_input, strict=False) + + # 3. Partition to AXON and export + print("Exporting with AXON backend...") + from executorch.backends.nordic.axon import AxonCompileSpec, AxonPartitioner + from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig + + compile_spec = AxonCompileSpec( + sdk_edge_ai_path=sdk_path, + model_name="rnn_step", + axon_generated_dir=str(generated_dir), + ) + partitioner = AxonPartitioner(compile_spec) + + edge = to_edge_transform_and_lower( + re_exported, + partitioner=[partitioner], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + pte_path = build_dir / "simple_rnn.pte" + edge.to_executorch().save(str(pte_path)) + print(f" .pte: {pte_path} ({pte_path.stat().st_size} bytes)") + + # 4. Generate C header + model_pte_h = script_dir / "src" / "model_pte.h" + pte_bytes = pte_path.read_bytes() + with open(model_pte_h, "w") as f: + f.write("/* Auto-generated from simple_rnn.pte */\n") + f.write("#include \n\n") + f.write("static const uint8_t model_pte[] __attribute__((aligned(16))) = {\n") + for i, b in enumerate(pte_bytes): + if i % 16 == 0: + f.write(" ") + f.write(f"0x{b:02x},") + if i % 16 == 15: + f.write("\n") + f.write("\n};\n") + f.write(f"static const uint32_t model_pte_len = {len(pte_bytes)};\n") + print(f" C header: {model_pte_h}") + + # List generated AXON headers + headers = sorted(generated_dir.glob("*.h")) + subgraph_headers = [h for h in headers + if h.name.startswith("axon_subgraph_") + and h.name != "axon_subgraphs_table.h"] + print(f"\n Generated {len(subgraph_headers)} AXON subgraph(s):") + for h in subgraph_headers: + print(f" {h.name}") + print(f" Lookup table: axon_subgraphs_table.h") + + if len(subgraph_headers) >= 2: + print(f"\n Multi-subgraph delegation: {len(subgraph_headers)} separate command buffers") + print(" (The recurrent tanh between FC layers splits them into separate subgraphs.)") + else: + print(f"\n {len(subgraph_headers)} subgraph(s) generated.") + + print("\nDone. Rebuild firmware to embed the model.") + + +if __name__ == "__main__": + main() diff --git a/examples/nordic/simple_rnn/prj.conf b/examples/nordic/simple_rnn/prj.conf new file mode 100644 index 00000000000..516019714d8 --- /dev/null +++ b/examples/nordic/simple_rnn/prj.conf @@ -0,0 +1,37 @@ +# Copyright (c) 2026 iote.ai +# SPDX-License-Identifier: BSD-3-Clause +# +# hello_axon — ExecuTorch + AXON NPU inference + +# Console / UART +CONFIG_CONSOLE=y +CONFIG_UART_CONSOLE=y +CONFIG_SERIAL=y + +# Logging +CONFIG_LOG=y +CONFIG_LOG_DEFAULT_LEVEL=3 +CONFIG_LOG_BACKEND_UART=y +CONFIG_LOG_PRINTK=y +CONFIG_PRINTK=y + +# Timing (for cycle counting) +CONFIG_TIMING_FUNCTIONS=y + +# Float printing +CONFIG_PICOLIBC_IO_FLOAT=y + +# Memory +CONFIG_HEAP_MEM_POOL_SIZE=32768 +CONFIG_MAIN_STACK_SIZE=16384 + +# C++ (required by ExecuTorch) +CONFIG_CPP=y +CONFIG_STD_CPP17=y +CONFIG_REQUIRES_FULL_LIBCPP=y + +# ExecuTorch +CONFIG_EXECUTORCH=y +CONFIG_EXECUTORCH_ENABLE_LOGGING=y +CONFIG_EXECUTORCH_BUILD_PORTABLE_OPS=n +CONFIG_EXECUTORCH_OPTIMIZE_FOR_SIZE=y diff --git a/examples/nordic/simple_rnn/pyproject.toml b/examples/nordic/simple_rnn/pyproject.toml new file mode 100644 index 00000000000..cf478c16fb2 --- /dev/null +++ b/examples/nordic/simple_rnn/pyproject.toml @@ -0,0 +1,15 @@ +# Copyright (c) 2026 iote.ai +# SPDX-License-Identifier: BSD-3-Clause +# +# Python project for hello_axon model export. +# See setup_export_env.sh for one-shot environment setup. + +[project] +name = "hello-axon" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "cffi>=1.15", + "numpy>=2.0", + "pyyaml", +] diff --git a/examples/nordic/simple_rnn/setup_export_env.sh b/examples/nordic/simple_rnn/setup_export_env.sh new file mode 100755 index 00000000000..64eed59310c --- /dev/null +++ b/examples/nordic/simple_rnn/setup_export_env.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# Copyright (c) 2026 iote.ai +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# One-shot setup of the Python environment for model export. +# +# Model export (PyTorch → ExecuTorch → AXON) and firmware build +# (Zephyr + west) use DIFFERENT Python environments: +# +# - Model export needs PyTorch, ExecuTorch, tosa-tools — large ML +# packages that are not part of the NCS toolchain. +# +# - Firmware build uses the NCS toolchain Python, which provides +# west, Zephyr cmake modules, and the ARM cross-compiler. The NCS +# toolchain sets PYTHONHOME and PYTHONPATH to point at its own +# Python 3.12, which breaks imports for any other Python. +# +# This script creates an isolated .venv/ in this directory with the +# export dependencies. It does NOT affect the NCS toolchain or any +# other Python environment on the system. The generated run_export.sh +# wrapper unsets PYTHONHOME/PYTHONPATH before running, so it works +# even if you previously sourced nrf-connect-sdk-env.sh. +# +# ExecuTorch itself is NOT pip-installed (its setup.py triggers a +# heavy cmake build). Instead, ExecuTorch's Python source tree is +# added to PYTHONPATH at runtime — this is sufficient for the export +# pipeline which only uses the Python backend code, not the C++ runtime. +# +# Prerequisites: uv (install with: pip install uv) +# +# Usage: +# cd examples/nordic/hello_axon +# ./setup_export_env.sh +# SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh + +set -e + +# Unset the NCS toolchain's Python environment variables. The NCS +# toolchain (activated by sourcing nrf-connect-sdk-env.sh) sets these +# to point at its bundled Python 3.12. If left set, uv would try to +# use the NCS Python's stdlib, causing "SRE module mismatch" errors +# because uv's Python (3.13) and NCS's stdlib (3.12) are incompatible. +unset PYTHONHOME +unset PYTHONPATH + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ET_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" + +echo "=== Setting up hello_axon export environment ===" +echo " Directory: $SCRIPT_DIR" +echo " ExecuTorch: $ET_ROOT" + +# Check uv +if ! command -v uv &>/dev/null; then + echo "ERROR: 'uv' not found. Install with: pip install uv" + exit 1 +fi + +cd "$SCRIPT_DIR" + +# Create venv and install base deps from pyproject.toml +echo "" +echo "--- Creating venv and installing base dependencies ---" +uv sync + +# Install torch (CPU variant) — needed by ExecuTorch and the export pipeline +echo "" +echo "--- Installing PyTorch (CPU) ---" +uv pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Install ExecuTorch Python packages. +# We DON'T do `pip install -e` (which triggers a heavy cmake build). +# Instead, we install the Python-only dependencies and add ExecuTorch +# to the path at runtime via PYTHONPATH. +echo "" +echo "--- Installing ExecuTorch dependencies ---" +uv pip install setuptools flatbuffers packaging "ruamel.yaml" tabulate + +# Install tosa-tools and torchao +echo "" +echo "--- Installing tosa-tools and torchao ---" +uv pip install tosa-tools torchao + +# Create a wrapper script that sets PYTHONPATH for export +cat > "$SCRIPT_DIR/run_export.sh" << 'WRAPPER' +#!/bin/bash +# Auto-generated by setup_export_env.sh +# +# Runs export_model.py in the isolated .venv with the correct PYTHONPATH. +# Safe to run even if nrf-connect-sdk-env.sh was sourced in this shell — +# we unset PYTHONHOME/PYTHONPATH to avoid NCS toolchain conflicts. +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ET_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +# Clear NCS Python env (PYTHONHOME would redirect stdlib to NCS's Python 3.12) +unset PYTHONHOME +unset PYTHONPATH +# Add ExecuTorch Python source to the path (not pip-installed, see README) +export PYTHONPATH="${ET_ROOT}/src" +exec uv run --directory "$SCRIPT_DIR" python "$SCRIPT_DIR/export_model.py" "$@" +WRAPPER +chmod +x "$SCRIPT_DIR/run_export.sh" + +# Verify +echo "" +echo "--- Verifying installation ---" +PYTHONPATH="${ET_ROOT}/src" uv run python -c " +from executorch.backends.nordic.axon import AxonBackend, AxonQuantizer +print(' ExecuTorch AXON backend: OK') +import tosa +print(' tosa-tools: OK') +import torch +print(f' PyTorch: {torch.__version__}') +print('Setup complete.') +" + +echo "" +echo "=== Done ===" +echo "" +echo "Export a model with:" +echo " cd $SCRIPT_DIR" +echo " SDK_EDGE_AI_PATH=~/sdk-edge-ai ./run_export.sh" diff --git a/examples/nordic/simple_rnn/src/generated/.gitignore b/examples/nordic/simple_rnn/src/generated/.gitignore new file mode 100644 index 00000000000..03f0c1f2a2a --- /dev/null +++ b/examples/nordic/simple_rnn/src/generated/.gitignore @@ -0,0 +1,3 @@ +# Auto-generated by AXON backend; do not commit. +* +!.gitignore diff --git a/examples/nordic/simple_rnn/src/inference.cpp b/examples/nordic/simple_rnn/src/inference.cpp new file mode 100644 index 00000000000..05f402cd2c2 --- /dev/null +++ b/examples/nordic/simple_rnn/src/inference.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * Simple RNN inference — demonstrates multi-subgraph delegation. + * + * The model has FC layers (AXON) separated by tanh (CPU), producing + * multiple delegate handles. Each AXON subgraph has its own compiled + * command buffer. + */ + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "model_pte.h" + +extern "C" { + extern uint64_t axon_delegate_total_cycles; + extern uint32_t axon_delegate_total_calls; + void axon_delegate_dump_profile(void); +} + +namespace et = executorch::runtime; +using et::Error; +using et::EValue; +using et::HierarchicalAllocator; +using et::MemoryAllocator; +using et::MemoryManager; +using et::Method; +using et::Program; +using et::Result; +using et::Span; +using executorch::extension::BufferDataLoader; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using exec_aten::TensorImpl; + +static uint8_t method_allocator_pool[32 * 1024]; +static uint8_t planned_memory_pool[32 * 1024]; +static uint8_t temp_allocator_pool[8 * 1024]; + +extern "C" int run_inference(void) +{ + if (model_pte_len == 0) { + ET_LOG(Error, "No model embedded. Run: ./setup_export_env.sh && ./run_export.sh"); + return -1; + } + + ET_LOG(Info, "Loading model (%u bytes)...", model_pte_len); + + BufferDataLoader loader(model_pte, model_pte_len); + Result program = Program::load(&loader); + if (!program.ok()) { + ET_LOG(Error, "Program::load failed: 0x%x", + static_cast(program.error())); + return -1; + } + ET_LOG(Info, "Program loaded, %zu method(s)", program->num_methods()); + + const char *method_name = nullptr; + { + auto name_result = program->get_method_name(0); + if (!name_result.ok()) return -2; + method_name = *name_result; + } + ET_LOG(Info, "Method: %s", method_name); + + MemoryAllocator method_allocator(sizeof(method_allocator_pool), method_allocator_pool); + MemoryAllocator temp_allocator(sizeof(temp_allocator_pool), temp_allocator_pool); + auto method_meta = program->method_meta(method_name); + if (!method_meta.ok()) return -3; + + Span planned_span(planned_memory_pool, sizeof(planned_memory_pool)); + HierarchicalAllocator planned_allocator({&planned_span, 1}); + MemoryManager memory_manager(&method_allocator, &planned_allocator, &temp_allocator); + + Result method = program->load_method(method_name, &memory_manager); + if (!method.ok()) { + ET_LOG(Error, "load_method failed: 0x%x", static_cast(method.error())); + return -4; + } + ET_LOG(Info, "Method loaded"); + + /* The RNN step model takes two inputs: x (1,4) and h (1,8) */ + float input_data[4] = {0.5f, -0.3f, 0.8f, -0.1f}; + float hidden_data[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + + Tensor::SizesType input_sizes[] = {1, 4}; + Tensor::SizesType hidden_sizes[] = {1, 8}; + Tensor::DimOrderType dim2[] = {0, 1}; + + TensorImpl input_impl(ScalarType::Float, 2, input_sizes, input_data, dim2); + TensorImpl hidden_impl(ScalarType::Float, 2, hidden_sizes, hidden_data, dim2); + Tensor input_tensor(&input_impl); + Tensor hidden_tensor(&hidden_impl); + + timing_init(); + timing_start(); + + /* Run 4 RNN steps, feeding hidden state back */ + for (int step = 0; step < 4; step++) { + Error err = method->set_input(input_tensor, 0); + if (err != Error::Ok) return -5; + err = method->set_input(hidden_tensor, 1); + if (err != Error::Ok) return -5; + + timing_t t_start = timing_counter_get(); + err = method->execute(); + timing_t t_end = timing_counter_get(); + if (err != Error::Ok) { + ET_LOG(Error, "step %d execute failed: 0x%x", step, (uint32_t)err); + return -6; + } + + uint64_t cycles = timing_cycles_get(&t_start, &t_end); + uint64_t ns = timing_cycles_to_ns(cycles); + + /* Output 0 = out (1,2), Output 1 = h_new (1,8) */ + const EValue &out_val = method->get_output(0); + if (out_val.isTensor()) { + const auto &out = out_val.toTensor(); + const float *d = out.const_data_ptr(); + ET_LOG(Info, " step %d: out=(%.3f, %.3f) %llu us", + step, (double)d[0], (double)d[1], + (unsigned long long)(ns / 1000)); + } + + /* Feed h_new back as hidden state for next step */ + if (method->outputs_size() > 1) { + const EValue &h_val = method->get_output(1); + if (h_val.isTensor()) { + const auto &h_out = h_val.toTensor(); + const float *h_data = h_out.const_data_ptr(); + for (int i = 0; i < 8 && i < h_out.numel(); i++) { + hidden_data[i] = h_data[i]; + } + } + } + } + + axon_delegate_dump_profile(); + ET_LOG(Info, "Done."); + return 0; +} diff --git a/examples/nordic/simple_rnn/src/main.c b/examples/nordic/simple_rnn/src/main.c new file mode 100644 index 00000000000..a7c623f9e68 --- /dev/null +++ b/examples/nordic/simple_rnn/src/main.c @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2026 iote.ai + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * Simple RNN — multi-subgraph AXON delegation example. + */ + +#include +#include + +LOG_MODULE_REGISTER(simple_rnn, LOG_LEVEL_INF); + +extern int run_inference(void); + +int main(void) +{ + LOG_INF("Simple RNN - ExecuTorch multi-subgraph AXON delegation"); + LOG_INF("Board: %s", CONFIG_BOARD_TARGET); + +#if defined(CONFIG_NRF_AXON) && CONFIG_NRF_AXON + LOG_INF("AXON NPU: enabled"); +#else + LOG_INF("AXON NPU: not available (CPU only)"); +#endif + + int ret = run_inference(); + if (ret != 0) { + LOG_ERR("Inference failed: %d", ret); + } + + return 0; +} diff --git a/zephyr/CMakeLists.txt b/zephyr/CMakeLists.txt index 6b6970f30ce..b18fcdfddc1 100644 --- a/zephyr/CMakeLists.txt +++ b/zephyr/CMakeLists.txt @@ -60,6 +60,9 @@ if(CONFIG_EXECUTORCH) if(CONFIG_ETHOS_U) set(EXECUTORCH_BUILD_ARM_BAREMETAL ON) endif() + if(CONFIG_NRF_AXON) + set(EXECUTORCH_BUILD_NORDIC_AXON ON) + endif() set(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL ON) set(EXECUTORCH_BUILD_KERNELS_QUANTIZED ON) @@ -133,6 +136,7 @@ if(CONFIG_EXECUTORCH) cortex_m_ops_lib cortex_m_kernels cmsis-nn + executorch_delegate_axon ) # Apply corrected flags to each target @@ -171,4 +175,10 @@ if(CONFIG_EXECUTORCH) ${EXECUTORCH_DIR}/third-party/flatcc/include ) + # Link AXON delegate if built + if(TARGET executorch_delegate_axon) + target_link_libraries(libexecutorch PUBLIC executorch_delegate_axon) + message(STATUS "Linking with Nordic AXON delegate") + endif() + endif() diff --git a/zephyr/Kconfig b/zephyr/Kconfig index ce977c93c67..d8c7c9d08aa 100644 --- a/zephyr/Kconfig +++ b/zephyr/Kconfig @@ -49,4 +49,12 @@ config EXECUTORCH_BUILD_PORTABLE_OPS operator building to include only needed operators, but the underlying kernel implementations will still be available. +config EXECUTORCH_BUILD_NORDIC_AXON + bool "Build Nordic AXON NPU delegate" + depends on NRF_AXON + help + Enable the AXON NPU delegate backend for Nordic nRF54LM20B. + Requires the Nordic sdk-edge-ai to be available in the build + environment. + endif # EXECUTORCH