Skip to content

Add _QuantGemma4TextExperts plugin for fused 3D MoE expert quantization #1173

@marioiseli89

Description

@marioiseli89

Problem

Gemma 4 MoE models (e.g., google/gemma-4-26B-A4B-it) store expert weights as fused 3D nn.Parameter tensors:

# In Gemma4TextExperts:
self.gate_up_proj = nn.Parameter(torch.empty(num_experts, 2 * intermediate_dim, hidden_dim))
self.down_proj = nn.Parameter(torch.empty(num_experts, hidden_dim, intermediate_dim))

This is different from Mixtral/Qwen/Nemotron which use nn.ModuleList of nn.Linear layers. Since modelopt's quantizer only discovers nn.Linear modules, it silently skips 91% of the model's weights (the experts are the bulk of a MoE model).

Running hf_ptq.py --qformat nvfp4 on Gemma 4 26B-A4B produces a 46GB output instead of the expected ~16GB — the expert tensors remain in BF16.

Solution

We wrote a _QuantGemma4TextExperts plugin following the same pattern as _QuantQwen35MoeExperts (line 784 of plugins/huggingface.py). It:

  1. Registers as a QuantModule for Gemma4TextExperts
  2. Unfuses the 3D gate_up_proj and down_proj into 128 individual nn.Linear layers per expert
  3. Provides forward(), __getitem__, __len__, __iter__ matching the original expert routing logic
  4. modelopt then quantizes each expert as a standard nn.Linear

Plugin code

class _Gemma4ExpertModule(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
        self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
        self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False)


class _QuantGemma4TextExperts(QuantModule):
    def _setup(self):
        from accelerate import init_empty_weights
        dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device

        def _copy_weight(module, weight):
            module.to_empty(device=device)
            with torch.no_grad():
                module.weight.data = weight.detach().data.to(dtype=dtype, device=device)

        expert_dim = self.intermediate_dim
        with init_empty_weights():
            expert_modules = nn.ModuleList([
                _Gemma4ExpertModule(self.hidden_dim, expert_dim)
                for _ in range(self.num_experts)
            ])

        for idx in range(self.num_experts):
            _copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :expert_dim, :])
            _copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, expert_dim:, :])
            _copy_weight(expert_modules[idx].down_proj, self.down_proj[idx])

        delattr(self, "gate_up_proj")
        delattr(self, "down_proj")
        for idx in range(self.num_experts):
            self.add_module(str(idx), expert_modules[idx])

    def __len__(self):
        return self.num_experts

    def __iter__(self):
        for idx in range(self.num_experts):
            yield getattr(self, str(idx))

    def __getitem__(self, idx):
        return getattr(self, str(int(idx)))

    def forward(self, hidden_states, top_k_index, top_k_weights):
        final_hidden_states = torch.zeros_like(hidden_states)
        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
        for expert_idx in expert_hit:
            expert_idx = expert_idx[0]
            if expert_idx == self.num_experts:
                continue
            with torch.no_grad():
                top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
            current_state = hidden_states[token_idx]
            expert = self[expert_idx]
            gate = expert.gate_proj(current_state)
            up = expert.up_proj(current_state)
            current_hidden_states = self.act_fn(gate) * up
            current_hidden_states = expert.down_proj(current_hidden_states)
            current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
            final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
        return final_hidden_states

Registration

from transformers.models.gemma4.modeling_gemma4 import Gemma4TextExperts
QuantModuleRegistry.register({Gemma4TextExperts: "hf.Gemma4TextExperts"})(_QuantGemma4TextExperts)

Results

Additional notes

  • Vision encoder should be excluded from quantization (*vision*, *embed_vision*enable: False)
  • Post-export key renaming needed: modelopt exports experts.E.proj.weight but vLLM expects moe.experts.E.proj.weight
  • Calibration: 4096 samples with natural expert routing (forced moe_calib_experts_ratio degrades quality)
  • Requires transformers >= 5.4 for Gemma 4 architecture support

Environment

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions