diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index cb87995aaf6..e74ae810a02 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -98,7 +98,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs, attention: Attention): + def __init__( + self, args: ModelArgs, attention: Attention, mlp_type: str = "default" + ): """ Transformer block with support for pre-norm and post-norm. Args: @@ -106,6 +108,8 @@ def __init__(self, args: ModelArgs, attention: Attention): attention (Attention): attention object to use in the transformer block. See `attention.py` for types of attention. Make sure the attention type is registered in the ATTENTION_REGISTRY. + mlp_type (str): MLP type for this layer. "default" for standard + FFN, "skip" for no FFN block. """ super().__init__() self.use_kv_cache = args.use_kv_cache @@ -113,11 +117,14 @@ def __init__(self, args: ModelArgs, attention: Attention): self.dim = args.dim self.head_dim = args.head_dim self.attention = attention + self.mlp_type = mlp_type.lower() assert ( args.hidden_dim is not None ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock." - if args.moe: + if self.mlp_type == "skip": + pass # No FFN block for this layer + elif args.moe: self.block_sparse_moe = MOEFeedForward(args) elif args.target_modules is not None and ( "down_proj" in args.target_modules @@ -136,11 +143,12 @@ def __init__(self, args: ModelArgs, attention: Attention): eps=args.norm_eps, add_unit_offset=args.rms_norm_add_unit_offset, ) - self.ffn_norm = RMSNorm( - args.dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) + if self.mlp_type != "skip": + self.ffn_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": @@ -156,9 +164,12 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock": f"Unknown attention type: {args.attention_type}. " f"Available: {list(ATTENTION_REGISTRY.keys())}" ) + mlp_type = "default" + if args.mlp_type is not None and layer_id < len(args.mlp_type): + mlp_type = args.mlp_type[layer_id] cls = ATTENTION_REGISTRY[args.attention_type] attention = cls(args, layer_id, rope, **args.attention_kwargs) - return TransformerBlock(args, attention) + return TransformerBlock(args, attention, mlp_type=mlp_type) def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention( @@ -167,7 +178,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: if not isinstance(self.attention, AttentionSkip): h = x + h - if hasattr(self, "block_sparse_moe"): + if self.mlp_type == "skip": + out = h + elif hasattr(self, "block_sparse_moe"): out = h + self.block_sparse_moe(self.ffn_norm(h)) else: out = h + self.feed_forward(self.ffn_norm(h)) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 402c6c39750..104e9fe2ddd 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -145,6 +145,9 @@ class ModelArgs: attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) # Hybrid models can have layer types different from attention layer_types: Optional[list] = None + # Per-layer MLP type: "default" for standard FFN, "skip" for no FFN block. + # Indexed by layer id (e.g. mlp_type[0] applies to layer 0). + mlp_type: Optional[list] = None model_architecture: Optional[str] = ( None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now. )