From 4971eb3f29f9ba47791b7226f3eb90f7152c3643 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Wed, 6 May 2026 16:55:03 +0800 Subject: [PATCH 1/9] Add Kimi K2.5 model support --- gptqmodel/models/auto.py | 3 ++ gptqmodel/models/definitions/kimi_k25.py | 65 ++++++++++++++++++++++++ tests/test_kimi_k25_support.py | 63 +++++++++++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 gptqmodel/models/definitions/kimi_k25.py create mode 100644 tests/test_kimi_k25_support.py diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 15df19a38..feddf57e9 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +# ruff: noqa: I001 from __future__ import annotations @@ -116,6 +117,7 @@ from .definitions.internlm2 import InternLM2QModel # noqa: E402 from .definitions.internvl_chat import InternVLChatQModel # noqa: E402 from .definitions.klear import KlearQModel # noqa: E402 +from .definitions.kimi_k25 import KimiK25QModel # noqa: E402 from .definitions.laguna import LagunaQModel # noqa: E402 from .definitions.lfm2_moe import LFM2MoeQModel # noqa: E402 from .definitions.llada2 import LLaDA2MoeQModel @@ -182,6 +184,7 @@ "brumby": BrumbyQModel, "gpt_neo": GptNeoQModel, "kimi_k2": DeepSeekV3QModel, # 100% DeepSeekV3QModel clone + "kimi_k25": KimiK25QModel, "klear": KlearQModel, "laguna": LagunaQModel, "gpt_neox": GPTNeoXQModel, diff --git a/gptqmodel/models/definitions/kimi_k25.py b/gptqmodel/models/definitions/kimi_k25.py new file mode 100644 index 000000000..8bf91884c --- /dev/null +++ b/gptqmodel/models/definitions/kimi_k25.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from ...utils.model import get_module +from ..base import BaseQModel +from ..moe_lifecycle import GateUpDownMoELifecycleHooks + + +class KimiK25QModel(BaseQModel): + # Kimi-K2.5 wraps a DeepSeek-V3 text backbone with a vision tower and + # projector. Quantize the language model and keep the vision path in base. + require_trust_remote_code = True + + require_load_processor = True + + pre_lm_head_norm_module = "language_model.model.norm" + + dynamic_expert_index = "n_routed_experts" + + layer_modules_strict = False + + moe_lifecycle_hooks = GateUpDownMoELifecycleHooks() + + module_tree = [ + "language_model", + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ("q_proj:0", "q_a_proj:0", "kv_a_proj_with_mqa:0", "q_b_proj:1", "kv_b_proj:1", "o_proj:2"), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "mlp:moe": { + "": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "experts": { + "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + "shared_experts": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + }, + ] + + @classmethod + def get_base_modules(cls, model): + prefix, core_model = cls._resolve_multimodal_layout(model) + base_modules = [] + for name, _ in core_model.named_children(): + if name != "language_model": + base_modules.append(f"{prefix}.{name}" if prefix else name) + return base_modules + + @classmethod + def _resolve_multimodal_layout(cls, model): + for prefix in ("model", ""): + core_model = get_module(model, prefix) if prefix else model + if core_model is None: + continue + if hasattr(core_model, "language_model"): + return prefix, core_model + raise AttributeError("Unable to resolve Kimi-K2.5 core model with a `language_model` module.") + + +__all__ = ["KimiK25QModel"] diff --git a/tests/test_kimi_k25_support.py b/tests/test_kimi_k25_support.py new file mode 100644 index 000000000..48db9835e --- /dev/null +++ b/tests/test_kimi_k25_support.py @@ -0,0 +1,63 @@ +from types import SimpleNamespace + +from torch import nn + +from gptqmodel.models import auto +from gptqmodel.models.definitions.kimi_k25 import KimiK25QModel + + +def test_kimi_k25_model_type_selects_definition(monkeypatch): + fake_config = SimpleNamespace(model_type="kimi_k25") + + monkeypatch.setattr(auto, "resolve_trust_remote_code", lambda path, trust_remote_code=False: trust_remote_code) + monkeypatch.setattr(auto.AutoConfig, "from_pretrained", lambda *args, **kwargs: fake_config) + + assert auto.check_and_get_model_definition("/tmp/kimi-k2.5", trust_remote_code=True) is KimiK25QModel + + +def test_kimi_k25_quantizes_language_model_and_keeps_multimodal_modules_in_base(): + class _LanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Module() + + class _KimiCore(nn.Module): + def __init__(self): + super().__init__() + self.vision_tower = nn.Identity() + self.mm_projector = nn.Identity() + self.language_model = _LanguageModel() + + class _KimiWrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = _KimiCore() + + base_modules = set(KimiK25QModel.get_base_modules(_KimiWrapper())) + + assert KimiK25QModel.require_trust_remote_code is True + assert KimiK25QModel.require_load_processor is True + assert KimiK25QModel.pre_lm_head_norm_module == "language_model.model.norm" + assert "model.vision_tower" in base_modules + assert "model.mm_projector" in base_modules + assert "model.language_model" not in base_modules + + +def test_kimi_k25_module_tree_targets_deepseek_v3_text_backbone(): + layer_modules = KimiK25QModel.simple_layer_modules( + model_config=SimpleNamespace(n_routed_experts=2), + quantize_config=SimpleNamespace(dynamic=None), + ) + flat_modules = {name for block in layer_modules for name in block} + + assert KimiK25QModel.layer_modules_strict is False + assert KimiK25QModel.dynamic_expert_index == "n_routed_experts" + assert KimiK25QModel.extract_layers_node() == ["language_model.model.layers"] + assert "self_attn.q_a_proj" in flat_modules + assert "self_attn.kv_a_proj_with_mqa" in flat_modules + assert "self_attn.q_b_proj" in flat_modules + assert "self_attn.kv_b_proj" in flat_modules + assert "self_attn.o_proj" in flat_modules + assert "mlp.gate_proj" in flat_modules + assert "mlp.experts.0.gate_proj" in flat_modules + assert "mlp.shared_experts.up_proj" in flat_modules From 5ea1f14bb1ca49e92a7126c4c0d9a88baebb2f55 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Wed, 6 May 2026 17:07:13 +0800 Subject: [PATCH 2/9] move tests to models/ --- tests/{ => models}/test_kimi_k25_support.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{ => models}/test_kimi_k25_support.py (100%) diff --git a/tests/test_kimi_k25_support.py b/tests/models/test_kimi_k25_support.py similarity index 100% rename from tests/test_kimi_k25_support.py rename to tests/models/test_kimi_k25_support.py From f4e005e9aeda1dffd9e9c262b40a08d9a87057e7 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Wed, 6 May 2026 17:25:54 +0800 Subject: [PATCH 3/9] use new test template --- tests/models/test_kimi_k25_support.py | 131 +++++++++++++++----------- 1 file changed, 74 insertions(+), 57 deletions(-) diff --git a/tests/models/test_kimi_k25_support.py b/tests/models/test_kimi_k25_support.py index 48db9835e..076e270c5 100644 --- a/tests/models/test_kimi_k25_support.py +++ b/tests/models/test_kimi_k25_support.py @@ -1,63 +1,80 @@ -from types import SimpleNamespace +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium -from torch import nn +import os -from gptqmodel.models import auto -from gptqmodel.models.definitions.kimi_k25 import KimiK25QModel +from model_test import ModelTest -def test_kimi_k25_model_type_selects_definition(monkeypatch): - fake_config = SimpleNamespace(model_type="kimi_k25") - - monkeypatch.setattr(auto, "resolve_trust_remote_code", lambda path, trust_remote_code=False: trust_remote_code) - monkeypatch.setattr(auto.AutoConfig, "from_pretrained", lambda *args, **kwargs: fake_config) - - assert auto.check_and_get_model_definition("/tmp/kimi-k2.5", trust_remote_code=True) is KimiK25QModel - - -def test_kimi_k25_quantizes_language_model_and_keeps_multimodal_modules_in_base(): - class _LanguageModel(nn.Module): - def __init__(self): - super().__init__() - self.model = nn.Module() - - class _KimiCore(nn.Module): - def __init__(self): - super().__init__() - self.vision_tower = nn.Identity() - self.mm_projector = nn.Identity() - self.language_model = _LanguageModel() - - class _KimiWrapper(nn.Module): - def __init__(self): - super().__init__() - self.model = _KimiCore() - - base_modules = set(KimiK25QModel.get_base_modules(_KimiWrapper())) - - assert KimiK25QModel.require_trust_remote_code is True - assert KimiK25QModel.require_load_processor is True - assert KimiK25QModel.pre_lm_head_norm_module == "language_model.model.norm" - assert "model.vision_tower" in base_modules - assert "model.mm_projector" in base_modules - assert "model.language_model" not in base_modules - - -def test_kimi_k25_module_tree_targets_deepseek_v3_text_backbone(): - layer_modules = KimiK25QModel.simple_layer_modules( - model_config=SimpleNamespace(n_routed_experts=2), - quantize_config=SimpleNamespace(dynamic=None), +class Test(ModelTest): + # Keep one stable saved checkpoint so eval-only repro runs can reuse the exact post-quant model. + SAVE_PATH = os.environ.get( + "GPTQMODEL_KIMI_2_5_SAVE_PATH", + "/tmp/kimi_2_5_gptq_saved_ckpt", ) - flat_modules = {name for block in layer_modules for name in block} + DELETE_QUANTIZED_MODEL = False + NATIVE_MODEL_ID = "/monster/data/model/Kimi-K2.5" # moonshotai/Kimi-K2.5 + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + TRUST_REMOTE_CODE = True + # TODO, update scores + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.3987, + "floor_pct": 0.04, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3234, # 0.3294 4096, 0.3242 2048 + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.3643, # 0.3558 4096, 0.3635 2048 + "floor_pct": 0.04, + }, + }, + } + EVAL_TASKS_FAST = { + "gsm8k_platinum_cot": { + "chat_template": True, + "evalution_use_model_path": True, + "evalution_batch_size": "auto", + "evalution_model_args": { + "dtype": "bfloat16", + "attn_implementation": "paged|flash_attention_2", + "device": "cuda:0", + }, + "evalution_suite_kwargs": { + "batch_size": 32, + "max_new_tokens": 256, + "stream": True, + }, + "acc,num": { + "value": 0.390625, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3166, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + "acc_norm": { + "value": 0.3430, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + } - assert KimiK25QModel.layer_modules_strict is False - assert KimiK25QModel.dynamic_expert_index == "n_routed_experts" - assert KimiK25QModel.extract_layers_node() == ["language_model.model.layers"] - assert "self_attn.q_a_proj" in flat_modules - assert "self_attn.kv_a_proj_with_mqa" in flat_modules - assert "self_attn.q_b_proj" in flat_modules - assert "self_attn.kv_b_proj" in flat_modules - assert "self_attn.o_proj" in flat_modules - assert "mlp.gate_proj" in flat_modules - assert "mlp.experts.0.gate_proj" in flat_modules - assert "mlp.shared_experts.up_proj" in flat_modules + def test(self): + self.quantize_and_evaluate() From e8108ebcb56cdc64bf8d087a7fe6d0cf8dffacac Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Wed, 6 May 2026 19:15:37 +0800 Subject: [PATCH 4/9] USE_FLASH_ATTN = False --- tests/models/test_kimi_k25_support.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_kimi_k25_support.py b/tests/models/test_kimi_k25_support.py index 076e270c5..5acc3e790 100644 --- a/tests/models/test_kimi_k25_support.py +++ b/tests/models/test_kimi_k25_support.py @@ -19,6 +19,7 @@ class Test(ModelTest): EVAL_BATCH_SIZE = 64 DATASET_CONCAT_SIZE = 2048 TRUST_REMOTE_CODE = True + USE_FLASH_ATTN = False # TODO, update scores EVAL_TASKS_SLOW = { "gsm8k_platinum_cot": { From b148eb676d6f4fd98dc406bcf2ec5c99d98abf32 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Thu, 7 May 2026 09:38:19 +0800 Subject: [PATCH 5/9] use default --- tests/models/test_kimi_k25_support.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_kimi_k25_support.py b/tests/models/test_kimi_k25_support.py index 5acc3e790..8673831e6 100644 --- a/tests/models/test_kimi_k25_support.py +++ b/tests/models/test_kimi_k25_support.py @@ -48,7 +48,6 @@ class Test(ModelTest): "evalution_batch_size": "auto", "evalution_model_args": { "dtype": "bfloat16", - "attn_implementation": "paged|flash_attention_2", "device": "cuda:0", }, "evalution_suite_kwargs": { From ffdb725220a1d0bb2fa320746501549c99f586af Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Thu, 7 May 2026 17:05:48 +0800 Subject: [PATCH 6/9] fix flash attn check --- gptqmodel/models/definitions/kimi_k25.py | 8 +++-- gptqmodel/models/loader.py | 43 +++++++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/gptqmodel/models/definitions/kimi_k25.py b/gptqmodel/models/definitions/kimi_k25.py index 8bf91884c..abeee1ed0 100644 --- a/gptqmodel/models/definitions/kimi_k25.py +++ b/gptqmodel/models/definitions/kimi_k25.py @@ -3,9 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from ...utils.model import get_module from ..base import BaseQModel from ..moe_lifecycle import GateUpDownMoELifecycleHooks +from ...utils.model import get_module class KimiK25QModel(BaseQModel): @@ -44,11 +44,13 @@ class KimiK25QModel(BaseQModel): @classmethod def get_base_modules(cls, model): + base_modules = super().get_base_modules(model) prefix, core_model = cls._resolve_multimodal_layout(model) - base_modules = [] for name, _ in core_model.named_children(): if name != "language_model": - base_modules.append(f"{prefix}.{name}" if prefix else name) + module_name = f"{prefix}.{name}" if prefix else name + if module_name not in base_modules: + base_modules.append(module_name) return base_modules @classmethod diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index ce750eda2..76c8d4611 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -116,6 +116,47 @@ def _supports_flash_attn_2(config: PretrainedConfig) -> bool: return False +def _iter_nested_pretrained_configs(config: PretrainedConfig): + """Yield config and all nested PretrainedConfig nodes once.""" + + stack = [config] + visited = set() + + while stack: + cur = stack.pop() + if not isinstance(cur, PretrainedConfig): + continue + + node_id = id(cur) + if node_id in visited: + continue + visited.add(node_id) + yield cur + + for value in vars(cur).values(): + if isinstance(value, PretrainedConfig): + stack.append(value) + elif isinstance(value, dict): + for sub in value.values(): + if isinstance(sub, PretrainedConfig): + stack.append(sub) + elif isinstance(value, (list, tuple, set)): + for sub in value: + if isinstance(sub, PretrainedConfig): + stack.append(sub) + + +def _override_attn_implementation(config: PretrainedConfig, attn_implementation: str) -> None: + """Apply attention implementation override to root and nested configs.""" + + for sub_config in _iter_nested_pretrained_configs(config): + try: + sub_config._attn_implementation = attn_implementation + except Exception: + # Some remote configs may expose read-only wrappers; ignore safely. + pass + + def _is_accelerated_attention_device(device: object) -> bool: """Return True when the selected device can run CUDA/ROCm flash attention.""" @@ -459,7 +500,7 @@ def from_pretrained( if atten_impl is not None and atten_impl != "auto": log.info(f"Loader: overriding attn_implementation in config to `{atten_impl}`") - config._attn_implementation = atten_impl + _override_attn_implementation(config, atten_impl) resolved_device = normalize_device_device_map(device, device_map) resolved_device = auto_select_device(resolved_device, backend) From 08b59018255693d6042396f21ffa147285b29c5c Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Thu, 7 May 2026 17:09:32 +0800 Subject: [PATCH 7/9] fix NotImplementedError: Cannot copy out of meta tensor; no data! --- gptqmodel/models/writer.py | 109 +++++++++++++++++++++++++- gptqmodel/nn_modules/hooked_linear.py | 50 ++++++++---- tests/models/test_kimi_k25_support.py | 4 + 3 files changed, 144 insertions(+), 19 deletions(-) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index e1510b16b..5694a63a0 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -20,6 +20,7 @@ from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin from transformers.models.auto.tokenization_auto import get_tokenizer_config +from ._const import DEFAULT_MAX_SHARD_SIZE, DEVICE from ..adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora from ..adapter.peft import LoraConfig from ..quantization.config import ( @@ -60,11 +61,9 @@ make_quant, streaming_state_dict_to_shards, ) -from ..utils.structure import alias_all_from_turtle_if_meta +from ..utils.structure import alias_all_from_turtle_if_meta, alias_from_turtle_for_submodule from ..utils.torch import torch_empty_cache from ..version import __version__ -from ._const import DEFAULT_MAX_SHARD_SIZE, DEVICE - log = setup_logger() @@ -103,6 +102,104 @@ def _parse_split_by(value: Optional[str]) -> Optional[str]: return normalized +def _materialize_remaining_meta_params_from_turtle(model: torch.nn.Module, turtle_model) -> int: + """Best-effort fallback for meta params that survive normal turtle sync.""" + + if ( + turtle_model is None + or not hasattr(turtle_model, "_resolve_checkpoint_tensor_source") + or not hasattr(turtle_model, "_weight_map") + or not hasattr(turtle_model, "model_local_path") + ): + return 0 + + restored = 0 + pending_by_shard: Dict[str, List[tuple[str, str, str, torch.nn.Parameter, Optional[int], Optional[int], Optional[int]]]] = {} + + for full_name, param in list(model.named_parameters()): + if not (getattr(param, "is_meta", False) or param.device.type == "meta"): + continue + + module_path, leaf = full_name.rsplit(".", 1) + resolved_name, expert_index, split_index, split_dim = turtle_model._resolve_checkpoint_tensor_source(module_path, leaf) + if resolved_name is None: + continue + shard = turtle_model._weight_map.get(resolved_name) + if shard is None: + continue + pending_by_shard.setdefault(shard, []).append( + (resolved_name, module_path, leaf, param, expert_index, split_index, split_dim) + ) + + for shard, entries in pending_by_shard.items(): + shard_path = os.path.join(turtle_model.model_local_path, shard) + unique_names = {name for name, _module_path, _leaf, _param, _expert_index, _split_index, _split_dim in entries} + + try: + with safe_open(shard_path, framework="pt", device="cpu") as handler: + tensors = {name: handler.get_tensor(name) for name in unique_names} + except RuntimeError as exc: + log.warn("Model save: skipping shard `%s` during meta materialization due to runtime error: %s", shard, exc) + continue + + for tensor_name, module_path, leaf, param, expert_index, split_index, split_dim in entries: + source = tensors.get(tensor_name) + if source is None: + continue + target = source + if expert_index is not None: + if expert_index >= target.shape[0]: + continue + target = target.narrow(0, expert_index, 1).squeeze(0) + if split_index is not None and split_dim is not None: + if target.shape[split_dim] % 2 != 0: + continue + chunk = target.shape[split_dim] // 2 + target = target.narrow(split_dim, split_index * chunk, chunk) + if target.dtype != param.dtype: + target = target.to(dtype=param.dtype) + if tuple(target.shape) != tuple(param.shape): + continue + module = model.get_submodule(module_path) + replacement = torch.nn.Parameter(target.detach().clone(), requires_grad=param.requires_grad) + setattr(module, leaf, replacement) + restored += 1 + + return restored + + +def _materialize_meta_layers_from_turtle(model: torch.nn.Module, turtle_model) -> int: + if turtle_model is None or not hasattr(turtle_model, "materialize_submodule"): + return 0 + + layer_paths = set() + for full_name, param in model.named_parameters(): + if not (getattr(param, "is_meta", False) or param.device.type == "meta"): + continue + parts = full_name.split(".") + if "layers" in parts: + i = parts.index("layers") + if i + 1 < len(parts): + layer_paths.add(".".join(parts[: i + 2])) + + materialized = 0 + for path in sorted(layer_paths): + try: + submodule = model.get_submodule(path) + alias_from_turtle_for_submodule( + target_model=model, + turtle_model=turtle_model, + target_submodule=submodule, + device=torch.device("cpu"), + non_blocking=False, + ) + materialized += 1 + except Exception as exc: + log.warn("Model save: failed to materialize meta layer `%s` from turtle: %s", path, exc) + + return materialized + + def _cleanup_saved_weight_files( save_dir: str, expected_files: List[str], @@ -658,6 +755,12 @@ def debug_saved_config(path): # Due to shell/turtle state, we need to sync the modules from turtle to shell if not self.load_quantized_model: alias_all_from_turtle_if_meta(shell_model=self.model, turtle_model=self.turtle_model) + materialized_layers = _materialize_meta_layers_from_turtle(self.model, self.turtle_model) + if materialized_layers: + log.info("Model save: materialized %s meta layer modules from turtle source.", materialized_layers) + restored_meta = _materialize_remaining_meta_params_from_turtle(self.model, self.turtle_model) + if restored_meta: + log.info("Model save: materialized %s remaining meta params from turtle source.", restored_meta) offload_root = self.quantize_config.offload_to_disk_path if getattr(self.quantize_config, "offload_to_disk", False) else None state_dict = get_state_dict_for_save(self.model, offload_root=offload_root) diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py index c956a5573..a1af8421b 100644 --- a/gptqmodel/nn_modules/hooked_linear.py +++ b/gptqmodel/nn_modules/hooked_linear.py @@ -7,12 +7,12 @@ import torch import transformers +from accelerate.utils import has_offloaded_params from torch import nn from ..utils.device_telemetry import emit_device_telemetry from ..utils.logger import setup_logger - log = setup_logger() @@ -22,6 +22,29 @@ class StopForward(Exception): STOP_FORWARD_EXCEPTION = StopForward("Forwarding stopped") + +def _restore_output_device(output: torch.Tensor, original_device: torch.device) -> torch.Tensor: + if output.device == original_device: + return output + if original_device.type == "meta": + # Meta tensors are placeholders with no backing storage; never move real + # outputs back to meta during hook replay. + return output + return output.to(device=original_device) + + +def _materialize_if_meta_weight(module: nn.Module, *, input_device: torch.device) -> None: + weight = getattr(module, "weight", None) + if weight is None: + return + if getattr(weight, "is_meta", False) or weight.device.type == "meta": + if has_offloaded_params(module): + from ..utils.offload import undo_offload_to_disk + + restore_device = input_device if input_device.type != "meta" else torch.device("cpu") + undo_offload_to_disk(module, device=restore_device, include_buffers=False) + + # Models using conv1d: gpt2 class HookedConv1D(transformers.Conv1D): def __init__(self, nf: int, nx: int) -> None: @@ -42,6 +65,7 @@ def from_conv1d(m: transformers.Conv1D): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: original_device = input.device + _materialize_if_meta_weight(self, input_device=original_device) target_device = self.weight.data.device if original_device != target_device: input = input.to(device=target_device) @@ -52,9 +76,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.forward_hook_last: raise STOP_FORWARD_EXCEPTION.with_traceback(None) - if output.device != original_device: - output = output.to(device=original_device) - return output + return _restore_output_device(output, original_device) class HookedConv1d(torch.nn.Conv1d): def __init__( @@ -106,6 +128,7 @@ def from_conv1d(m: torch.nn.Conv1d): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: original_device = input.device + _materialize_if_meta_weight(self, input_device=original_device) target_device = self.weight.data.device if original_device != target_device: input = input.to(device=target_device) @@ -114,9 +137,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.forward_hook(self, (input,), output) if self.forward_hook_last: raise STOP_FORWARD_EXCEPTION.with_traceback(None) - if output.device != original_device: - output = output.to(device=original_device) - return output + return _restore_output_device(output, original_device) # Models using conv2d: ovis class HookedConv2d(torch.nn.Conv2d): @@ -169,6 +190,7 @@ def from_conv2d(m: torch.nn.Conv2d): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: original_device = input.device + _materialize_if_meta_weight(self, input_device=original_device) target_device = self.weight.data.device if original_device != target_device: input = input.to(device=target_device) @@ -177,9 +199,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.forward_hook(self, (input,), output) if self.forward_hook_last: raise STOP_FORWARD_EXCEPTION.with_traceback(None) - if output.device != original_device: - output = output.to(device=original_device) - return output + return _restore_output_device(output, original_device) # Models using transformers.conv1d: gpt2 class HookedTransformerConv1D(transformers.Conv1D): @@ -200,6 +220,7 @@ def from_conv1d(conv1d: transformers.Conv1D): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: original_device = input.device + _materialize_if_meta_weight(self, input_device=original_device) target_device = self.weight.data.device if original_device != target_device: input = input.to(device=target_device) @@ -208,9 +229,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.forward_hook(self, (input,), output) if self.forward_hook_last: raise STOP_FORWARD_EXCEPTION.with_traceback(None) - if output.device != original_device: - output = output.to(device=original_device) - return output + return _restore_output_device(output, original_device) class HookedLinear(torch.nn.Linear): def __init__(self, in_features: int, out_features: int) -> None: @@ -233,6 +252,7 @@ def from_linear(linear: torch.nn.Linear): @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: original_device = input.device + _materialize_if_meta_weight(self, input_device=original_device) target_device = self.weight.data.device module_name = getattr(self, "module_name", None) or getattr(self, "full_name", None) or getattr(self, "name", None) or "unknown" emit_device_telemetry( @@ -248,9 +268,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.forward_hook(self, (input,), output) if self.forward_hook_last: raise STOP_FORWARD_EXCEPTION.with_traceback(None) - if output.device != original_device: - output = output.to(device=original_device) - return output + return _restore_output_device(output, original_device) def _replace_module(module, child, name, level: int = 0, debug: bool = False) -> bool: diff --git a/tests/models/test_kimi_k25_support.py b/tests/models/test_kimi_k25_support.py index 8673831e6..a499269b8 100644 --- a/tests/models/test_kimi_k25_support.py +++ b/tests/models/test_kimi_k25_support.py @@ -9,6 +9,8 @@ class Test(ModelTest): + # Isolate this test from global fast-layer env overrides. + FAST_LAYER_COUNT_ENV = "GPTQMODEL_FAST_LAYER_COUNT_KIMI_K25" # Keep one stable saved checkpoint so eval-only repro runs can reuse the exact post-quant model. SAVE_PATH = os.environ.get( "GPTQMODEL_KIMI_2_5_SAVE_PATH", @@ -20,6 +22,7 @@ class Test(ModelTest): DATASET_CONCAT_SIZE = 2048 TRUST_REMOTE_CODE = True USE_FLASH_ATTN = False + OFFLOAD_TO_DISK = False # TODO, update scores EVAL_TASKS_SLOW = { "gsm8k_platinum_cot": { @@ -48,6 +51,7 @@ class Test(ModelTest): "evalution_batch_size": "auto", "evalution_model_args": { "dtype": "bfloat16", + "attn_implementation": "auto", "device": "cuda:0", }, "evalution_suite_kwargs": { From 4c17ddf81d705cee1a77b3eaed8bee2eb927369a Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Thu, 7 May 2026 17:10:02 +0800 Subject: [PATCH 8/9] add todo --- tests/models/test_kimi_k25_support.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_kimi_k25_support.py b/tests/models/test_kimi_k25_support.py index a499269b8..213c259f2 100644 --- a/tests/models/test_kimi_k25_support.py +++ b/tests/models/test_kimi_k25_support.py @@ -22,6 +22,7 @@ class Test(ModelTest): DATASET_CONCAT_SIZE = 2048 TRUST_REMOTE_CODE = True USE_FLASH_ATTN = False + # TODO, offload is unable to fix by now OFFLOAD_TO_DISK = False # TODO, update scores EVAL_TASKS_SLOW = { From b78447fa123f29231744281a008a1999d7878fb6 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Fri, 8 May 2026 14:23:21 +0800 Subject: [PATCH 9/9] add retry to fix file not found. --- gptqmodel/models/loader.py | 51 +++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 76c8d4611..1546c1d06 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -7,6 +7,7 @@ import copy import os +import shutil import time from importlib.metadata import PackageNotFoundError, version from itertools import chain @@ -239,6 +240,43 @@ def _is_meta_shell_build_error(exc: Exception) -> bool: return "cannot be called on meta tensors" in message and ".item()" in message +def _is_broken_transformers_dynamic_module_error(exc: Exception) -> bool: + if not isinstance(exc, FileNotFoundError): + return False + missing_path = str(getattr(exc, "filename", "") or exc) + return "transformers_modules" in missing_path and missing_path.endswith(".py") + + +def _hf_loader_from_pretrained_with_dynamic_module_retry(loader, model_local_path: str, **kwargs): + try: + return loader.from_pretrained(model_local_path, **kwargs) + except Exception as exc: + if not _is_broken_transformers_dynamic_module_error(exc): + raise + + missing_path = str(getattr(exc, "filename", "") or "") + missing_name = os.path.basename(missing_path) + source_path = os.path.join(model_local_path, missing_name) + if missing_path and os.path.isfile(source_path): + os.makedirs(os.path.dirname(missing_path), exist_ok=True) + shutil.copy2(source_path, missing_path) + log.warn( + "Loader: repaired missing dynamic-module file by copying `%s` -> `%s`.", + source_path, + missing_path, + ) + + retry_kwargs = dict(kwargs) + retry_kwargs["force_download"] = True + log.warn( + "Loader: detected broken transformers dynamic-module cache while loading `%s`; " + "retrying once with force_download=True: %s", + model_local_path, + exc, + ) + return loader.from_pretrained(model_local_path, **retry_kwargs) + + def _coerce_quantized_awq_dtype(*, backend: BACKEND, qcfg: QuantizeConfig, dtype): if qcfg.quant_method not in (METHOD.AWQ, METHOD.PARO): return dtype @@ -578,7 +616,12 @@ def from_pretrained( hf_model_init_kwargs[ATTN_IMPLEMENTATION] = "flash_attention_2" log.info("Loader: Auto enabling flash_attention_2 for dense Bonsai PROFILE.%s.", effective_profile.name) # Load a non-quantized model, but do not perform quantization. For example, for evaluation. - model = cls.loader.from_pretrained(model_local_path, config=config, **hf_model_init_kwargs) + model = _hf_loader_from_pretrained_with_dynamic_module_retry( + cls.loader, + model_local_path, + config=config, + **hf_model_init_kwargs, + ) model._model_init_kwargs = hf_model_init_kwargs _maybe_print_module_tree(model=model) @@ -676,7 +719,8 @@ def skip(*args, **kwargs): fallback_init_kwargs = model_init_kwargs_without_internal.copy() fallback_init_kwargs.pop("device_map", None) fallback_init_kwargs["low_cpu_mem_usage"] = False - model = cls.loader.from_pretrained( + model = _hf_loader_from_pretrained_with_dynamic_module_retry( + cls.loader, model_local_path, config=config, **fallback_init_kwargs, @@ -716,7 +760,8 @@ def skip(*args, **kwargs): ) else: log.info("Loader: loading model directly to CPU (not using meta device or turtle_model)") - model = cls.loader.from_pretrained( + model = _hf_loader_from_pretrained_with_dynamic_module_retry( + cls.loader, model_local_path, config=config, **model_init_kwargs_without_internal,