Skip to content

Commit 35590b6

Browse files
committed
coderabbit
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent c9153ed commit 35590b6

4 files changed

Lines changed: 19 additions & 8 deletions

File tree

examples/speculative_decoding/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,10 @@ def train():
268268
mtsp.convert(model, [("medusa", config)])
269269
elif training_args.mode == "eagle3":
270270
# Validate and rewrite eagle config fields
271-
EagleConfig.model_validate(
271+
eagle_cfg = EagleConfig.model_validate(
272272
eagle_cfg,
273273
context={"training_args": training_args, "data_args": data_args},
274-
)
274+
).model_dump()
275275
mtsp.convert(model, [("eagle", eagle_cfg)])
276276

277277
# Load draft vocab cache if the draft model uses a compressed vocabulary

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,15 @@ def _get_config_from_draft_or_base(key: str, model: nn.Module):
187187
new_value = str(new_value).replace("torch.", "")
188188
template_config[key] = new_value
189189

190-
# Inject export rope scaling override (validated at config time to require
191-
# training rope_type == "default").
190+
# Inject export rope scaling override when training rope_type is "default".
191+
rope_cfg = self.model.eagle_config.rope_scaling or {}
192+
training_rope_type = rope_cfg.get("rope_type") or rope_cfg.get("type")
192193
eagle_export_rope_scaling = getattr(self.model, "eagle_export_rope_scaling", None)
193-
if eagle_export_rope_scaling:
194+
if eagle_export_rope_scaling and training_rope_type == "default":
194195
template_config["rope_scaling"] = eagle_export_rope_scaling
195196

196197
# In transformers 5.x, rope_theta is under rope_scaling, not the main config.
197-
rope_cfg = self.model.eagle_config.rope_scaling
198+
# Always source from the training rope config (rope_theta is not in export overrides).
198199
if template_config.get("rope_theta") is None and rope_cfg:
199200
template_config["rope_theta"] = rope_cfg.get("rope_theta")
200201

modelopt/torch/speculative/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def _derive_eagle_offline(cls, data: Any, info: ValidationInfo) -> Any:
148148
def _check_rope_scaling_consistency(self) -> "EagleConfig":
149149
if not self.eagle_export_rope_scaling:
150150
return self
151-
rope_cfg = self.eagle_architecture_config.get("rope_scaling", {})
152-
rope_type = rope_cfg.get("rope_type")
151+
rope_cfg = self.eagle_architecture_config.get("rope_scaling", {}) or {}
152+
rope_type = rope_cfg.get("rope_type") or rope_cfg.get("type")
153153
if rope_type is not None and rope_type != "default":
154154
raise ValueError(
155155
f"eagle_export_rope_scaling is set but eagle_architecture_config has "

tests/unit/torch/speculative/test_eagle_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ def test_rope_consistency_error_non_default_rope_type():
3636
EagleConfig.model_validate(cfg)
3737

3838

39+
def test_rope_consistency_error_non_default_rope_type_alt_key():
40+
"""Error when rope_scaling uses 'type' key instead of 'rope_type' (kimik2-style)."""
41+
cfg = {
42+
"eagle_export_rope_scaling": {"rope_type": "yarn", "factor": 32.0},
43+
"eagle_architecture_config": {"rope_scaling": {"type": "yarn"}},
44+
}
45+
with pytest.raises(ValidationError, match="rope_type='yarn'"):
46+
EagleConfig.model_validate(cfg)
47+
48+
3949
def test_rope_consistency_ok_default_rope_type():
4050
"""No error when training rope_type is 'default'."""
4151
cfg = {

0 commit comments

Comments
 (0)