Skip to content

Commit 04faf26

Browse files
navsudfacebook-github-bot
authored andcommitted
Add per-layer MLP type support for executorch export (#18856)
Summary: Add per-layer MLP type support to the ExecuTorch export path. This allows hybrid models to configure FFN blocks per layer (e.g. skip FFN on specified layers), reducing model size and inference latency. The per-layer config uses an mlp_type list in ModelArgs, where each layer can be set to "default" (standard FFN) or "skip" (no FFN block). This is extensible to future MLP types. - Add mlp_type field to ModelArgs (model_args.py) — optional list of strings, one per layer - Update TransformerBlock.__init__ to accept mlp_type string and skip FFN/ffn_norm creation when mlp_type == "skip" (llama_transformer.py) - Update TransformerBlock.from_type() to read mlp_type from ModelArgs per layer - Update TransformerBlock.forward() to pass through attention output directly when mlp_type == "skip" Reviewed By: ifed-ucsd Differential Revision: D100682545
1 parent fe71bd4 commit 04faf26

2 files changed

Lines changed: 25 additions & 9 deletions

File tree

examples/models/llama/llama_transformer.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,26 +98,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9898

9999

100100
class TransformerBlock(nn.Module):
101-
def __init__(self, args: ModelArgs, attention: Attention):
101+
def __init__(
102+
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
103+
):
102104
"""
103105
Transformer block with support for pre-norm and post-norm.
104106
Args:
105107
args (ModelArgs): model configuration parameters.
106108
attention (Attention): attention object to use in the transformer
107109
block. See `attention.py` for types of attention. Make sure
108110
the attention type is registered in the ATTENTION_REGISTRY.
111+
mlp_type (str): MLP type for this layer. "default" for standard
112+
FFN, "skip" for no FFN block.
109113
"""
110114
super().__init__()
111115
self.use_kv_cache = args.use_kv_cache
112116
self.n_heads = args.n_heads
113117
self.dim = args.dim
114118
self.head_dim = args.head_dim
115119
self.attention = attention
120+
self.mlp_type = mlp_type.lower()
116121

117122
assert (
118123
args.hidden_dim is not None
119124
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
120-
if args.moe:
125+
if self.mlp_type == "skip":
126+
pass # No FFN block for this layer
127+
elif args.moe:
121128
self.block_sparse_moe = MOEFeedForward(args)
122129
elif args.target_modules is not None and (
123130
"down_proj" in args.target_modules
@@ -136,11 +143,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
136143
eps=args.norm_eps,
137144
add_unit_offset=args.rms_norm_add_unit_offset,
138145
)
139-
self.ffn_norm = RMSNorm(
140-
args.dim,
141-
eps=args.norm_eps,
142-
add_unit_offset=args.rms_norm_add_unit_offset,
143-
)
146+
if self.mlp_type != "skip":
147+
self.ffn_norm = RMSNorm(
148+
args.dim,
149+
eps=args.norm_eps,
150+
add_unit_offset=args.rms_norm_add_unit_offset,
151+
)
144152

145153
@classmethod
146154
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
@@ -156,9 +164,12 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
156164
f"Unknown attention type: {args.attention_type}. "
157165
f"Available: {list(ATTENTION_REGISTRY.keys())}"
158166
)
167+
mlp_type = "default"
168+
if args.mlp_type is not None and layer_id < len(args.mlp_type):
169+
mlp_type = args.mlp_type[layer_id]
159170
cls = ATTENTION_REGISTRY[args.attention_type]
160171
attention = cls(args, layer_id, rope, **args.attention_kwargs)
161-
return TransformerBlock(args, attention)
172+
return TransformerBlock(args, attention, mlp_type=mlp_type)
162173

163174
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
164175
h, attn_options_update = self.attention(
@@ -167,7 +178,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
167178
if not isinstance(self.attention, AttentionSkip):
168179
h = x + h
169180

170-
if hasattr(self, "block_sparse_moe"):
181+
if self.mlp_type == "skip":
182+
out = h
183+
elif hasattr(self, "block_sparse_moe"):
171184
out = h + self.block_sparse_moe(self.ffn_norm(h))
172185
else:
173186
out = h + self.feed_forward(self.ffn_norm(h))

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ class ModelArgs:
145145
attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
146146
# Hybrid models can have layer types different from attention
147147
layer_types: Optional[list] = None
148+
# Per-layer MLP type: "default" for standard FFN, "skip" for no FFN block.
149+
# Indexed by layer id (e.g. mlp_type[0] applies to layer 0).
150+
mlp_type: Optional[list] = None
148151
model_architecture: Optional[str] = (
149152
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
150153
)

0 commit comments

Comments
 (0)