diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md
index 1cc1acfbf9..83a89ad37d 100755
--- a/examples/llm_ptq/README.md
+++ b/examples/llm_ptq/README.md
@@ -109,6 +109,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
| Gemma 3 | ✅2 | - | ✅ | - | - |
| QWen 2, 2.5 4 | ✅ | ✅ | ✅ | ✅ | ✅ |
| QWen3, 3.5 MOE, Next 6 | ✅ | - | - | - | ✅ |
+| QWen3.5 6 | ✅ | - | ✅ | - | - |
| QwQ | ✅ | - | - | - | ✅ |
| DeepSeek V3, R1, V3.1, V3.27 | - | - | - | - | ✅ |
| GLM-4.78 | ✅ | - | - | - | ✅ |
diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py
index c2d4d4bfca..ac07295220 100755
--- a/examples/llm_ptq/example_utils.py
+++ b/examples/llm_ptq/example_utils.py
@@ -209,6 +209,7 @@ def build_quant_cfg(
model_type,
moe_calib_experts_ratio: float | None = None,
) -> dict[str, Any]:
+ """Build quantization config with model-specific overrides for AWQ, SmoothQuant, and VLM."""
quant_cfg = copy.deepcopy(quant_cfg)
if "awq" in str(quant_cfg.get("algorithm")):
from modelopt.torch.quantization.config import find_quant_cfg_entry_by_path
@@ -252,6 +253,17 @@ def build_quant_cfg(
quant_cfg["quant_cfg"].append({"quantizer_name": "*image*", "enable": False})
quant_cfg["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False})
+ if model_type == "qwen3_5":
+ # GatedDeltaNet's in_proj_b and in_proj_a have very narrow output dimensions
+ # (hidden_size -> num_v_heads, e.g. 1024 -> 16), quantizing them causes accuracy loss.
+ quant_cfg["quant_cfg"].append({"quantizer_name": "*in_proj_b*", "enable": False})
+ quant_cfg["quant_cfg"].append({"quantizer_name": "*in_proj_a*", "enable": False})
+ # TRT-LLM's Qwen3.5 linear-attention packing only supports weight/weight_scale_inv;
+ # disable activation quantization so input_scale is not exported for these layers.
+ quant_cfg["quant_cfg"].append(
+ {"quantizer_name": "*linear_attn*input_quantizer", "enable": False}
+ )
+
return quant_cfg
diff --git a/examples/vlm_ptq/README.md b/examples/vlm_ptq/README.md
index 8b9c31aa42..2a4a16f1e2 100644
--- a/examples/vlm_ptq/README.md
+++ b/examples/vlm_ptq/README.md
@@ -38,6 +38,7 @@ Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#getting-started) fo
| VILA | ✅ | ✅ | ✅ | ✅ | - |
| Phi-3-vision, Phi-4-multimodal | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen2, 2.5-VL | ✅ | ✅ | ✅ | ✅ | ✅ |
+| Qwen3.5 | ✅ | - | ✅ | - | - |
| Gemma3 | ✅ | - | - | - | - |
> *1.Only TensorRT-LLM checkpoint export is supported. Not compatible with the TensorRT-LLM torch backend* \
diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py
index 3bd72d9de9..ceea524c88 100755
--- a/modelopt/torch/export/model_utils.py
+++ b/modelopt/torch/export/model_utils.py
@@ -29,6 +29,8 @@
"MPT": "mpt",
"Bloom": "bloom",
"ChatGLM": "chatglm",
+ "Qwen3_5Moe": "qwen3_5moe",
+ "Qwen3_5": "qwen3_5",
"Qwen3Moe": "qwen3moe",
"Qwen3Next": "qwen3next",
"QWen": "qwen",
diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py
index 4ceb51cd2c..8b3c7612fc 100755
--- a/modelopt/torch/export/quant_utils.py
+++ b/modelopt/torch/export/quant_utils.py
@@ -1221,7 +1221,7 @@ def _update_svdquant(modules, new_pre_quant_scale):
# Mathematical equivalence:
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
- (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
+ (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP", "Qwen3_5MLP"], ("up_proj", "down_proj")),
]
diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py
index 8fe2f68b32..4a338ca67e 100644
--- a/tests/_test_utils/torch/transformers_models.py
+++ b/tests/_test_utils/torch/transformers_models.py
@@ -35,6 +35,11 @@
T5ForConditionalGeneration,
)
+try:
+ from transformers import Qwen3_5TextConfig
+except ImportError:
+ Qwen3_5TextConfig = None
+
import modelopt.torch.opt as mto
SEED = 1234
@@ -107,6 +112,7 @@ def get_tiny_qwen3_moe(**config_kwargs) -> PreTrainedModel:
def create_tiny_qwen3_moe_dir(
tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs
) -> Path:
+ """Save a tiny Qwen3 MoE model (and optional tokenizer) to a temp directory."""
qwen3_moe_dir = Path(tmp_path) / "tiny_qwen3_moe"
if with_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
@@ -117,9 +123,42 @@ def create_tiny_qwen3_moe_dir(
get_tiny_qwen3_moe(**config_kwargs).save_pretrained(qwen3_moe_dir)
return qwen3_moe_dir
+##### Qwen3.5 (hybrid linear attention + full attention) #####
+def get_tiny_qwen3_5(**config_kwargs) -> PreTrainedModel:
+ """Create a tiny Qwen3.5 model with hybrid GatedDeltaNet + full attention layers for testing."""
+ if Qwen3_5TextConfig is None:
+ pytest.skip("Qwen3_5TextConfig not available (requires transformers >= 4.57)")
+
+ set_seed(SEED)
+
+ kwargs = {
+ "dtype": torch.bfloat16,
+ "hidden_size": 32,
+ "intermediate_size": 32,
+ "num_hidden_layers": 4,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 1,
+ "head_dim": 16,
+ "linear_num_key_heads": 4,
+ "linear_num_value_heads": 4,
+ "linear_key_head_dim": 8,
+ "linear_value_head_dim": 8,
+ "linear_conv_kernel_dim": 4,
+ "full_attention_interval": 4,
+ "attn_output_gate": True,
+ "max_position_embeddings": 32,
+ "vocab_size": 32,
+ "rms_norm_eps": 1e-6,
+ }
+ kwargs.update(**config_kwargs)
+ tiny_qwen3_5 = AutoModelForCausalLM.from_config(Qwen3_5TextConfig(**kwargs))
+
+ return tiny_qwen3_5
+
##### GPT-OSS #####
def get_tiny_gpt_oss(**config_kwargs) -> PreTrainedModel:
+ """Create a tiny GPT-OSS MoE model for testing."""
set_seed(SEED)
kwargs = {
diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py
index 692ab07d4a..326781e723 100644
--- a/tests/unit/torch/quantization/plugins/test_huggingface.py
+++ b/tests/unit/torch/quantization/plugins/test_huggingface.py
@@ -24,6 +24,7 @@
create_tiny_llama_dir,
get_tiny_gpt_oss,
get_tiny_llama,
+ get_tiny_qwen3_5,
get_tiny_qwen3_moe,
tf_modelopt_state_and_output_tester,
)
@@ -235,6 +236,7 @@ def test_is_homogeneous_hf_model_gpt_oss():
def test_hf_decoder_discoverer_registration_path():
+ """Verify HF decoder layer discoverer is registered and returns correct layers."""
model = get_tiny_llama()
assert any(
is_supported is is_homogeneous_hf_model and discoverer is get_homogeneous_hf_decoder_layers
@@ -243,3 +245,73 @@ def test_hf_decoder_discoverer_registration_path():
assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers(
model
)
+
+
+@pytest.mark.parametrize(
+ "quant_config",
+ [mtq.FP8_DEFAULT_CFG, mtq.INT4_AWQ_CFG],
+ ids=["fp8", "int4_awq"],
+)
+def test_qwen3_5_hybrid_attention_quantize(quant_config):
+ """Verify FP8 and AWQ quantization works for Qwen3.5 hybrid (GatedDeltaNet + Attention)."""
+ import copy
+
+ model = get_tiny_qwen3_5()
+
+ quant_cfg = copy.deepcopy(quant_config)
+ if quant_config is mtq.INT4_AWQ_CFG:
+ for entry in quant_cfg["quant_cfg"]:
+ if entry["quantizer_name"] == "*weight_quantizer":
+ entry.setdefault("cfg", {})["block_sizes"] = {-1: 16}
+ break
+
+ # Disable narrow GatedDeltaNet projections (same as example_utils does for qwen3_5)
+ quant_cfg["quant_cfg"].append({"quantizer_name": "*in_proj_b*", "enable": False})
+ quant_cfg["quant_cfg"].append({"quantizer_name": "*in_proj_a*", "enable": False})
+ # Disable activation quantization for linear attention (TRT-LLM packing limitation)
+ quant_cfg["quant_cfg"].append(
+ {"quantizer_name": "*linear_attn*input_quantizer", "enable": False}
+ )
+
+ def calib_fn(model):
+ """Run calibration forward passes with dummy inputs."""
+ x = model.dummy_inputs["input_ids"]
+ for _ in range(2):
+ model(x)
+
+ mtq.quantize(model, quant_cfg, calib_fn)
+
+ # Verify the model still produces output
+ with torch.no_grad():
+ out = model(model.dummy_inputs["input_ids"])
+ assert out.logits is not None
+
+ # Verify both GatedDeltaNet and Attention linear layers got quantized
+ has_gdn_quantized = False
+ has_attn_quantized = False
+ for name, module in model.named_modules():
+ if hasattr(module, "weight_quantizer") and hasattr(module, "weight"):
+ if "linear_attn.in_proj_qkv" in name and module.weight_quantizer.is_enabled:
+ has_gdn_quantized = True
+ if "self_attn.q_proj" in name and module.weight_quantizer.is_enabled:
+ has_attn_quantized = True
+ assert has_gdn_quantized, "GatedDeltaNet linear layers should be quantized"
+ assert has_attn_quantized, "Attention linear layers should be quantized"
+
+ # Verify narrow projections are NOT quantized
+ for name, module in model.named_modules():
+ if "in_proj_b" in name and hasattr(module, "weight_quantizer"):
+ assert not module.weight_quantizer.is_enabled, (
+ f"in_proj_b should have quantization disabled: {name}"
+ )
+ if "in_proj_a" in name and hasattr(module, "weight_quantizer"):
+ assert not module.weight_quantizer.is_enabled, (
+ f"in_proj_a should have quantization disabled: {name}"
+ )
+
+ # Verify linear_attn input quantizers are disabled (no input_scale for TRT-LLM export)
+ for name, module in model.named_modules():
+ if "linear_attn" in name and hasattr(module, "input_quantizer"):
+ assert not module.input_quantizer.is_enabled, (
+ f"linear_attn input_quantizer should be disabled: {name}"
+ )