Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, attention: Attention):
def __init__(
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
):
"""
Transformer block with support for pre-norm and post-norm.
Args:
args (ModelArgs): model configuration parameters.
attention (Attention): attention object to use in the transformer
block. See `attention.py` for types of attention. Make sure
the attention type is registered in the ATTENTION_REGISTRY.
mlp_type (str): MLP type for this layer. "default" for standard
FFN, "skip" for no FFN block.
"""
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.head_dim
self.attention = attention
self.mlp_type = mlp_type.lower()

assert (
args.hidden_dim is not None
), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
if args.moe:
if self.mlp_type == "skip":
pass # No FFN block for this layer
elif args.moe:
self.block_sparse_moe = MOEFeedForward(args)
elif args.target_modules is not None and (
"down_proj" in args.target_modules
Expand All @@ -136,11 +143,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
self.ffn_norm = RMSNorm(
args.dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
if self.mlp_type != "skip":
self.ffn_norm = RMSNorm(
args.dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)

@classmethod
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
Expand All @@ -156,9 +164,12 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
f"Unknown attention type: {args.attention_type}. "
f"Available: {list(ATTENTION_REGISTRY.keys())}"
)
mlp_type = "default"
if args.mlp_type is not None and layer_id < len(args.mlp_type):
mlp_type = args.mlp_type[layer_id]
cls = ATTENTION_REGISTRY[args.attention_type]
attention = cls(args, layer_id, rope, **args.attention_kwargs)
return TransformerBlock(args, attention)
return TransformerBlock(args, attention, mlp_type=mlp_type)

def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
h, attn_options_update = self.attention(
Expand All @@ -167,7 +178,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
if not isinstance(self.attention, AttentionSkip):
h = x + h

if hasattr(self, "block_sparse_moe"):
if self.mlp_type == "skip":
out = h
elif hasattr(self, "block_sparse_moe"):
out = h + self.block_sparse_moe(self.ffn_norm(h))
else:
out = h + self.feed_forward(self.ffn_norm(h))
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class ModelArgs:
attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
# Hybrid models can have layer types different from attention
layer_types: Optional[list] = None
# Per-layer MLP type: "default" for standard FFN, "skip" for no FFN block.
# Indexed by layer id (e.g. mlp_type[0] applies to layer 0).
mlp_type: Optional[list] = None
model_architecture: Optional[str] = (
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
)
Expand Down
Loading