From 8a28ceb84f705f57c1c6397faa061f103bc569a1 Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Mon, 20 Apr 2026 23:45:55 -0700 Subject: [PATCH 1/6] Added MDP generation to QEff Compile Signed-off-by: Mohit Mehta --- QEfficient/base/modeling_qeff.py | 52 ++++-- QEfficient/compile/mdp_generator.py | 249 ++++++++++++++++++++++++++++ 2 files changed, 291 insertions(+), 10 deletions(-) create mode 100644 QEfficient/compile/mdp_generator.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index e9213761d9..2655dbc45f 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -490,6 +490,7 @@ def _compile( specializations: Optional[List[Dict[str, int]]] = None, custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, + mdp_num_partitions: int = 1, num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, @@ -513,6 +514,11 @@ def _compile( :specializations (list): List of specializations to compile for :custom_io (dict): Custom IO to specify the input and outputs in different formats than default :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. + :mdp_num_partitions (int): Number of pipeline-parallel partitions for disaggregated prefill serving. + When > 1, the ONNX graph is read directly to generate a fully-populated MDP partition + config (nodeList per partition) without requiring a compiler round-trip. + Ignored when ``mdp_load_partition_config`` is already provided in compiler_options. + Defaults to 1 (template / tensor-slice MDP, existing behaviour). :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` :qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.`` @@ -578,22 +584,47 @@ def _compile( + [f"-m={onnx_path}"] ) - # MDP partition config: prioritize dump over load - mdp_dump_json_path = compiler_options.pop("mdp_dump_partition_config", None) + # MDP partition config selection (three priorities, highest first): + # 1. User explicitly provides a pre-built MDP JSON to load. + # 2. Disaggregated (pipeline-parallel) MDP — generate from ONNX topsort. + # 3. Template (tensor-slice) MDP — single partition, nodeList absent. mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None) + # Silently discard any stale mdp_dump_partition_config key that callers + # may still pass; the compiler-round-trip dump path is no longer supported. + compiler_options.pop("mdp_dump_partition_config", None) mdp_ts_json = None - if mdp_dump_json_path: - if mdp_ts_json_path: - logger.warning( - "Loading and Dumping partition is not supported at the same time. Prioritizing dump config over load config!" - ) - command.append(f"-mdp-dump-partition-config={mdp_dump_json_path}") - elif mdp_ts_json_path: + if mdp_ts_json_path: command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") mdp_ts_json = load_json(str(mdp_ts_json_path)) + elif mdp_num_partitions > 1: + # Disaggregated (pipeline-parallel) MDP: generate a fully-populated + # nodeList per partition directly from the ONNX graph — no compiler + # round-trip required. + from QEfficient.compile.mdp_generator import generate_disagg_mdp_partition_config + + num_layers = getattr(self, "num_layers", None) + if num_layers is None: + raise AttributeError("Model does not expose 'num_layers'. Cannot generate disagg MDP partition config.") + num_cores = compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) + logger.info( + f"Generating disagg MDP partition config from ONNX: " + f"num_devices={mdp_ts_num_devices}, num_partitions={mdp_num_partitions}, " + f"num_layers={num_layers}, num_cores={num_cores}" + ) + mdp_ts_json = generate_disagg_mdp_partition_config( + onnx_path=str(onnx_path), + num_devices=mdp_ts_num_devices, + num_partitions=mdp_num_partitions, + num_layers=num_layers, + num_cores=num_cores, + ) + mdp_ts_json_path = compile_dir / f"mdp_disagg_{mdp_ts_num_devices}d_{mdp_num_partitions}p.json" + create_json(str(mdp_ts_json_path), mdp_ts_json) + command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") elif mdp_ts_num_devices > 1: - # Generate mdp config only if neither dump nor load is provided and num_devices > 1 + # Template (tensor-slice) MDP: single partition, empty nodeList. + # Used when PP is disabled (stages=1). Compiler fills the nodeList. mdp_ts_json = generate_mdp_partition_config( mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) ) @@ -618,6 +649,7 @@ def _compile( "specializations": specializations, "custom_io": custom_io, "mdp_ts_num_devices": mdp_ts_num_devices, + "mdp_num_partitions": mdp_num_partitions, "mdp_ts_json": mdp_ts_json, "num_speculative_tokens": num_speculative_tokens, "prefill_only": prefill_only, diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py new file mode 100644 index 0000000000..58a9ffe71c --- /dev/null +++ b/QEfficient/compile/mdp_generator.py @@ -0,0 +1,249 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +"""MDP generator for disaggregated prefill serving (PP-enabled, TS-enabled, stages>1).""" + +from typing import Any, Dict, List, Optional, Set +import onnx +import logging + +logger = logging.getLogger(__name__) + + +def _get_compiler_folded_nodes(graph) -> Set[str]: + """Return node names the compiler will fold away during ONNX import. + + Mirrors computeIsConstantFoldable() in ONNXModelLoader.cpp: a node is + foldable if every one of its inputs is a compile-time constant (initializer, + Constant op output, or output of another foldable node). Folded nodes are + absent from the compiler IR, so including them in nodeList is harmless but + excluding them produces a cleaner MDP closer to the compiler dump. + + Op types that the compiler never folds (ProtobufLoader.cpp:68): + Loop, Const, Identity, If, DequantizeLinear + """ + # const_values: output tensor names whose value is known at compile time. + # Seeded with all initializer names (model weights / constants). + const_values: Set[str] = {init.name for init in graph.initializer} + + # Constant op outputs are trivially compile-time constants; collect them + # upfront so the fixed-point loop below only needs one pass for everything else. + for node in graph.node: + if node.op_type == "Constant": + const_values.update(out for out in node.output if out) + + # Never-folded op types (compiler explicitly skips these - ProtobufLoader.cpp:68). + _NEVER_FOLD = frozenset({"Loop", "Const", "Identity", "If", "DequantizeLinear"}) + + # Keep marking nodes foldable until no new ones are found. + foldable_nodes: Set[str] = set() + while(True): + changed = False + for node in graph.node: + if not node.name or node.name in foldable_nodes: + continue + if node.op_type in _NEVER_FOLD or not node.input: + continue + if all(inp in const_values for inp in node.input if inp): + foldable_nodes.add(node.name) + const_values.update(out for out in node.output if out) + changed = True + if not changed: + break + + return foldable_nodes + + +def _get_layer_num(node_name: str) -> Optional[int]: + """Return transformer layer index from node name, or None. + + Supports layers.N (Llama/Mistral/Qwen/Gemma/Granite) and h.N (GPT-2). + """ + for part in node_name.split("/"): + if part.startswith("layers."): + suffix = part[len("layers.") :] + if suffix.isdigit(): + return int(suffix) + elif part.startswith("h."): + suffix = part[len("h.") :] + if suffix.isdigit(): + return int(suffix) + return None + + +def _get_inlined_node_map(model) -> tuple: + """Classify ONNX local functions and build inlined sub-node names. + + The compiler inlines a local function body into the parent graph during + ONNX import if it has < 100 nodes AND is not a known custom op + (ONNXModelLoaderSubFuns.cpp). Inlined call-sites do not appear in the + compiler IR; their sub-nodes are named /. + Known custom ops (registered via DEFINEKNOWNCUSTOMOP) keep their + call-site name in the IR and must be included in nodeList as-is. + + Returns: + inlined_node_map: dict mapping call-site name -> list of inlined + sub-node names (/). + non_inlined_funcs: set of function names that are NOT inlined + (known custom ops or >= 100 nodes); their + call-site names are valid nodeList entries. + """ + # Registered with DEFINEKNOWNCUSTOMOP in ONNXModelLoader.cpp + _KNOWN_CUSTOM_OPS = frozenset({"CustomRMSNorm"}) + + local_functions = {f.name: f for f in model.functions} + logger.info(f"Found {len(local_functions)} local function types: {set(local_functions.keys())}") + + inlined_funcs: Set[str] = set() + non_inlined_funcs: Set[str] = set() + for func_name, func in local_functions.items(): + if func_name in _KNOWN_CUSTOM_OPS or len(func.node) >= 100: + non_inlined_funcs.add(func_name) + logger.info(f" {func_name}: not inlined") + else: + inlined_funcs.add(func_name) + logger.info(f" {func_name}: {len(func.node)} nodes, will inline") + + inlined_node_map: Dict[str, List[str]] = {} + for node in model.graph.node: + if node.op_type in inlined_funcs: + func = local_functions[node.op_type] + inlined_node_map[node.name] = [ + f"{node.name}/{fn.name}" for fn in func.node if fn.name + ] + + logger.info(f"Inlined sub-nodes mapped for {len(inlined_node_map)} call-sites") + return inlined_node_map, non_inlined_funcs + + +def generate_disagg_mdp_partition_config( + onnx_path: str, + num_devices: int, + num_partitions: int, + num_layers: int, + num_cores: int = 16, +) -> Dict[str, Any]: + """Generate a pipeline-partitioned MDP config from an exported ONNX graph. + + Assigns nodes to partitions by transformer layer index. Non-layer nodes + (embeddings, lm_head) follow the nearest layer in topological order. + nodeList is a superset of the compiler dump; the compiler silently ignores + optimized-away names. Inlined local function call-sites (CtxScatterCB, + CtxGatherCB) are excluded; their /nNN sub-nodes are assigned automatically. + Known custom ops (CustomRMSNorm) are included by call-site name. + + For PP+TS: num_devices // num_partitions devices per partition; the + compiler applies tensor-slicing within each stage. + + Args: + onnx_path: Path to the exported ONNX file. + num_devices: Total devices (num_partitions * ts_per_stage). + num_partitions: Number of pipeline stages. + num_layers: Number of transformer layers. + num_cores: NSP cores per device (default 16). + + Returns: + dict with keys 'connections' and 'partitions'. + """ + assert num_partitions <= num_devices, f"num_partitions ({num_partitions}) must be <= num_devices ({num_devices})" + + layers_per_partition = num_layers // num_partitions + model = onnx.load(onnx_path, load_external_data=False) + + # Verify topological order (ONNX spec §3.3). Fails loudly on malformed exports. + # Graph inputs and initializers are excluded — they are not produced by any node. + graph_input_names: Set[str] = {inp.name for inp in model.graph.input} + initializer_names: Set[str] = {init.name for init in model.graph.initializer} + external_names: Set[str] = graph_input_names | initializer_names + + output_to_node: Dict[str, str] = {} + for node in model.graph.node: + for out in node.output: + if out: # "" marks optional unused outputs + output_to_node[out] = node.name + + seen_outputs: Set[str] = set() + for node in model.graph.node: + for inp in node.input: + if not inp: + continue + if inp in external_names: + continue + if inp in output_to_node and inp not in seen_outputs: + raise ValueError( + f"ONNX graph has a cycle or violates topological order: " + f"node '{node.name}' consumes '{inp}' produced by " + f"'{output_to_node[inp]}', but that producer has not appeared yet." + ) + for out in node.output: + if out: + seen_outputs.add(out) + + logger.info("Computing constant-foldable nodes...") + folded_nodes = _get_compiler_folded_nodes(model.graph) + logger.info(f"Found {len(folded_nodes)} compiler-folded nodes (excluded from nodeList)") + + inlined_node_map, non_inlined_functions = _get_inlined_node_map(model) + inlined_functions = {f.name for f in model.functions} - non_inlined_functions + + # First pass: assign main graph nodes to partitions by layer index. + partitions: List[List[str]] = [[] for _ in range(num_partitions)] + current_layer_partition = 0 + seen_first_layer = False + max_layer_seen = -1 + + for node in model.graph.node: + if not node.name.startswith("/"): + continue + if node.name in folded_nodes: + continue + if node.op_type in inlined_functions: + continue # inlined; sub-nodes added in second pass + + layer_num = _get_layer_num(node.name) + if layer_num is not None: + max_layer_seen = max(max_layer_seen, layer_num) + seen_first_layer = True + partition_idx = min(layer_num // layers_per_partition, num_partitions - 1) + current_layer_partition = partition_idx + partitions[partition_idx].append(node.name) + else: + if not seen_first_layer: + partitions[0].append(node.name) + else: + partitions[current_layer_partition].append(node.name) + + # Second pass: add inlined sub-nodes, inheriting their call-site's partition. + for call_site_name, inlined_nodes in inlined_node_map.items(): + layer_num = _get_layer_num(call_site_name) + if layer_num is not None: + partition_idx = min(layer_num // layers_per_partition, num_partitions - 1) + else: + partition_idx = current_layer_partition + partitions[partition_idx].extend(inlined_nodes) + + for i, partition in enumerate(partitions): + logger.info(f"Partition {i}: {len(partition)} nodes") + logger.info(f"Total nodes in MDP: {sum(len(p) for p in partitions)}") + + # PP-only: 1 device/partition; PP+TS: num_devices//num_partitions devices/partition. + device_ids = list(range(num_devices)) + devices_per_partition = num_devices // num_partitions + partition_objs = [] + for i, node_list in enumerate(partitions): + assigned_devices = device_ids[i * devices_per_partition : (i + 1) * devices_per_partition] + partition_objs.append( + { + "name": f"Partition{i}", + "nodeList": node_list, + "devices": [{"deviceId": dev_id, "numCores": num_cores} for dev_id in assigned_devices], + } + ) + + return { + "connections": [{"devices": device_ids, "type": "p2p"}], + "partitions": partition_objs, + } From 5b0adf86ad06f6f97342f7f506e9f3e53d208cdb Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Mon, 20 Apr 2026 23:50:47 -0700 Subject: [PATCH 2/6] Formatting and Linting Signed-off-by: Mohit Mehta --- QEfficient/compile/mdp_generator.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py index 58a9ffe71c..cbc0f606ec 100644 --- a/QEfficient/compile/mdp_generator.py +++ b/QEfficient/compile/mdp_generator.py @@ -6,9 +6,10 @@ # ----------------------------------------------------------------------------- """MDP generator for disaggregated prefill serving (PP-enabled, TS-enabled, stages>1).""" +import logging from typing import Any, Dict, List, Optional, Set + import onnx -import logging logger = logging.getLogger(__name__) @@ -40,7 +41,7 @@ def _get_compiler_folded_nodes(graph) -> Set[str]: # Keep marking nodes foldable until no new ones are found. foldable_nodes: Set[str] = set() - while(True): + while True: changed = False for node in graph.node: if not node.name or node.name in foldable_nodes: @@ -111,9 +112,7 @@ def _get_inlined_node_map(model) -> tuple: for node in model.graph.node: if node.op_type in inlined_funcs: func = local_functions[node.op_type] - inlined_node_map[node.name] = [ - f"{node.name}/{fn.name}" for fn in func.node if fn.name - ] + inlined_node_map[node.name] = [f"{node.name}/{fn.name}" for fn in func.node if fn.name] logger.info(f"Inlined sub-nodes mapped for {len(inlined_node_map)} call-sites") return inlined_node_map, non_inlined_funcs From 6bef372eee11357e38c2cefb502a46fff1eff51c Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Wed, 22 Apr 2026 01:41:18 -0700 Subject: [PATCH 3/6] Add compiler options - 'stages' Signed-off-by: Mohit Mehta --- QEfficient/base/modeling_qeff.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 2655dbc45f..2610beee7b 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -597,7 +597,8 @@ def _compile( if mdp_ts_json_path: command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") mdp_ts_json = load_json(str(mdp_ts_json_path)) - elif mdp_num_partitions > 1: + elif mdp_num_partitions > 1 or "stages" in compiler_options: + mdp_num_partitions = compiler_options.pop("stages", mdp_num_partitions) # Disaggregated (pipeline-parallel) MDP: generate a fully-populated # nodeList per partition directly from the ONNX graph — no compiler # round-trip required. From f924a1d865fdaaeac73953f276f8ef8b1a977985 Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Mon, 25 May 2026 01:49:26 -0700 Subject: [PATCH 4/6] Added support for layerwise export Signed-off-by: Mohit Mehta --- QEfficient/compile/mdp_generator.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py index cbc0f606ec..53faa64132 100644 --- a/QEfficient/compile/mdp_generator.py +++ b/QEfficient/compile/mdp_generator.py @@ -61,7 +61,12 @@ def _get_compiler_folded_nodes(graph) -> Set[str]: def _get_layer_num(node_name: str) -> Optional[int]: """Return transformer layer index from node name, or None. - Supports layers.N (Llama/Mistral/Qwen/Gemma/Granite) and h.N (GPT-2). + Supports: + - layers.N (Llama/Mistral/Qwen/Gemma/Granite) + - h.N (GPT-2) + - layer_N// prefix (subfunctions/merged ONNX where each subgraph is + prefixed with the layer index it belongs to, e.g. + "layer_3//model/embed_tokens/Gather") """ for part in node_name.split("/"): if part.startswith("layers."): @@ -72,6 +77,10 @@ def _get_layer_num(node_name: str) -> Optional[int]: suffix = part[len("h.") :] if suffix.isdigit(): return int(suffix) + elif part.startswith("layer_"): + suffix = part[len("layer_") :] + if suffix.isdigit(): + return int(suffix) return None @@ -92,8 +101,14 @@ def _get_inlined_node_map(model) -> tuple: (known custom ops or >= 100 nodes); their call-site names are valid nodeList entries. """ - # Registered with DEFINEKNOWNCUSTOMOP in ONNXModelLoader.cpp - _KNOWN_CUSTOM_OPS = frozenset({"CustomRMSNorm"}) + # Op types registered via DEFINEKNOWNCUSTOMOP in ONNXModelLoaderCustomOp.cpp. + # The compiler loads these as a single named node (not expanded/inlined). + _KNOWN_CUSTOM_OPS = frozenset( + { + "CustomRMSNorm", + "CastToUInt4", + } + ) local_functions = {f.name: f for f in model.functions} logger.info(f"Found {len(local_functions)} local function types: {set(local_functions.keys())}") @@ -195,7 +210,7 @@ def generate_disagg_mdp_partition_config( max_layer_seen = -1 for node in model.graph.node: - if not node.name.startswith("/"): + if not node.name: continue if node.name in folded_nodes: continue From 5b9fb71fb010bc662f23cb58e91c987dc6978072 Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Fri, 29 May 2026 02:06:18 -0700 Subject: [PATCH 5/6] Minor Changes Addressed Rishin Comments Signed-off-by: Mohit Mehta --- QEfficient/base/modeling_qeff.py | 10 +++------- QEfficient/compile/mdp_generator.py | 9 ++++++++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index a1c66cc360..6116ae723e 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -26,6 +26,7 @@ ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.blocking.blocking_configurator import build_transformer_blocking_config_for_transform +from QEfficient.compile.mdp_generator import generate_disagg_mdp_partition_config from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.models.pytorch_transforms import ( @@ -516,7 +517,7 @@ def _compile( specializations: Optional[List[Dict[str, int]]] = None, custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, - mdp_num_partitions: int = 1, + mdp_num_partitions: Optional[int] = 1, num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, @@ -613,20 +614,15 @@ def _compile( # 2. Disaggregated (pipeline-parallel) MDP — generate from ONNX topsort. # 3. Template (tensor-slice) MDP — single partition, nodeList absent. mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None) - # Silently discard any stale mdp_dump_partition_config key that callers - # may still pass; the compiler-round-trip dump path is no longer supported. - compiler_options.pop("mdp_dump_partition_config", None) mdp_ts_json = None if mdp_ts_json_path: command.append(f"-mdp-load-partition-config={mdp_ts_json_path}") mdp_ts_json = load_json(str(mdp_ts_json_path)) - elif mdp_num_partitions > 1 or "stages" in compiler_options: - mdp_num_partitions = compiler_options.pop("stages", mdp_num_partitions) + elif mdp_num_partitions > 1: # Disaggregated (pipeline-parallel) MDP: generate a fully-populated # nodeList per partition directly from the ONNX graph — no compiler # round-trip required. - from QEfficient.compile.mdp_generator import generate_disagg_mdp_partition_config num_layers = getattr(self, "num_layers", None) if num_layers is None: diff --git a/QEfficient/compile/mdp_generator.py b/QEfficient/compile/mdp_generator.py index 53faa64132..1d6de9da09 100644 --- a/QEfficient/compile/mdp_generator.py +++ b/QEfficient/compile/mdp_generator.py @@ -162,7 +162,14 @@ def generate_disagg_mdp_partition_config( Returns: dict with keys 'connections' and 'partitions'. """ - assert num_partitions <= num_devices, f"num_partitions ({num_partitions}) must be <= num_devices ({num_devices})" + + if num_partitions <= 0: + raise ValueError(f"Invalid number of partitions: {num_partitions}") + + if num_partitions > num_devices: + raise ValueError( + f"Num of partitions should be <= number of devices. Found {num_partitions} partitions and {num_devices} devices" + ) layers_per_partition = num_layers // num_partitions model = onnx.load(onnx_path, load_external_data=False) From e176948a5abcaebe81c8be62551730565f59dd09 Mon Sep 17 00:00:00 2001 From: Mohit Mehta Date: Tue, 2 Jun 2026 02:12:18 -0700 Subject: [PATCH 6/6] Add support for VLMs Signed-off-by: Mohit Mehta --- QEfficient/base/modeling_qeff.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6116ae723e..170020e396 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -625,8 +625,12 @@ def _compile( # round-trip required. num_layers = getattr(self, "num_layers", None) + if getattr(self, "model", None) and getattr(self.model, "language_model", None) and not num_layers: + num_layers = getattr(self.model.language_model.config, "num_hidden_layers", None) if num_layers is None: - raise AttributeError("Model does not expose 'num_layers'. Cannot generate disagg MDP partition config.") + raise AttributeError( + "Model or Language Model does not expose 'num_layers' or 'num_hidden_layers' respectively. Cannot generate disagg MDP partition config." + ) num_cores = compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES) logger.info( f"Generating disagg MDP partition config from ONNX: "