Skip to content

Commit a341ca0

Browse files
navsudfacebook-github-bot
authored andcommitted
Add per-layer MLP type support for on-device ANE export
Summary: Add per-layer MLP type support to the ExecuTorch export path. This allows hybrid models to configure FFN blocks per layer (e.g. skip FFN on specified layers), reducing model size and inference latency. The per-layer config uses an mlp_type list in ModelArgs, where each layer can be set to "default" (standard FFN) or "skip" (no FFN block). This is extensible to future MLP types. - Add mlp_type field to ModelArgs (model_args.py) — optional list of strings, one per layer - Update TransformerBlock.__init__ to accept mlp_type string and skip FFN/ffn_norm creation when mlp_type == "skip" (llama_transformer.py) - Update TransformerBlock.from_type() to read mlp_type from ModelArgs per layer - Update TransformerBlock.forward() to pass through attention output directly when mlp_type == "skip" Differential Revision: D100682545
1 parent fe71bd4 commit a341ca0

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,26 +98,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9898

9999

100100
class TransformerBlock(nn.Module):
101-
def __init__(self, args: ModelArgs, attention: Attention):
101+
def __init__(self, args: ModelArgs, attention: Attention, mlp_type: str = "default"):
102102
"""
103103
Transformer block with support for pre-norm and post-norm.
104104
Args:
105105
args (ModelArgs): model configuration parameters.
106106
attention (Attention): attention object to use in the transformer
107107
block. See `attention.py` for types of attention. Make sure
108108
the attention type is registered in the ATTENTION_REGISTRY.
109+
mlp_type (str): MLP type for this layer. "default" for standard
110+
FFN, "skip" for no FFN block.
109111
"""
110112
super().__init__()
111113
self.use_kv_cache = args.use_kv_cache
112114
self.n_heads = args.n_heads
113115
self.dim = args.dim
114116
self.head_dim = args.head_dim
115117
self.attention = attention
118+
self.mlp_type = mlp_type
116119

117120
assert (
118121
args.hidden_dim is not None
119122
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
120-
if args.moe:
123+
if mlp_type == "skip":
124+
pass # No FFN block for this layer
125+
elif args.moe:
121126
self.block_sparse_moe = MOEFeedForward(args)
122127
elif args.target_modules is not None and (
123128
"down_proj" in args.target_modules
@@ -136,11 +141,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
136141
eps=args.norm_eps,
137142
add_unit_offset=args.rms_norm_add_unit_offset,
138143
)
139-
self.ffn_norm = RMSNorm(
140-
args.dim,
141-
eps=args.norm_eps,
142-
add_unit_offset=args.rms_norm_add_unit_offset,
143-
)
144+
if mlp_type != "skip":
145+
self.ffn_norm = RMSNorm(
146+
args.dim,
147+
eps=args.norm_eps,
148+
add_unit_offset=args.rms_norm_add_unit_offset,
149+
)
144150

145151
@classmethod
146152
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
@@ -156,9 +162,12 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
156162
f"Unknown attention type: {args.attention_type}. "
157163
f"Available: {list(ATTENTION_REGISTRY.keys())}"
158164
)
165+
mlp_type = "default"
166+
if args.mlp_type is not None and layer_id < len(args.mlp_type):
167+
mlp_type = args.mlp_type[layer_id]
159168
cls = ATTENTION_REGISTRY[args.attention_type]
160169
attention = cls(args, layer_id, rope, **args.attention_kwargs)
161-
return TransformerBlock(args, attention)
170+
return TransformerBlock(args, attention, mlp_type=mlp_type)
162171

163172
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
164173
h, attn_options_update = self.attention(
@@ -167,7 +176,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
167176
if not isinstance(self.attention, AttentionSkip):
168177
h = x + h
169178

170-
if hasattr(self, "block_sparse_moe"):
179+
if self.mlp_type == "skip":
180+
out = h
181+
elif hasattr(self, "block_sparse_moe"):
171182
out = h + self.block_sparse_moe(self.ffn_norm(h))
172183
else:
173184
out = h + self.feed_forward(self.ffn_norm(h))

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ class ModelArgs:
145145
attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
146146
# Hybrid models can have layer types different from attention
147147
layer_types: Optional[list] = None
148+
# Per-layer MLP type: "default" for standard FFN, "skip" for no FFN block.
149+
# Indexed by layer id (e.g. mlp_type[0] applies to layer 0).
150+
mlp_type: Optional[list] = None
148151
model_architecture: Optional[str] = (
149152
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
150153
)

0 commit comments

Comments
 (0)