Problem
Gemma 4 MoE models (e.g., google/gemma-4-26B-A4B-it) store expert weights as fused 3D nn.Parameter tensors:
# In Gemma4TextExperts:
self.gate_up_proj = nn.Parameter(torch.empty(num_experts, 2 * intermediate_dim, hidden_dim))
self.down_proj = nn.Parameter(torch.empty(num_experts, hidden_dim, intermediate_dim))
This is different from Mixtral/Qwen/Nemotron which use nn.ModuleList of nn.Linear layers. Since modelopt's quantizer only discovers nn.Linear modules, it silently skips 91% of the model's weights (the experts are the bulk of a MoE model).
Running hf_ptq.py --qformat nvfp4 on Gemma 4 26B-A4B produces a 46GB output instead of the expected ~16GB — the expert tensors remain in BF16.
Solution
We wrote a _QuantGemma4TextExperts plugin following the same pattern as _QuantQwen35MoeExperts (line 784 of plugins/huggingface.py). It:
- Registers as a
QuantModule for Gemma4TextExperts
- Unfuses the 3D
gate_up_proj and down_proj into 128 individual nn.Linear layers per expert
- Provides
forward(), __getitem__, __len__, __iter__ matching the original expert routing logic
- modelopt then quantizes each expert as a standard
nn.Linear
Plugin code
class _Gemma4ExpertModule(nn.Module):
def __init__(self, hidden_dim, intermediate_dim):
super().__init__()
self.gate_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.up_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False)
class _QuantGemma4TextExperts(QuantModule):
def _setup(self):
from accelerate import init_empty_weights
dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device
def _copy_weight(module, weight):
module.to_empty(device=device)
with torch.no_grad():
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)
expert_dim = self.intermediate_dim
with init_empty_weights():
expert_modules = nn.ModuleList([
_Gemma4ExpertModule(self.hidden_dim, expert_dim)
for _ in range(self.num_experts)
])
for idx in range(self.num_experts):
_copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :expert_dim, :])
_copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, expert_dim:, :])
_copy_weight(expert_modules[idx].down_proj, self.down_proj[idx])
delattr(self, "gate_up_proj")
delattr(self, "down_proj")
for idx in range(self.num_experts):
self.add_module(str(idx), expert_modules[idx])
def __len__(self):
return self.num_experts
def __iter__(self):
for idx in range(self.num_experts):
yield getattr(self, str(idx))
def __getitem__(self, idx):
return getattr(self, str(int(idx)))
def forward(self, hidden_states, top_k_index, top_k_weights):
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
with torch.no_grad():
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
expert = self[expert_idx]
gate = expert.gate_proj(current_state)
up = expert.up_proj(current_state)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = expert.down_proj(current_hidden_states)
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
Registration
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextExperts
QuantModuleRegistry.register({Gemma4TextExperts: "hf.Gemma4TextExperts"})(_QuantGemma4TextExperts)
Results
Additional notes
- Vision encoder should be excluded from quantization (
*vision*, *embed_vision* → enable: False)
- Post-export key renaming needed: modelopt exports
experts.E.proj.weight but vLLM expects moe.experts.E.proj.weight
- Calibration: 4096 samples with natural expert routing (forced
moe_calib_experts_ratio degrades quality)
- Requires
transformers >= 5.4 for Gemma 4 architecture support
Environment
Problem
Gemma 4 MoE models (e.g.,
google/gemma-4-26B-A4B-it) store expert weights as fused 3Dnn.Parametertensors:This is different from Mixtral/Qwen/Nemotron which use
nn.ModuleListofnn.Linearlayers. Since modelopt's quantizer only discoversnn.Linearmodules, it silently skips 91% of the model's weights (the experts are the bulk of a MoE model).Running
hf_ptq.py --qformat nvfp4on Gemma 4 26B-A4B produces a 46GB output instead of the expected ~16GB — the expert tensors remain in BF16.Solution
We wrote a
_QuantGemma4TextExpertsplugin following the same pattern as_QuantQwen35MoeExperts(line 784 ofplugins/huggingface.py). It:QuantModuleforGemma4TextExpertsgate_up_projanddown_projinto 128 individualnn.Linearlayers per expertforward(),__getitem__,__len__,__iter__matching the original expert routing logicnn.LinearPlugin code
Registration
Results
--quantization modelopt --moe-backend marlinAdditional notes
*vision*,*embed_vision*→enable: False)experts.E.proj.weightbut vLLM expectsmoe.experts.E.proj.weightmoe_calib_experts_ratiodegrades quality)transformers >= 5.4for Gemma 4 architecture supportEnvironment