From 9ce06a5e175042ddfd5ed960f4c484e6633350a5 Mon Sep 17 00:00:00 2001 From: Naveen Suda Date: Tue, 14 Apr 2026 21:48:17 -0700 Subject: [PATCH] Add per-layer MLP type support for executorch export (#18856) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add per-layer MLP type support to the ExecuTorch export path. This allows hybrid models to configure FFN blocks per layer (e.g. skip FFN on specified layers), reducing model size and inference latency. The per-layer config uses an mlp_type list in ModelArgs, where each layer can be set to "default" (standard FFN) or "skip" (no FFN block). This is extensible to future MLP types. - Add mlp_type field to ModelArgs (model_args.py) — optional list of strings, one per layer - Update TransformerBlock.__init__ to accept mlp_type string and skip FFN/ffn_norm creation when mlp_type == "skip" (llama_transformer.py) - Update TransformerBlock.from_type() to read mlp_type from ModelArgs per layer - Update TransformerBlock.forward() to pass through attention output directly when mlp_type == "skip" Reviewed By: ifed-ucsd Differential Revision: D100682545 --- examples/models/llama/llama_transformer.py | 31 +++++++++++++++------- examples/models/llama/model_args.py | 3 +++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index cb87995aaf6..e74ae810a02 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -98,7 +98,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs, attention: Attention): + def __init__( + self, args: ModelArgs, attention: Attention, mlp_type: str = "default" + ): """ Transformer block with support for pre-norm and post-norm. Args: @@ -106,6 +108,8 @@ def __init__(self, args: ModelArgs, attention: Attention): attention (Attention): attention object to use in the transformer block. See `attention.py` for types of attention. Make sure the attention type is registered in the ATTENTION_REGISTRY. + mlp_type (str): MLP type for this layer. "default" for standard + FFN, "skip" for no FFN block. """ super().__init__() self.use_kv_cache = args.use_kv_cache @@ -113,11 +117,14 @@ def __init__(self, args: ModelArgs, attention: Attention): self.dim = args.dim self.head_dim = args.head_dim self.attention = attention + self.mlp_type = mlp_type.lower() assert ( args.hidden_dim is not None ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock." - if args.moe: + if self.mlp_type == "skip": + pass # No FFN block for this layer + elif args.moe: self.block_sparse_moe = MOEFeedForward(args) elif args.target_modules is not None and ( "down_proj" in args.target_modules @@ -136,11 +143,12 @@ def __init__(self, args: ModelArgs, attention: Attention): eps=args.norm_eps, add_unit_offset=args.rms_norm_add_unit_offset, ) - self.ffn_norm = RMSNorm( - args.dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) + if self.mlp_type != "skip": + self.ffn_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": @@ -156,9 +164,12 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock": f"Unknown attention type: {args.attention_type}. " f"Available: {list(ATTENTION_REGISTRY.keys())}" ) + mlp_type = "default" + if args.mlp_type is not None and layer_id < len(args.mlp_type): + mlp_type = args.mlp_type[layer_id] cls = ATTENTION_REGISTRY[args.attention_type] attention = cls(args, layer_id, rope, **args.attention_kwargs) - return TransformerBlock(args, attention) + return TransformerBlock(args, attention, mlp_type=mlp_type) def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention( @@ -167,7 +178,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: if not isinstance(self.attention, AttentionSkip): h = x + h - if hasattr(self, "block_sparse_moe"): + if self.mlp_type == "skip": + out = h + elif hasattr(self, "block_sparse_moe"): out = h + self.block_sparse_moe(self.ffn_norm(h)) else: out = h + self.feed_forward(self.ffn_norm(h)) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 402c6c39750..104e9fe2ddd 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -145,6 +145,9 @@ class ModelArgs: attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) # Hybrid models can have layer types different from attention layer_types: Optional[list] = None + # Per-layer MLP type: "default" for standard FFN, "skip" for no FFN block. + # Indexed by layer id (e.g. mlp_type[0] applies to layer 0). + mlp_type: Optional[list] = None model_architecture: Optional[str] = ( None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now. )