From 16833dfc2313500c195c35db0d4645a1748ca24c Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Mon, 20 Apr 2026 23:45:55 -0700 Subject: [PATCH 1/5] Added MDP generation to QEff Compile Signed-off-by: Mohit Mehta --- QEfficient/base/modeling_qeff.py | 56 +++++-- QEfficient/compile/mdp_generator.py | 249 ++++++++++++++++++++++++++++ 2 files changed, 293 insertions(+), 12 deletions(-) create mode 100644 QEfficient/compile/mdp_generator.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 17b87afd14..ab36248d96 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -699,7 +699,8 @@ def _compile( specializations: Optional[List[Dict[str, int]]] = None, custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, - num_speculative_tokens: Optional[Union[int, List[int]]] = None, + mdp_num_partitions: int = 1, + num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, use_onnx_subfunctions: bool = False, @@ -721,7 +722,12 @@ def _compile( :specializations (list): List of specializations to compile for :custom_io (dict): Custom IO to specify the input and outputs in different formats than default :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. - :num_speculative_tokens (int | List[int], optional): Number of speculative tokens for TLM decode. A plain int K compiles one decode specialization (seq_len=K+1). A list [K0, K1, ...] compiles one specialization per value, enabling per-step dispatch to the cheapest kernel. + :mdp_num_partitions (int): Number of pipeline-parallel partitions for disaggregated prefill serving. + When > 1, the ONNX graph is read directly to generate a fully-populated MDP partition + config (nodeList per partition) without requiring a compiler round-trip. + Ignored when ``mdp_load_partition_config`` is already provided in compiler_options. + Defaults to 1 (template / tensor-slice MDP, existing behaviour). + :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` :qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.`` :compiler_options: Pass any compiler option as input. @@ -793,22 +799,47 @@ def _compile( + [f"-m={onnx_path}"] ) - # MDP partition config: prioritize dump over load - mdp_dump_json_path = compiler_options.pop("mdp_dump_partition_config", None) + # MDP partition config selection (three priorities, highest first): + # 1. User explicitly provides a pre-built MDP JSON to load. + # 2. Disaggregated (pipeline-parallel) MDP — generate from ONNX topsort. + # 3. Template (tensor-slice) MDP — single partition, nodeList absent. mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None) + # Silently discard any stale mdp_dump_partition_config key that callers + # may still pass; the compiler-round-trip dump path is no longer supported. + compiler_options.pop("mdp_dump_partition_config", None) mdp_ts_json = None - if mdp_dump_json_path: - if mdp_ts_json_path: - logger.warning( - "Loading and Dumping partition is not supported at the same time. Prioritizing dump config over load config!" - ) - command.append(f"-mdp-dump-partition-config={mdp_dump_json_path}") - elif mdp_ts_json_path: + if mdp_ts_json_path: command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") mdp_ts_json = load_json(str(mdp_ts_json_path)) + elif mdp_num_partitions > 1: + # Disaggregated (pipeline-parallel) MDP: generate a fully-populated + # nodeList per partition directly from the ONNX graph — no compiler + # round-trip required. + from QEfficient.compile.mdp_generator import generate_disagg_mdp_partition_config + + num_layers = getattr(self, "num_layers", None) + if num_layers is None: + raise AttributeError("Model does not expose 'num_layers'. Cannot generate disagg MDP partition config.") + num_cores = compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) + logger.info( + f"Generating disagg MDP partition config from ONNX: " + f"num_devices={mdp_ts_num_devices}, num_partitions={mdp_num_partitions}, " + f"num_layers={num_layers}, num_cores={num_cores}" + ) + mdp_ts_json = generate_disagg_mdp_partition_config( + onnx_path=str(onnx_path), + num_devices=mdp_ts_num_devices, + num_partitions=mdp_num_partitions, + num_layers=num_layers, + num_cores=num_cores, + ) + mdp_ts_json_path = compile_dir / f"mdp_disagg_{mdp_ts_num_devices}d_{mdp_num_partitions}p.json" + create_json(str(mdp_ts_json_path), mdp_ts_json) + command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") elif mdp_ts_num_devices > 1: - # Generate mdp config only if neither dump nor load is provided and num_devices > 1 + # Template (tensor-slice) MDP: single partition, empty nodeList. + # Used when PP is disabled (stages=1). Compiler fills the nodeList. mdp_ts_json = generate_mdp_partition_config( mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) ) @@ -833,6 +864,7 @@ def _compile( "specializations": specializations, "custom_io": custom_io, "mdp_ts_num_devices": mdp_ts_num_devices, + "mdp_num_partitions": mdp_num_partitions, "mdp_ts_json": mdp_ts_json, "num_speculative_tokens": num_speculative_tokens, "prefill_only": prefill_only, diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py new file mode 100644 index 0000000000..58a9ffe71c --- /dev/null +++ b/QEfficient/compile/mdp_generator.py @@ -0,0 +1,249 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +"""MDP generator for disaggregated prefill serving (PP-enabled, TS-enabled, stages>1).""" + +from typing import Any, Dict, List, Optional, Set +import onnx +import logging + +logger = logging.getLogger(__name__) + + +def _get_compiler_folded_nodes(graph) -> Set[str]: + """Return node names the compiler will fold away during ONNX import. + + Mirrors computeIsConstantFoldable() in ONNXModelLoader.cpp: a node is + foldable if every one of its inputs is a compile-time constant (initializer, + Constant op output, or output of another foldable node). Folded nodes are + absent from the compiler IR, so including them in nodeList is harmless but + excluding them produces a cleaner MDP closer to the compiler dump. + + Op types that the compiler never folds (ProtobufLoader.cpp:68): + Loop, Const, Identity, If, DequantizeLinear + """ + # const_values: output tensor names whose value is known at compile time. + # Seeded with all initializer names (model weights / constants). + const_values: Set[str] = {init.name for init in graph.initializer} + + # Constant op outputs are trivially compile-time constants; collect them + # upfront so the fixed-point loop below only needs one pass for everything else. + for node in graph.node: + if node.op_type == "Constant": + const_values.update(out for out in node.output if out) + + # Never-folded op types (compiler explicitly skips these - ProtobufLoader.cpp:68). + _NEVER_FOLD = frozenset({"Loop", "Const", "Identity", "If", "DequantizeLinear"}) + + # Keep marking nodes foldable until no new ones are found. + foldable_nodes: Set[str] = set() + while(True): + changed = False + for node in graph.node: + if not node.name or node.name in foldable_nodes: + continue + if node.op_type in _NEVER_FOLD or not node.input: + continue + if all(inp in const_values for inp in node.input if inp): + foldable_nodes.add(node.name) + const_values.update(out for out in node.output if out) + changed = True + if not changed: + break + + return foldable_nodes + + +def _get_layer_num(node_name: str) -> Optional[int]: + """Return transformer layer index from node name, or None. + + Supports layers.N (Llama/Mistral/Qwen/Gemma/Granite) and h.N (GPT-2). + """ + for part in node_name.split("/"): + if part.startswith("layers."): + suffix = part[len("layers.") :] + if suffix.isdigit(): + return int(suffix) + elif part.startswith("h."): + suffix = part[len("h.") :] + if suffix.isdigit(): + return int(suffix) + return None + + +def _get_inlined_node_map(model) -> tuple: + """Classify ONNX local functions and build inlined sub-node names. + + The compiler inlines a local function body into the parent graph during + ONNX import if it has < 100 nodes AND is not a known custom op + (ONNXModelLoaderSubFuns.cpp). Inlined call-sites do not appear in the + compiler IR; their sub-nodes are named /. + Known custom ops (registered via DEFINEKNOWNCUSTOMOP) keep their + call-site name in the IR and must be included in nodeList as-is. + + Returns: + inlined_node_map: dict mapping call-site name -> list of inlined + sub-node names (/). + non_inlined_funcs: set of function names that are NOT inlined + (known custom ops or >= 100 nodes); their + call-site names are valid nodeList entries. + """ + # Registered with DEFINEKNOWNCUSTOMOP in ONNXModelLoader.cpp + _KNOWN_CUSTOM_OPS = frozenset({"CustomRMSNorm"}) + + local_functions = {f.name: f for f in model.functions} + logger.info(f"Found {len(local_functions)} local function types: {set(local_functions.keys())}") + + inlined_funcs: Set[str] = set() + non_inlined_funcs: Set[str] = set() + for func_name, func in local_functions.items(): + if func_name in _KNOWN_CUSTOM_OPS or len(func.node) >= 100: + non_inlined_funcs.add(func_name) + logger.info(f" {func_name}: not inlined") + else: + inlined_funcs.add(func_name) + logger.info(f" {func_name}: {len(func.node)} nodes, will inline") + + inlined_node_map: Dict[str, List[str]] = {} + for node in model.graph.node: + if node.op_type in inlined_funcs: + func = local_functions[node.op_type] + inlined_node_map[node.name] = [ + f"{node.name}/{fn.name}" for fn in func.node if fn.name + ] + + logger.info(f"Inlined sub-nodes mapped for {len(inlined_node_map)} call-sites") + return inlined_node_map, non_inlined_funcs + + +def generate_disagg_mdp_partition_config( + onnx_path: str, + num_devices: int, + num_partitions: int, + num_layers: int, + num_cores: int = 16, +) -> Dict[str, Any]: + """Generate a pipeline-partitioned MDP config from an exported ONNX graph. + + Assigns nodes to partitions by transformer layer index. Non-layer nodes + (embeddings, lm_head) follow the nearest layer in topological order. + nodeList is a superset of the compiler dump; the compiler silently ignores + optimized-away names. Inlined local function call-sites (CtxScatterCB, + CtxGatherCB) are excluded; their /nNN sub-nodes are assigned automatically. + Known custom ops (CustomRMSNorm) are included by call-site name. + + For PP+TS: num_devices // num_partitions devices per partition; the + compiler applies tensor-slicing within each stage. + + Args: + onnx_path: Path to the exported ONNX file. + num_devices: Total devices (num_partitions * ts_per_stage). + num_partitions: Number of pipeline stages. + num_layers: Number of transformer layers. + num_cores: NSP cores per device (default 16). + + Returns: + dict with keys 'connections' and 'partitions'. + """ + assert num_partitions <= num_devices, f"num_partitions ({num_partitions}) must be <= num_devices ({num_devices})" + + layers_per_partition = num_layers // num_partitions + model = onnx.load(onnx_path, load_external_data=False) + + # Verify topological order (ONNX spec §3.3). Fails loudly on malformed exports. + # Graph inputs and initializers are excluded — they are not produced by any node. + graph_input_names: Set[str] = {inp.name for inp in model.graph.input} + initializer_names: Set[str] = {init.name for init in model.graph.initializer} + external_names: Set[str] = graph_input_names | initializer_names + + output_to_node: Dict[str, str] = {} + for node in model.graph.node: + for out in node.output: + if out: # "" marks optional unused outputs + output_to_node[out] = node.name + + seen_outputs: Set[str] = set() + for node in model.graph.node: + for inp in node.input: + if not inp: + continue + if inp in external_names: + continue + if inp in output_to_node and inp not in seen_outputs: + raise ValueError( + f"ONNX graph has a cycle or violates topological order: " + f"node '{node.name}' consumes '{inp}' produced by " + f"'{output_to_node[inp]}', but that producer has not appeared yet." + ) + for out in node.output: + if out: + seen_outputs.add(out) + + logger.info("Computing constant-foldable nodes...") + folded_nodes = _get_compiler_folded_nodes(model.graph) + logger.info(f"Found {len(folded_nodes)} compiler-folded nodes (excluded from nodeList)") + + inlined_node_map, non_inlined_functions = _get_inlined_node_map(model) + inlined_functions = {f.name for f in model.functions} - non_inlined_functions + + # First pass: assign main graph nodes to partitions by layer index. + partitions: List[List[str]] = [[] for _ in range(num_partitions)] + current_layer_partition = 0 + seen_first_layer = False + max_layer_seen = -1 + + for node in model.graph.node: + if not node.name.startswith("/"): + continue + if node.name in folded_nodes: + continue + if node.op_type in inlined_functions: + continue # inlined; sub-nodes added in second pass + + layer_num = _get_layer_num(node.name) + if layer_num is not None: + max_layer_seen = max(max_layer_seen, layer_num) + seen_first_layer = True + partition_idx = min(layer_num // layers_per_partition, num_partitions - 1) + current_layer_partition = partition_idx + partitions[partition_idx].append(node.name) + else: + if not seen_first_layer: + partitions[0].append(node.name) + else: + partitions[current_layer_partition].append(node.name) + + # Second pass: add inlined sub-nodes, inheriting their call-site's partition. + for call_site_name, inlined_nodes in inlined_node_map.items(): + layer_num = _get_layer_num(call_site_name) + if layer_num is not None: + partition_idx = min(layer_num // layers_per_partition, num_partitions - 1) + else: + partition_idx = current_layer_partition + partitions[partition_idx].extend(inlined_nodes) + + for i, partition in enumerate(partitions): + logger.info(f"Partition {i}: {len(partition)} nodes") + logger.info(f"Total nodes in MDP: {sum(len(p) for p in partitions)}") + + # PP-only: 1 device/partition; PP+TS: num_devices//num_partitions devices/partition. + device_ids = list(range(num_devices)) + devices_per_partition = num_devices // num_partitions + partition_objs = [] + for i, node_list in enumerate(partitions): + assigned_devices = device_ids[i * devices_per_partition : (i + 1) * devices_per_partition] + partition_objs.append( + { + "name": f"Partition{i}", + "nodeList": node_list, + "devices": [{"deviceId": dev_id, "numCores": num_cores} for dev_id in assigned_devices], + } + ) + + return { + "connections": [{"devices": device_ids, "type": "p2p"}], + "partitions": partition_objs, + } From bc006dde1d28f140b9e7281760c21804418b617b Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Mon, 20 Apr 2026 23:50:47 -0700 Subject: [PATCH 2/5] Formatting and Linting Signed-off-by: Mohit Mehta --- QEfficient/compile/mdp_generator.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py index 58a9ffe71c..cbc0f606ec 100644 --- a/QEfficient/compile/mdp_generator.py +++ b/QEfficient/compile/mdp_generator.py @@ -6,9 +6,10 @@ # ----------------------------------------------------------------------------- """MDP generator for disaggregated prefill serving (PP-enabled, TS-enabled, stages>1).""" +import logging from typing import Any, Dict, List, Optional, Set + import onnx -import logging logger = logging.getLogger(__name__) @@ -40,7 +41,7 @@ def _get_compiler_folded_nodes(graph) -> Set[str]: # Keep marking nodes foldable until no new ones are found. foldable_nodes: Set[str] = set() - while(True): + while True: changed = False for node in graph.node: if not node.name or node.name in foldable_nodes: @@ -111,9 +112,7 @@ def _get_inlined_node_map(model) -> tuple: for node in model.graph.node: if node.op_type in inlined_funcs: func = local_functions[node.op_type] - inlined_node_map[node.name] = [ - f"{node.name}/{fn.name}" for fn in func.node if fn.name - ] + inlined_node_map[node.name] = [f"{node.name}/{fn.name}" for fn in func.node if fn.name] logger.info(f"Inlined sub-nodes mapped for {len(inlined_node_map)} call-sites") return inlined_node_map, non_inlined_funcs From 7a0d6511fe66d24f818c2189d05a43575bd9ad6d Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Wed, 22 Apr 2026 01:41:18 -0700 Subject: [PATCH 3/5] Add compiler options - 'stages' Signed-off-by: Mohit Mehta --- QEfficient/base/modeling_qeff.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index ab36248d96..b25ee4d0bf 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -812,7 +812,8 @@ def _compile( if mdp_ts_json_path: command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") mdp_ts_json = load_json(str(mdp_ts_json_path)) - elif mdp_num_partitions > 1: + elif mdp_num_partitions > 1 or "stages" in compiler_options: + mdp_num_partitions = compiler_options.pop("stages", mdp_num_partitions) # Disaggregated (pipeline-parallel) MDP: generate a fully-populated # nodeList per partition directly from the ONNX graph — no compiler # round-trip required. From 8193f30a335bcd3195767bf7e026d14ae0e2bfeb Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Mon, 25 May 2026 01:49:26 -0700 Subject: [PATCH 4/5] Added support for layerwise export Signed-off-by: Mohit Mehta --- QEfficient/compile/mdp_generator.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py index cbc0f606ec..53faa64132 100644 --- a/QEfficient/compile/mdp_generator.py +++ b/QEfficient/compile/mdp_generator.py @@ -61,7 +61,12 @@ def _get_compiler_folded_nodes(graph) -> Set[str]: def _get_layer_num(node_name: str) -> Optional[int]: """Return transformer layer index from node name, or None. - Supports layers.N (Llama/Mistral/Qwen/Gemma/Granite) and h.N (GPT-2). + Supports: + - layers.N (Llama/Mistral/Qwen/Gemma/Granite) + - h.N (GPT-2) + - layer_N// prefix (subfunctions/merged ONNX where each subgraph is + prefixed with the layer index it belongs to, e.g. + "layer_3//model/embed_tokens/Gather") """ for part in node_name.split("/"): if part.startswith("layers."): @@ -72,6 +77,10 @@ def _get_layer_num(node_name: str) -> Optional[int]: suffix = part[len("h.") :] if suffix.isdigit(): return int(suffix) + elif part.startswith("layer_"): + suffix = part[len("layer_") :] + if suffix.isdigit(): + return int(suffix) return None @@ -92,8 +101,14 @@ def _get_inlined_node_map(model) -> tuple: (known custom ops or >= 100 nodes); their call-site names are valid nodeList entries. """ - # Registered with DEFINEKNOWNCUSTOMOP in ONNXModelLoader.cpp - _KNOWN_CUSTOM_OPS = frozenset({"CustomRMSNorm"}) + # Op types registered via DEFINEKNOWNCUSTOMOP in ONNXModelLoaderCustomOp.cpp. + # The compiler loads these as a single named node (not expanded/inlined). + _KNOWN_CUSTOM_OPS = frozenset( + { + "CustomRMSNorm", + "CastToUInt4", + } + ) local_functions = {f.name: f for f in model.functions} logger.info(f"Found {len(local_functions)} local function types: {set(local_functions.keys())}") @@ -195,7 +210,7 @@ def generate_disagg_mdp_partition_config( max_layer_seen = -1 for node in model.graph.node: - if not node.name.startswith("/"): + if not node.name: continue if node.name in folded_nodes: continue From fe974d0e5d9d2c34d294144e53970481d613158b Mon Sep 17 00:00:00 2001 From: Ann Date: Thu, 4 Jun 2026 22:40:40 +0530 Subject: [PATCH 5/5] Added inference serving with DMA slicing for KV handoff Signed-off-by: Ann --- QEfficient/base/modeling_qeff.py | 2 +- QEfficient/generation/cloud_infer_kv_slice.py | 805 ++++++++++++++++++ ...n3moe_disagg_mode_with_chunking_kvslice.py | 282 ++++++ 3 files changed, 1088 insertions(+), 1 deletion(-) create mode 100644 QEfficient/generation/cloud_infer_kv_slice.py create mode 100644 examples/disagg_serving/qwen3moe_disagg_mode_with_chunking_kvslice.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index b25ee4d0bf..8c359efc99 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -14,7 +14,7 @@ import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import onnx import torch diff --git a/QEfficient/generation/cloud_infer_kv_slice.py b/QEfficient/generation/cloud_infer_kv_slice.py new file mode 100644 index 0000000000..e4a415c70d --- /dev/null +++ b/QEfficient/generation/cloud_infer_kv_slice.py @@ -0,0 +1,805 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +cloud_infer_KV_share.py +======================= +optimised KV-cache handoff via shared DMA buffers and slice operations. + +Design +------ +KV handoff (zero-copy) +---------------------- +On the last prefill chunk, before calling np_run_pipeline(): + + session.set_data_for_kv_handoff( + kv_cache_buffers, # shared numpy arrays + [("batch_index", bidx), ("ctx_start", 0)], # slice offsets + exec_obj_idx, + session.prefill_buff_map[:-1], # KV outputs only (no logits) + ) + +This calls execObj.setDataWithSlices() which wires the prefill RetainedState +output buffers directly into the decode session's KV input slots via a sliced +DMA descriptor — the KV tensors never pass through Python/numpy. + +""" + +from __future__ import annotations + +import json +import logging +import os +import platform +import sys +from pathlib import Path +from queue import Queue +from typing import Any + +import numpy as np + +# ── SDK imports ─────────────────────────────────────────────────────────────── +try: + import qaicrt +except ImportError: + sys.path.append(f"/opt/qti-aic/dev/lib/{platform.machine()}") + import qaicrt + +try: + import QAicApi_pb2 as aicapi +except ImportError: + sys.path.append("/opt/qti-aic/dev/python") + import QAicApi_pb2 as aicapi + +logger = logging.getLogger(__name__) + +# ── Env-var names (match vllm convention) ──────────────────────────────────── +_PREFILL_QUEUE_LEN_ENV = "VLLM_QAIC_PREFILL_QUEUE_LEN" +_ASYNC_SCHEDULING_TIMEOUT_ENV = "VLLM_QAIC_ASYNC_SCHEDULING_EXEC_TIMEOUT" + +# ── dtype mapping ───────────────────────────────────────────────────────────── +AIC_TO_NP: dict[int, np.dtype] = { + getattr(aicapi, "BFLOAT16_TYPE", 11): np.dtype(np.float16), + aicapi.FLOAT_TYPE: np.dtype(np.float32), + aicapi.FLOAT_16_TYPE: np.dtype(np.float16), + aicapi.INT8_Q_TYPE: np.dtype(np.int8), + aicapi.UINT8_Q_TYPE: np.dtype(np.uint8), + aicapi.INT16_Q_TYPE: np.dtype(np.int16), + aicapi.INT32_Q_TYPE: np.dtype(np.int32), + aicapi.INT32_I_TYPE: np.dtype(np.int32), + aicapi.INT64_I_TYPE: np.dtype(np.int64), + aicapi.INT8_TYPE: np.dtype(np.int8), +} + + +# ───────────────────────────────────────────────────────────────────────────── +# Helper: build a descriptive error message when waitForCompletion fails +# ───────────────────────────────────────────────────────────────────────────── +def _build_completion_error(bindings, allowed_shapes: list, buf_dims_for_idx) -> str: + msg = "Failed to run" + if not allowed_shapes: + return msg + msg += '\n\n(Only if "No matching dimension found" error is present above)' + msg += "\nAllowed shapes:" + for i, allowed_shape in enumerate(allowed_shapes): + msg += f"\n{i}\n" + for binding, (elemsize, shape), (_, passed_shape) in zip(bindings, allowed_shape, buf_dims_for_idx): + if passed_shape[0] == 0: + if not binding.is_partial_buf_allowed: + logger.warning("Partial buffer not allowed for: %s", binding.name) + continue + msg += f"{binding.name}:\t{elemsize}\t{shape}\n" + msg += "\n\nPassed shapes:\n" + for binding, (elemsize, shape) in zip(bindings, buf_dims_for_idx): + if shape[0] == 0: + continue + msg += f"{binding.name}:\t{elemsize}\t{shape}\n" + return msg + + +# ───────────────────────────────────────────────────────────────────────────── +class QAICInferenceSession: + """ + Optimised QAIC inference session with KV-cache handoff via shared DMA + buffers and slice operations. + + Parameters + ---------- + qpc_path : str | Path + Path to the compiled QPC binary directory. + full_batch_size : int + Maximum batch size the QPC was compiled for. Used to compute the + batch_index modulo when wiring KV slices. + device_ids : list[int] | None + Device IDs to use. Single-device → non-MQ path; multi-device → MQ + via devMapping. None → default context (device 0). + activate : bool + Activate the program immediately on construction. + enable_debug_logs : bool + Enable QAIC runtime debug logging. + stages : int + Number of pipeline stages for pipelined prefill. + prefill exec-obj pool size = stages + 1 (overridable via env-var). + cluster_id : str | None + "prefill" → only prefill exec-objs, no decode slot. + "decode" → only one decode exec-obj, no prefill pool. + None → combined mode: one decode slot + prefill pool. + use_async_scheduling : bool + When True (and cluster_id is None) allocates an extra prefill exec-obj + so a new chunk can be enqueued before the previous one completes. + """ + + def __init__( + self, + qpc_path: str | Path, + full_batch_size: int = 1, + device_ids: list[int] | None = None, + activate: bool = True, + enable_debug_logs: bool = False, + stages: int | None = 1, + cluster_id: str | None = None, + use_async_scheduling: bool = False, + ) -> None: + self.stages: int = stages if stages is not None else 1 + self.cluster_id = cluster_id + self.full_batch_size = full_batch_size + self.async_scheduling_exec_timeout: int = int(os.getenv(_ASYNC_SCHEDULING_TIMEOUT_ENV, 300)) + + # ── Exec-obj pool layout ───────────────────────────────────────────── + # Layout in self.execObj list: + # [0 .. decode_num) → decode slot(s) + # [decode_num .. queue_len) → prefill pool + # + # cluster_id="decode" : 1 decode slot, 0 prefill + # cluster_id="prefill" : 0 decode slots, stages+1 prefill + # cluster_id=None : 1 decode slot + 1 (or 2 if async) prefill + if cluster_id == "decode": + self.prefill_num_execObj: int = 0 + self.decode_num_execObj: int = 1 + self.decode_execObj_idx: int | None = 0 + elif cluster_id == "prefill": + self.prefill_num_execObj = int(os.getenv(_PREFILL_QUEUE_LEN_ENV, self.stages + 1)) + self.decode_num_execObj = 0 + self.decode_execObj_idx = None + else: + _default_prefill = 2 if use_async_scheduling else 1 + self.prefill_num_execObj = int(os.getenv(_PREFILL_QUEUE_LEN_ENV, _default_prefill)) + self.decode_num_execObj = 1 + self.decode_execObj_idx = 0 + + self.queue_len: int = self.prefill_num_execObj + self.decode_num_execObj + + # Thread-safe queue of available prefill exec-obj indices + self.prefill_available_exec_objs: Queue[int] = Queue() + _prefill_start = self.decode_num_execObj + for i in range(_prefill_start, _prefill_start + self.prefill_num_execObj): + self.prefill_available_exec_objs.put(i) + + logger.debug( + "cluster_id=%s async=%s prefill_exec_objs=%d decode_exec_objs=%d", + cluster_id, + use_async_scheduling, + self.prefill_num_execObj, + self.decode_num_execObj, + ) + + # ── Context + Queue ────────────────────────────────────────────────── + if device_ids is not None: + devices = qaicrt.QIDList(device_ids) + self.context = qaicrt.Context(devices) + self.queue = qaicrt.Queue(self.context, device_ids[0]) + else: + self.context = qaicrt.Context() + self.queue = qaicrt.Queue(self.context, 0) + + if enable_debug_logs: + assert self.context.setLogLevel(qaicrt.QLogLevel.QL_DEBUG) == qaicrt.QStatus.QS_SUCCESS, ( + "Failed to setLogLevel" + ) + + # One thread per queue — avoids head-of-line blocking between + # concurrent prefill and decode enqueues + _qprops = qaicrt.QAicQueueProperties() + _qprops.numThreadsPerQueue = 1 + self.queue.initProperties(_qprops) + + # ── Load QPC + IO descriptor ───────────────────────────────────────── + qpc = qaicrt.Qpc(str(qpc_path)) + iodesc = aicapi.IoDesc() + status, iodesc_data = qpc.getIoDescriptor() + assert status == qaicrt.QStatus.QS_SUCCESS, "Failed to getIoDescriptor" + iodesc.ParseFromString(bytes(iodesc_data)) + + self.allowed_shapes: list = [ + [(AIC_TO_NP[x.type].itemsize, list(x.dims)) for x in allowed_shape.shapes] + for allowed_shape in iodesc.allowed_shapes + ] + self.bindings = iodesc.selected_set.bindings + self.binding_index_map: dict[str, int] = {b.name: b.index for b in self.bindings} + + # ── Program ────────────────────────────────────────────────────────── + prog_props = qaicrt.QAicProgramProperties() + prog_props.dataPathTimeoutMs = 60_000 + + _dev_id_non_mq = None + if device_ids: + if len(device_ids) == 1: + _dev_id_non_mq = device_ids[0] + else: + prog_props.devMapping = ":".join(map(str, device_ids)) + + self.program = qaicrt.Program(self.context, _dev_id_non_mq, qpc, prog_props) + assert self.program.load() == qaicrt.QStatus.QS_SUCCESS, "Failed to load program" + + self.activate_done = False + if activate: + self.activate() + + # ── Per-exec-obj qbuffers and buf_dims ─────────────────────────────── + # Each exec-obj gets its own buffer list so concurrent in-flight + # submissions don't clobber each other's buffer descriptors. + self.qbuffers: list[list[qaicrt.QBuffer]] = [ + [qaicrt.QBuffer(bytes(b.size)) for b in self.bindings] for _ in range(self.queue_len) + ] + self.buf_dims: list[qaicrt.BufferDimensionsVecRef] = [ + qaicrt.BufferDimensionsVecRef([(AIC_TO_NP[b.type].itemsize, list(b.dims)) for b in self.bindings]) + for _ in range(self.queue_len) + ] + + # ── KV slicing spec (for zero-copy DMA handoff) ────────────────────── + # Created once from the program; reused across all setDataWithSlices calls. + self.kv_slicing_spec_handle = None + self.repetition_penalty_spec_handle = None + + if "past_key.0_RetainedState" in self.binding_index_map: + _rs_binding = self.bindings[self.binding_index_map["past_key.0_RetainedState"]] + # kv_shape / kv_size used externally to allocate shared KV buffers + self.kv_shape: list[int] = list(_rs_binding.dims) + self.kv_shape[0] = 1 # per-batch-slot shape + self.kv_size: np.dtype = AIC_TO_NP[_rs_binding.type] + self.kv_slicing_spec_handle = self.get_slicing_spec_handle(self.get_json_for_kv_cache_slicing()) + + if "past_repetition_penalty_buffer" in self.input_names: + self.repetition_penalty_map: list[tuple[str, int]] = [ + ( + "past_repetition_penalty_buffer", + self.binding_index_map["past_repetition_penalty_buffer"], + ) + ] + self.repetition_penalty_spec_handle = self.get_slicing_spec_handle( + self.get_json_for_repetition_penalty_slicing() + ) + + # ── Buffer maps (sorted by layer then key/value) ───────────────────── + # decode_buff_map : input KV buffers (past_key.*, past_value.*) + # prefill_buff_map : output KV retained states + logits (logits last) + def _kv_sort_key(item: tuple[str, int]) -> tuple[int, int]: + name = item[0] + layer = int(name.split(".")[1]) if "." in name else 0 + kind = 0 if name.startswith("past_key") else 1 + return (layer, kind) + + self.decode_buff_map: list[tuple[str, int]] = sorted( + [ + (name, self.binding_index_map[name]) + for name in self.input_names + if name.startswith("past_key") or name.startswith("past_value") + ], + key=_kv_sort_key, + ) + + # decode_rs_buff_map : output RetainedState KV bindings (past_key.*_RetainedState) + # Stores OUTPUT binding indices — used by set_data_for_kv_handoff on the + # decode session to wire RetainedState outputs directly into kv_cache arrays. + # Distinct from decode_buff_map which stores INPUT binding indices. + self.decode_rs_buff_map: list[tuple[str, int]] = sorted( + [ + (name.replace("_RetainedState", ""), self.binding_index_map[name]) + for name in self.output_names + if name.endswith("_RetainedState") + ], + key=_kv_sort_key, + ) + + self.prefill_buff_map: list[tuple[str, int]] = sorted( + [ + (name.replace("_RetainedState", ""), self.binding_index_map[name]) + for name in self.output_names + if name.endswith("_RetainedState") + ], + key=_kv_sort_key, + ) + # Append logits at the end of prefill_buff_map + for name in self.output_names: + if name.startswith("log"): + self.prefill_buff_map.append((name, self.binding_index_map[name])) + + # ── Skip KV buffers by default (retained-state managed via handoff) ── + for slot in range(self.queue_len): + self.skip_buffers([n for n in self.input_names if n.startswith("past_")], slot) + self.skip_buffers([n for n in self.output_names if n.endswith("_RetainedState")], slot) + + # ── Properties ──────────────────────────────────────────────────────────── + + @property + def input_names(self) -> list[str]: + return [b.name for b in self.bindings if b.dir == aicapi.BUFFER_IO_TYPE_INPUT] + + @property + def output_names(self) -> list[str]: + return [b.name for b in self.bindings if b.dir == aicapi.BUFFER_IO_TYPE_OUTPUT] + + # ── Lifecycle ───────────────────────────────────────────────────────────── + + def activate(self) -> None: + """Activate the program and create one ExecObj per pool slot.""" + self.activate_done = True + self.program.activate() + self.execObj: list[qaicrt.ExecObj] = [qaicrt.ExecObj(self.context, self.program) for _ in range(self.queue_len)] + + def deactivate(self) -> None: + """Deactivate the program and release ExecObjs.""" + if self.activate_done: + del self.execObj + self.program.deactivate() + self.activate_done = False + + # ── Slicing spec helpers ────────────────────────────────────────────────── + + def get_json_for_kv_cache_slicing(self) -> str: + """ + JSON spec for setDataWithSlices on past_key.* / past_value.*. + Slice dimensions: [batch_index, :, ctx_start, :] — allows writing a + single batch slot starting at a given context position without copying + the full KV tensor. + """ + elem_size = AIC_TO_NP[self.bindings[self.binding_index_map["past_key.0"]].type].itemsize + spec = { + "BufferSpecs": [ + { + "Name": "past_key.*", + "ElemSize": elem_size, + "DimSpecs": [ + {"start": "batch_index"}, # dim 0: which batch slot to start at + {"start": 0}, # dim 1: all heads, start at 0 + {"start": "ctx_start"}, # dim 2: which context position to start at + {"start": 0}, + ], + }, + { + "Name": "past_value.*", + "ElemSize": elem_size, + "DimSpecs": [ + {"start": "batch_index"}, + {"start": 0}, + {"start": "ctx_start"}, + {"start": 0}, + ], + }, + ] + } + return json.dumps(spec) + + def get_json_for_repetition_penalty_slicing(self) -> str: + """JSON spec for setDataWithSlices on past_repetition_penalty_buffer.""" + elem_size = AIC_TO_NP[self.bindings[self.binding_index_map["past_repetition_penalty_buffer"]].type].itemsize + spec = { + "BufferSpecs": [ + { + "Name": "past_repetition_penalty_buffer", + "ElemSize": elem_size, + "DimSpecs": [ + {"start": "batch_index"}, + {"start": 0}, + ], + } + ] + } + return json.dumps(spec) + + def get_slicing_spec_handle(self, buffer_spec_json: str): + """Create and return a slicing spec handle from the program.""" + status, handle = self.program.createSlicingSpecHandle(buffer_spec_json) + assert status == qaicrt.QStatus.QS_SUCCESS, "Failed to createSlicingSpecHandle" + return handle + + # ── Buffer helpers ──────────────────────────────────────────────────────── + + def get_bindings(self, binding_names: list[str]) -> list: + return [b for b in self.bindings if b.name in binding_names] + + def get_bindings_shapes(self, binding_names: list[str]) -> dict[str, list[list[int]]]: + """Return all allowed shapes for the requested buffer names.""" + result: dict[str, list[list[int]]] = {} + for name in binding_names: + if name not in self.binding_index_map: + logger.warning("Unable to find binding: %s", name) + continue + idx = self.binding_index_map[name] + result[name] = [shape[idx][1] for shape in self.allowed_shapes] + return result + + def get_logits_ndim(self) -> int: + """Return the number of dimensions of the logits output binding. + + Reads directly from the selected_set binding metadata rather than + allowed_shapes, because allowed_shapes is empty for single-specialisation + QPCs (the common case) and would always return the default 3. + """ + if "logits" not in self.binding_index_map: + logger.warning("logits binding not found, defaulting ndim to 3") + return 3 + return len(self.bindings[self.binding_index_map["logits"]].dims) + + def set_buffers(self, buffers: dict[str, np.ndarray], index: int = 0) -> None: + """ + Copy numpy arrays into the qbuffer/buf_dims for exec-obj slot `index`. + Ensures contiguous memory layout before wrapping in QBuffer. + """ + for name, buf in buffers.items(): + if name not in self.binding_index_map: + logger.warning("Buffer: %s not found", name) + continue + buf_idx = self.binding_index_map[name] + contiguous = np.ascontiguousarray(buf) + if contiguous is not buf: + logger.warning("Non-contiguous buffer for '%s'; copying to contiguous.", name) + buffers[name] = contiguous + buf = contiguous + self.qbuffers[index][buf_idx] = qaicrt.QBuffer(buf) + self.buf_dims[index][buf_idx] = ( + buf.itemsize, + buf.shape if buf.ndim > 0 else (1,), + ) + + def skip_buffers(self, buffer_names: list[str], index: int = 0) -> None: + """Mark buffers as skipped (empty) for exec-obj slot `index`.""" + self.set_buffers({name: np.array([]) for name in buffer_names}, index) + + def unskip_buffers(self, buffer_names: list[str], index: int = 0) -> None: + """Restore skipped buffers to zero-filled arrays of their binding shape.""" + bufs: dict[str, np.ndarray] = {} + for b in self.get_bindings(buffer_names): + dtype = AIC_TO_NP[b.type] + bufs[b.name] = np.zeros(list(b.dims), dtype=dtype) + self.set_buffers(bufs, index) + + def get_tuple_list_from_dict(self, inputs: dict[str, Any]) -> list[tuple[int, np.ndarray]]: + """ + Convert {name: array} → [(binding_index, array)] tuple list. + This is the format expected by execObj.setData() and + execObj.setDataWithSlices() — avoids a full qbuffer/buf_dims update + and is the zero-copy path for inputs that are already contiguous. + """ + result: list[tuple[int, np.ndarray]] = [] + for name, buf in inputs.items(): + if name not in self.binding_index_map: + logger.warning("Buffer: %s not found", name) + continue + if buf is None: + continue + result.append((self.binding_index_map[name], buf)) + return result + + def _make_inputs_contiguous(self, inputs: dict[str, Any]) -> None: + """Ensure all input arrays are C-contiguous in-place.""" + for k, v in inputs.items(): + inputs[k] = np.ascontiguousarray(v) + + # ── KV handoff (zero-copy DMA slice) ───────────────────────────────────── + + def _set_data_with_slices( + self, + buffers, + slicing_parameters: list[tuple[str, int]], + slicing_spec_handle, + index: int = 0, + buff_map: list[tuple[str, int]] | None = None, + ): + """ + Core zero-copy KV handoff primitive. + + If `buffers` is a list/ndarray, `buff_map` must be provided and the + buffers are zipped with it to form the (binding_index, array) tuple list. + If `buffers` is a dict, get_tuple_list_from_dict() is used instead. + + setDataWithSlices() tells the runtime to write the output of this + exec-obj into a *slice* of the target buffer (identified by + slicing_parameters) rather than the whole buffer — enabling multiple + batch slots to share one large KV allocation without copying. + """ + if isinstance(buffers, (list, np.ndarray)): + assert buff_map is not None, "buff_map required when buffers is a list" + assert len(buffers) == len(buff_map) or len(buffers) + 1 == len(buff_map), ( + "buffers length must match buff_map (or buff_map may include logits entry)" + ) + slices_as_tuple_list = [(entry[1], buf) for entry, buf in zip(buff_map, buffers)] + else: + slices_as_tuple_list = self.get_tuple_list_from_dict(buffers) + + status, _ = self.execObj[index].setDataWithSlices( + slices_as_tuple_list, # [(binding_idx, numpy_array), ...] + slicing_spec_handle, # compiled address descriptor + slicing_parameters, # parametric values + ) + assert status == qaicrt.QStatus.QS_SUCCESS, "Failed to setDataWithSlices" + return buffers + + def set_data_for_kv_handoff( + self, + kv_cache_buffers, + slicing_parameters: list[tuple[str, int]], + index: int = 0, + buff_map: list[tuple[str, int]] | None = None, + ): + """ + Wire prefill RetainedState outputs directly into the shared KV cache + buffers via a sliced DMA descriptor. + + Call this BEFORE np_run_pipeline() on the last chunk so the runtime + knows where to write the KV outputs without any numpy copy. + + Parameters + ---------- + kv_cache_buffers : list[np.ndarray] | dict[str, np.ndarray] + Shared KV numpy arrays (one per layer key/value). + slicing_parameters : list[tuple[str, int]] + e.g. [("batch_index", bidx), ("ctx_start", 0)] + index : int + Exec-obj slot index. + buff_map : list[tuple[str, int]] | None + Typically session.prefill_buff_map[:-1] (KV only, no logits). + """ + return self._set_data_with_slices( + kv_cache_buffers, + slicing_parameters, + self.kv_slicing_spec_handle, + index, + buff_map, + ) + + def set_data_for_repetition_penalty( + self, + repetition_penalty_buffers, + slicing_parameters: list[tuple[str, int]], + index: int = 0, + ): + """Wire repetition-penalty buffer via sliced DMA.""" + return self._set_data_with_slices( + repetition_penalty_buffers, + slicing_parameters, + self.repetition_penalty_spec_handle, + index, + self.repetition_penalty_map, + ) + + # ── Inference entry points ──────────────────────────────────────────────── + + def np_run( + self, + inputs: dict[str, Any], + slicing_parameters: list[tuple[str, int]] | None = None, + is_prefill: bool = True, + ) -> int: + """ + Fire-and-forget enqueue for decode or non-pipelined prefill. + + Blocks until an exec-obj slot is available (prefill pool or fixed + decode slot), then calls setData / setDataWithSlices and enqueues. + Returns the exec-obj index so the caller can call complete_inf() later. + + Parameters + ---------- + inputs : dict[str, np.ndarray] + Model inputs. Arrays are made contiguous in-place. + slicing_parameters : list[tuple[str, int]] | None + If provided, uses setDataWithSlices with kv_slicing_spec_handle. + is_prefill : bool + True → draw from prefill_available_exec_objs pool. + False → use fixed decode_execObj_idx slot. + """ + if is_prefill: + exec_idx = self.prefill_available_exec_objs.get(timeout=self.async_scheduling_exec_timeout) + else: + assert self.decode_execObj_idx is not None, "decode_execObj_idx is None — session not configured for decode" + exec_idx = self.decode_execObj_idx + + # Copy input arrays into qbuffers/buf_dims so all buffers — inputs and + # outputs — are registered in a single setData(qbuffers, buf_dims) call. + # Overload-4 setData(tuple_list) replaces the entire registration, so + # calling it with only inputs would wipe out the logits output slot. + self.set_buffers(inputs, exec_idx) + + if slicing_parameters is None: + status = self.execObj[exec_idx].setData(self.qbuffers[exec_idx], self.buf_dims[exec_idx]) + assert status == qaicrt.QStatus.QS_SUCCESS, "setData failed" + else: + self._make_inputs_contiguous(inputs) + slices = self.get_tuple_list_from_dict(inputs) + status, _ = self.execObj[exec_idx].setDataWithSlices( + slices, self.kv_slicing_spec_handle, slicing_parameters + ) + assert status == qaicrt.QStatus.QS_SUCCESS, "setDataWithSlices failed" + + try: + assert self.queue.enqueue(self.execObj[exec_idx]) == qaicrt.QStatus.QS_SUCCESS, "enqueue failed" + except Exception as exc: + logger.error("Error while enqueuing: %s", exc) + return 0 + + return exec_idx + + def np_run_pipeline( + self, + inputs: dict[str, np.ndarray], + slicing_parameters: list[tuple[str, int]] | None = None, + last_chunk: bool = False, + kv_cache_buffers=None, + ) -> int: + """ + Pipelined prefill enqueue with KV-cache handoff on the last chunk. + + Draws an exec-obj from the prefill pool (blocks if none available), + optionally wires the KV RetainedState outputs into the shared KV cache + via set_data_for_kv_handoff() when last_chunk=True, then enqueues. + + The key difference from np_run: + • On last_chunk=True, set_data_for_kv_handoff() is called BEFORE + setData so the runtime wires the prefill output KV directly into + the decode session's KV input slots — zero-copy, no numpy involved. + • Non-last chunks do not need KV outputs (they are skipped by default). + + Parameters + ---------- + inputs : dict[str, np.ndarray] + Chunk inputs (input_ids, position_ids, batch_index, logits buffer). + slicing_parameters : list[tuple[str, int]] | None + Passed to setDataWithSlices if provided. + last_chunk : bool + True on the final chunk of a request. Triggers KV handoff wiring. + kv_cache_buffers : list[np.ndarray] | None + Shared KV numpy arrays to wire into. Must be non-None when + last_chunk=True. + """ + logger.debug("Waiting for prefill exec-obj (pipeline)") + exec_idx = self.prefill_available_exec_objs.get(timeout=self.async_scheduling_exec_timeout) + logger.debug("Got prefill exec-obj %d", exec_idx) + + if last_chunk: + assert kv_cache_buffers is not None, "kv_cache_buffers must be provided for last_chunk=True" + batch_index = int(inputs["batch_index"].item()) + # Wire prefill RetainedState outputs → shared KV cache (zero-copy) + self.set_data_for_kv_handoff( + kv_cache_buffers, + [ + ("batch_index", batch_index % self.full_batch_size), + ("ctx_start", 0), + ], + exec_idx, + self.prefill_buff_map[:-1], # KV entries only, exclude logits + ) + + # Must use overload-4 setData(tuple_list) here — NOT setData(qbuffers, buf_dims). + # + # Reason: setDataWithSlices (called above in set_data_for_kv_handoff) wires + # the KV RetainedState output slots (idx 6-9) to the shared kv_cache arrays. + # Per the SDK docs, setData(qbuffers, buf_dims) — overload 1 — overwrites ALL + # slots including 6-9, destroying the KV handoff wiring regardless of call order. + # + # setData(tuple_list) — overload 4 — only touches the indices explicitly listed. + # get_tuple_list_from_dict(inputs) produces [(0, input_ids), (1, position_ids), + # (10, logits_buf)] — indices 6-9 are absent, so the setDataWithSlices wiring + # on those slots is never touched and survives intact. + self._make_inputs_contiguous(inputs) + tuple_list = self.get_tuple_list_from_dict(inputs) + + if slicing_parameters is None: + status = self.execObj[exec_idx].setData(tuple_list) + assert status == qaicrt.QStatus.QS_SUCCESS, "setData failed" + else: + status, _ = self.execObj[exec_idx].setDataWithSlices( + tuple_list, self.kv_slicing_spec_handle, slicing_parameters + ) + assert status == qaicrt.QStatus.QS_SUCCESS, "setDataWithSlices failed" + + assert self.queue.enqueue(self.execObj[exec_idx]) == qaicrt.QStatus.QS_SUCCESS, "enqueue failed" + + return exec_idx + + def complete_inf(self, index: int, is_prefill: bool = True) -> None: + """ + Wait for exec-obj `index` to finish and release it back to the pool. + + Parameters + ---------- + index : int + Exec-obj index returned by np_run / np_run_pipeline. + is_prefill : bool + True → return the slot to prefill_available_exec_objs. + False → decode slot is fixed; nothing to return. + """ + if self.execObj[index].waitForCompletion() != qaicrt.QStatus.QS_SUCCESS: + raise ValueError(_build_completion_error(self.bindings, self.allowed_shapes, self.buf_dims[index])) + logger.debug("Releasing exec-obj %d (is_prefill=%s)", index, is_prefill) + if is_prefill: + self.prefill_available_exec_objs.put(index) + + def get_outputs(self, index: int = 0) -> dict[str, np.ndarray]: + """ + Read output buffers from exec-obj `index` after complete_inf(). + + Returns a dict of {output_name: numpy_array} for all non-empty outputs. + + Shape is read from binding metadata (always correct) rather than from + buf_dims, because np_run / np_run_pipeline use the tuple-list setData + path which bypasses set_buffers() and never updates buf_dims. + """ + status, out_qbuffers = self.execObj[index].getData() + assert status == qaicrt.QStatus.QS_SUCCESS, "getData failed" + + outputs: dict[str, np.ndarray] = {} + for binding in self.bindings: + if binding.dir != aicapi.BUFFER_IO_TYPE_OUTPUT: + continue + raw = bytes(out_qbuffers[binding.index]) + if len(raw) == 0: + continue + dtype = AIC_TO_NP[binding.type] + shape = list(binding.dims) + outputs[binding.name] = np.frombuffer(raw, dtype=dtype).reshape(shape) + return outputs + + def run(self, inputs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: + """ + Synchronous convenience wrapper — mirrors the original cloud_infer.py + QAICInferenceSession.run() interface. + + Uses np_run(is_prefill=True) + complete_inf() + get_outputs() so + existing callers need no changes. + """ + exec_idx = self.np_run(inputs, is_prefill=True) + self.complete_inf(exec_idx, is_prefill=True) + return self.get_outputs(exec_idx) + + # ── Misc helpers (kept for API compatibility) ───────────────────────────── + + def create_numpy_buffers( + self, + input_dict: dict, + direction: str, + shape: list[int], + size: np.dtype, + ) -> None: + """Allocate zero KV buffers (input or output RetainedState) into input_dict.""" + if direction == "in": + names = [n for n in self.input_names if n.startswith("past_key") or n.startswith("past_value")] + elif direction == "out": + names = [ + n + for n in self.output_names + if (n.startswith("past_key") or n.startswith("past_value")) and n.endswith("_RetainedState") + ] + else: + raise ValueError(f"Invalid direction '{direction}'; expected 'in' or 'out'") + for name in names: + input_dict[name] = np.zeros(shape, dtype=size) if shape else np.array([]) + + def create_output_buffers( + self, + input_dict: dict, + shape: list[int], + size: np.dtype, + buffer_name: str = "logits", + ) -> None: + """Allocate an empty output buffer (e.g. logits) into input_dict.""" + if buffer_name not in self.binding_index_map: + logger.warning("Buffer: %s not found", buffer_name) + return + input_dict[buffer_name] = np.empty(shape, dtype=size) + + def extract_outputs(self, input_dict: dict) -> dict[str, np.ndarray]: + """Extract output-named keys from a combined input/output dict.""" + return {name: input_dict[name] for name in self.output_names if name in input_dict} diff --git a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking_kvslice.py b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking_kvslice.py new file mode 100644 index 0000000000..eb85796131 --- /dev/null +++ b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking_kvslice.py @@ -0,0 +1,282 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Session layout +-------------- + prefill_session : cluster_id="prefill", stages=1 + exec-obj pool: [slot-1] (prefill only, no decode slot) + + decode_session : cluster_id="decode" + exec-obj pool: [slot-0] (decode only, no prefill pool) + +Shared KV buffers +----------------- + kv_cache : list[np.ndarray] — one array per (layer, key/value) pair, + shape = prefill_session.kv_shape, dtype = prefill_session.kv_size + Allocated once; written by prefill via DMA slice, read by decode + via the same pointer — no copy at the prefill→decode boundary. +""" + +import math +import time + +import numpy as np +from QEfficient.generation.cloud_infer_KV_share import QAICInferenceSession +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +# ───────────────────────────────────────────────────────────────────────────── +# Configuration +# ───────────────────────────────────────────────────────────────────────────── +MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507" +PREFILL_SEQ_LEN = 128 +CTX_LEN = 256 # = PREFILL_SEQ_LEN * 2 +STAGES = 4 # prefill pp stages + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(MODEL_ID) +PREFILL_QPC_PATH = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=8, # TS=2 num_devices/STAGES + split_retained_state_io=True, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + stages=STAGES, + # use_onnx_subfunctions=True, +) + +DECODE_QPC_PATH = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=2, + split_retained_state_io=True, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, + retain_full_kv=True, +) + +PROMPT = """ +Explain quantum computing in simple terms. +""" + +# ───────────────────────────────────────────────────────────────────────────── +# Load tokenizer and model config +# ───────────────────────────────────────────────────────────────────────────── +print("Loading tokenizer and config …") +config = AutoConfig.from_pretrained(MODEL_ID) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) +num_layers = config.num_hidden_layers +print(f" num_hidden_layers = {num_layers}") + +# ───────────────────────────────────────────────────────────────────────────── +# Tokenise and build chunked inputs +# ───────────────────────────────────────────────────────────────────────────── +raw_inputs = tokenizer(PROMPT, return_tensors="np", padding=True) +generation_len = CTX_LEN - int(raw_inputs["attention_mask"].sum(1, keepdims=True).max()) + +padded_len = raw_inputs["input_ids"].shape[1] +num_chunks = math.ceil(padded_len / PREFILL_SEQ_LEN) # ceil divide +padded_len = num_chunks * PREFILL_SEQ_LEN + +inputs = tokenizer( + PROMPT, + return_tensors="np", + padding="max_length", + max_length=padded_len, +) +inputs["position_ids"] = np.where( + inputs.pop("attention_mask"), + np.arange(padded_len), + -1, +) +inputs.pop("token_type_ids", None) +inputs.pop("past_key_values", None) + +print(f" prompt tokens = {padded_len} num_chunks = {num_chunks} generation_len = {generation_len}") + +# ───────────────────────────────────────────────────────────────────────────── +# Load sessions +# ───────────────────────────────────────────────────────────────────────────── +print("\nLoading prefill session (cluster_id='prefill', stages) …") +prefill_session = QAICInferenceSession( + qpc_path=PREFILL_QPC_PATH, + full_batch_size=1, + device_ids=[0, 1, 2, 3, 4, 5, 6, 7], + cluster_id="prefill", + stages=STAGES, +) +print(" prefill_session loaded ✓") +print(f" prefill exec-obj pool size : {prefill_session.prefill_num_execObj}") +print(f" kv_shape : {prefill_session.kv_shape} kv_dtype : {prefill_session.kv_size}") + +print("\nLoading decode session (cluster_id='decode') …") +decode_session = QAICInferenceSession( + qpc_path=DECODE_QPC_PATH, + full_batch_size=1, + device_ids=[8, 9], # 2-device MQ decode + cluster_id="decode", +) +print(" decode_session loaded ✓") + +# ───────────────────────────────────────────────────────────────────────────── +# Allocate shared KV cache buffers +# +# One numpy array per (layer, key/value) pair. +# These arrays are shared between prefill (written via DMA slice) and decode +# (read as direct inputs) — no copy at the boundary. +# ───────────────────────────────────────────────────────────────────────────── +kv_cache: list[np.ndarray] = [ + np.zeros(prefill_session.kv_shape, dtype=prefill_session.kv_size) + for _ in prefill_session.prefill_buff_map[:-1] # KV entries only, no logits +] +# Logits output buffer for the last prefill chunk (written in-place by the QPC) +logits_buf = np.zeros((1, 1, config.vocab_size), dtype=np.float32) + +print(f"\nAllocated {len(kv_cache)} shared KV buffers shape={prefill_session.kv_shape}") + +# ───────────────────────────────────────────────────────────────────────────── +# Chunked prefill +# +# Non-last chunks: np_run(is_prefill=True) +# Last chunk: np_run_pipeline(last_chunk=True, kv_cache_buffers=kv_cache) +# set_data_for_kv_handoff() wires RetainedState outputs into +# kv_cache via setDataWithSlices before enqueue — zero-copy. +# ───────────────────────────────────────────────────────────────────────────── +print("\n── Chunked prefill ──") +prefill_logits = None + +for chunk_idx in range(num_chunks): + start = chunk_idx * PREFILL_SEQ_LEN + end = start + PREFILL_SEQ_LEN + is_last = chunk_idx == num_chunks - 1 + + chunk_inputs = { + "input_ids": inputs["input_ids"][:, start:end], + "position_ids": inputs["position_ids"][:, start:end], + "batch_index": np.array([[0]], dtype=np.int64), + } + if is_last: + chunk_inputs["logits"] = logits_buf + + t0 = time.perf_counter() + + if is_last: + # Last chunk: wire KV outputs into shared kv_cache via DMA slice, + # then enqueue. No numpy copy of KV at the prefill→decode boundary. + exec_idx = prefill_session.np_run_pipeline( + inputs=chunk_inputs, + last_chunk=True, + kv_cache_buffers=kv_cache, + ) + else: + # Non-last chunks: + # KV RetainedState outputs are skipped (not needed mid-pipeline). + exec_idx = prefill_session.np_run( + inputs=chunk_inputs, + is_prefill=True, + ) + + prefill_session.complete_inf(exec_idx, is_prefill=True) + elapsed = time.perf_counter() - t0 + + print(f" chunk {chunk_idx + 1}/{num_chunks} last={is_last} time={elapsed * 1000:.1f} ms") + +# After the last chunk, logits_buf has been written in-place by the QPC +prefill_logits = logits_buf +first_token = int(np.argmax(prefill_logits)) +print(f"\n Prefill done. First generated token id = {first_token}") + +# ───────────────────────────────────────────────────────────────────────────── +# Decode — shared DMA KV buffer design +# +# kv_cache is the single shared numpy buffer across all decode steps: +# - prefill wrote into it via setDataWithSlices (zero-copy DMA slice) +# - each decode step reads KV from it as input (decode_buff_map INPUT slots) +# - each decode step writes updated KV back into it via set_data_for_kv_handoff +# (setDataWithSlices on RetainedState OUTPUT slots → same kv_cache arrays) +# +# decode_inputs[kv_name] always points at kv_cache[i] — never reassigned. +# After complete_inf, kv_cache is updated in-place with the new token's KV. +# ───────────────────────────────────────────────────────────────────────────── +print("\n── Decode ──") + +next_pos = int(np.max(inputs["position_ids"])) + 1 + +decode_inputs: dict[str, np.ndarray] = { + "input_ids": np.array([[first_token]], dtype=np.int64), + "position_ids": np.array([[next_pos]], dtype=np.int64), + "logits": logits_buf, +} +for (kv_name, _), kv_buf in zip(decode_session.decode_buff_map, kv_cache): + decode_inputs[kv_name] = kv_buf # same kv_cache array, no copy + +all_tokens = [first_token] + + +def run_decode_step(): + # Wire RetainedState OUTPUT slots → kv_cache via setDataWithSlices. + # Runtime DMA-writes updated KV directly into kv_cache after inference. + # decode_rs_buff_map holds OUTPUT binding indices. + decode_session.set_data_for_kv_handoff( + kv_cache, + [("batch_index", 0), ("ctx_start", 0)], + decode_session.decode_execObj_idx, + decode_session.decode_rs_buff_map, + ) + exec_idx = decode_session.np_run(decode_inputs, is_prefill=False) + decode_session.complete_inf(exec_idx, is_prefill=False) + return decode_session.get_outputs(exec_idx) + + +t0 = time.perf_counter() +dec_outputs = run_decode_step() +print(f" First decode step time={(time.perf_counter() - t0) * 1000:.1f} ms") + +next_token = int(np.argmax(dec_outputs["logits"])) +all_tokens.append(next_token) +next_pos += 1 + +# ───────────────────────────────────────────────────────────────────────────── +# Decode loop — kv_cache updated in-place each step, no KV copy or reassignment +# ───────────────────────────────────────────────────────────────────────────── +t_loop_start = time.perf_counter() + +for step in range(generation_len - 2): + decode_inputs["input_ids"] = np.array([[next_token]], dtype=np.int64) + decode_inputs["position_ids"] = np.array([[next_pos]], dtype=np.int64) + + dec_outputs = run_decode_step() + + next_token = int(np.argmax(dec_outputs["logits"])) + all_tokens.append(next_token) + next_pos += 1 + +t_loop_end = time.perf_counter() + +# ───────────────────────────────────────────────────────────────────────────── +# Results +# ───────────────────────────────────────────────────────────────────────────── +decode_steps = generation_len - 2 +tok_per_sec = decode_steps / (t_loop_end - t_loop_start) if decode_steps > 0 else 0.0 + +print(f"\n decode tok/sec = {tok_per_sec:.2f} ({decode_steps} steps)") +print(f"\ninput\n{PROMPT}\noutput\n{tokenizer.decode(all_tokens)}") + +prefill_session.deactivate() +decode_session.deactivate()