Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
# ruff: noqa: I001

from __future__ import annotations

Expand Down Expand Up @@ -117,6 +118,7 @@
from .definitions.internlm2 import InternLM2QModel # noqa: E402
from .definitions.internvl_chat import InternVLChatQModel # noqa: E402
from .definitions.klear import KlearQModel # noqa: E402
from .definitions.kimi_k25 import KimiK25QModel # noqa: E402
from .definitions.laguna import LagunaQModel # noqa: E402
from .definitions.lfm2_moe import LFM2MoeQModel # noqa: E402
from .definitions.llada2 import LLaDA2MoeQModel
Expand Down Expand Up @@ -183,6 +185,7 @@
"brumby": BrumbyQModel,
"gpt_neo": GptNeoQModel,
"kimi_k2": DeepSeekV3QModel, # 100% DeepSeekV3QModel clone
"kimi_k25": KimiK25QModel,
"klear": KlearQModel,
"laguna": LagunaQModel,
"gpt_neox": GPTNeoXQModel,
Expand Down
67 changes: 67 additions & 0 deletions gptqmodel/models/definitions/kimi_k25.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from ..base import BaseQModel
from ..moe_lifecycle import GateUpDownMoELifecycleHooks
from ...utils.model import get_module


class KimiK25QModel(BaseQModel):
# Kimi-K2.5 wraps a DeepSeek-V3 text backbone with a vision tower and
# projector. Quantize the language model and keep the vision path in base.
require_trust_remote_code = True

require_load_processor = True

pre_lm_head_norm_module = "language_model.model.norm"

dynamic_expert_index = "n_routed_experts"

layer_modules_strict = False

moe_lifecycle_hooks = GateUpDownMoELifecycleHooks()

module_tree = [
"language_model",
"model",
"layers",
"#",
{
"input_layernorm": ("input_layernorm:!",),
"self_attn": ("q_proj:0", "q_a_proj:0", "kv_a_proj_with_mqa:0", "q_b_proj:1", "kv_b_proj:1", "o_proj:2"),
"post_attention_layernorm": ("post_attention_layernorm:!",),
"mlp:moe": {
"": ("gate_proj:0", "up_proj:0", "down_proj:1"),
"experts": {
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
},
"shared_experts": ("gate_proj:0", "up_proj:0", "down_proj:1"),
},
},
]

@classmethod
def get_base_modules(cls, model):
base_modules = super().get_base_modules(model)
prefix, core_model = cls._resolve_multimodal_layout(model)
for name, _ in core_model.named_children():
if name != "language_model":
module_name = f"{prefix}.{name}" if prefix else name
if module_name not in base_modules:
base_modules.append(module_name)
return base_modules

@classmethod
def _resolve_multimodal_layout(cls, model):
for prefix in ("model", ""):
core_model = get_module(model, prefix) if prefix else model
if core_model is None:
continue
if hasattr(core_model, "language_model"):
return prefix, core_model
raise AttributeError("Unable to resolve Kimi-K2.5 core model with a `language_model` module.")


__all__ = ["KimiK25QModel"]
94 changes: 90 additions & 4 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import copy
import os
import shutil
import time
from importlib.metadata import PackageNotFoundError, version
from itertools import chain
Expand Down Expand Up @@ -116,6 +117,47 @@ def _supports_flash_attn_2(config: PretrainedConfig) -> bool:
return False


def _iter_nested_pretrained_configs(config: PretrainedConfig):
"""Yield config and all nested PretrainedConfig nodes once."""

stack = [config]
visited = set()

while stack:
cur = stack.pop()
if not isinstance(cur, PretrainedConfig):
continue

node_id = id(cur)
if node_id in visited:
continue
visited.add(node_id)
yield cur

for value in vars(cur).values():
if isinstance(value, PretrainedConfig):
stack.append(value)
elif isinstance(value, dict):
for sub in value.values():
if isinstance(sub, PretrainedConfig):
stack.append(sub)
elif isinstance(value, (list, tuple, set)):
for sub in value:
if isinstance(sub, PretrainedConfig):
stack.append(sub)


def _override_attn_implementation(config: PretrainedConfig, attn_implementation: str) -> None:
"""Apply attention implementation override to root and nested configs."""

for sub_config in _iter_nested_pretrained_configs(config):
try:
sub_config._attn_implementation = attn_implementation
except Exception:
# Some remote configs may expose read-only wrappers; ignore safely.
pass


def _is_accelerated_attention_device(device: object) -> bool:
"""Return True when the selected device can run CUDA/ROCm flash attention."""

Expand Down Expand Up @@ -198,6 +240,43 @@ def _is_meta_shell_build_error(exc: Exception) -> bool:
return "cannot be called on meta tensors" in message and ".item()" in message


def _is_broken_transformers_dynamic_module_error(exc: Exception) -> bool:
if not isinstance(exc, FileNotFoundError):
return False
missing_path = str(getattr(exc, "filename", "") or exc)
return "transformers_modules" in missing_path and missing_path.endswith(".py")


def _hf_loader_from_pretrained_with_dynamic_module_retry(loader, model_local_path: str, **kwargs):
try:
return loader.from_pretrained(model_local_path, **kwargs)
except Exception as exc:
if not _is_broken_transformers_dynamic_module_error(exc):
raise

missing_path = str(getattr(exc, "filename", "") or "")
missing_name = os.path.basename(missing_path)
source_path = os.path.join(model_local_path, missing_name)
if missing_path and os.path.isfile(source_path):
os.makedirs(os.path.dirname(missing_path), exist_ok=True)
shutil.copy2(source_path, missing_path)
log.warn(
"Loader: repaired missing dynamic-module file by copying `%s` -> `%s`.",
source_path,
missing_path,
)

retry_kwargs = dict(kwargs)
retry_kwargs["force_download"] = True
log.warn(
"Loader: detected broken transformers dynamic-module cache while loading `%s`; "
"retrying once with force_download=True: %s",
model_local_path,
exc,
)
return loader.from_pretrained(model_local_path, **retry_kwargs)


def _coerce_quantized_awq_dtype(*, backend: BACKEND, qcfg: QuantizeConfig, dtype):
if qcfg.quant_method not in (METHOD.AWQ, METHOD.PARO):
return dtype
Expand Down Expand Up @@ -459,7 +538,7 @@ def from_pretrained(

if atten_impl is not None and atten_impl != "auto":
log.info(f"Loader: overriding attn_implementation in config to `{atten_impl}`")
config._attn_implementation = atten_impl
_override_attn_implementation(config, atten_impl)

resolved_device = normalize_device_device_map(device, device_map)
resolved_device = auto_select_device(resolved_device, backend)
Expand Down Expand Up @@ -537,7 +616,12 @@ def from_pretrained(
hf_model_init_kwargs[ATTN_IMPLEMENTATION] = "flash_attention_2"
log.info("Loader: Auto enabling flash_attention_2 for dense Bonsai PROFILE.%s.", effective_profile.name)
# Load a non-quantized model, but do not perform quantization. For example, for evaluation.
model = cls.loader.from_pretrained(model_local_path, config=config, **hf_model_init_kwargs)
model = _hf_loader_from_pretrained_with_dynamic_module_retry(
cls.loader,
model_local_path,
config=config,
**hf_model_init_kwargs,
)
model._model_init_kwargs = hf_model_init_kwargs
_maybe_print_module_tree(model=model)

Expand Down Expand Up @@ -635,7 +719,8 @@ def skip(*args, **kwargs):
fallback_init_kwargs = model_init_kwargs_without_internal.copy()
fallback_init_kwargs.pop("device_map", None)
fallback_init_kwargs["low_cpu_mem_usage"] = False
model = cls.loader.from_pretrained(
model = _hf_loader_from_pretrained_with_dynamic_module_retry(
cls.loader,
model_local_path,
config=config,
**fallback_init_kwargs,
Expand Down Expand Up @@ -675,7 +760,8 @@ def skip(*args, **kwargs):
)
else:
log.info("Loader: loading model directly to CPU (not using meta device or turtle_model)")
model = cls.loader.from_pretrained(
model = _hf_loader_from_pretrained_with_dynamic_module_retry(
cls.loader,
model_local_path,
config=config,
**model_init_kwargs_without_internal,
Expand Down
109 changes: 106 additions & 3 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin
from transformers.models.auto.tokenization_auto import get_tokenizer_config

from ._const import DEFAULT_MAX_SHARD_SIZE, DEVICE
from ..adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora
from ..adapter.peft import LoraConfig
from ..quantization.config import (
Expand Down Expand Up @@ -60,11 +61,9 @@
make_quant,
streaming_state_dict_to_shards,
)
from ..utils.structure import alias_all_from_turtle_if_meta
from ..utils.structure import alias_all_from_turtle_if_meta, alias_from_turtle_for_submodule
from ..utils.torch import torch_empty_cache
from ..version import __version__
from ._const import DEFAULT_MAX_SHARD_SIZE, DEVICE


log = setup_logger()

Expand Down Expand Up @@ -103,6 +102,104 @@ def _parse_split_by(value: Optional[str]) -> Optional[str]:
return normalized


def _materialize_remaining_meta_params_from_turtle(model: torch.nn.Module, turtle_model) -> int:
"""Best-effort fallback for meta params that survive normal turtle sync."""

if (
turtle_model is None
or not hasattr(turtle_model, "_resolve_checkpoint_tensor_source")
or not hasattr(turtle_model, "_weight_map")
or not hasattr(turtle_model, "model_local_path")
):
return 0

restored = 0
pending_by_shard: Dict[str, List[tuple[str, str, str, torch.nn.Parameter, Optional[int], Optional[int], Optional[int]]]] = {}

for full_name, param in list(model.named_parameters()):
if not (getattr(param, "is_meta", False) or param.device.type == "meta"):
continue

module_path, leaf = full_name.rsplit(".", 1)
resolved_name, expert_index, split_index, split_dim = turtle_model._resolve_checkpoint_tensor_source(module_path, leaf)
if resolved_name is None:
continue
shard = turtle_model._weight_map.get(resolved_name)
if shard is None:
continue
pending_by_shard.setdefault(shard, []).append(
(resolved_name, module_path, leaf, param, expert_index, split_index, split_dim)
)

for shard, entries in pending_by_shard.items():
shard_path = os.path.join(turtle_model.model_local_path, shard)
unique_names = {name for name, _module_path, _leaf, _param, _expert_index, _split_index, _split_dim in entries}

try:
with safe_open(shard_path, framework="pt", device="cpu") as handler:
tensors = {name: handler.get_tensor(name) for name in unique_names}
except RuntimeError as exc:
log.warn("Model save: skipping shard `%s` during meta materialization due to runtime error: %s", shard, exc)
continue

for tensor_name, module_path, leaf, param, expert_index, split_index, split_dim in entries:
source = tensors.get(tensor_name)
if source is None:
continue
target = source
if expert_index is not None:
if expert_index >= target.shape[0]:
continue
target = target.narrow(0, expert_index, 1).squeeze(0)
if split_index is not None and split_dim is not None:
if target.shape[split_dim] % 2 != 0:
continue
chunk = target.shape[split_dim] // 2
target = target.narrow(split_dim, split_index * chunk, chunk)
if target.dtype != param.dtype:
target = target.to(dtype=param.dtype)
if tuple(target.shape) != tuple(param.shape):
continue
module = model.get_submodule(module_path)
replacement = torch.nn.Parameter(target.detach().clone(), requires_grad=param.requires_grad)
setattr(module, leaf, replacement)
restored += 1

return restored


def _materialize_meta_layers_from_turtle(model: torch.nn.Module, turtle_model) -> int:
if turtle_model is None or not hasattr(turtle_model, "materialize_submodule"):
return 0

layer_paths = set()
for full_name, param in model.named_parameters():
if not (getattr(param, "is_meta", False) or param.device.type == "meta"):
continue
parts = full_name.split(".")
if "layers" in parts:
i = parts.index("layers")
if i + 1 < len(parts):
layer_paths.add(".".join(parts[: i + 2]))

materialized = 0
for path in sorted(layer_paths):
try:
submodule = model.get_submodule(path)
alias_from_turtle_for_submodule(
target_model=model,
turtle_model=turtle_model,
target_submodule=submodule,
device=torch.device("cpu"),
non_blocking=False,
)
materialized += 1
except Exception as exc:
log.warn("Model save: failed to materialize meta layer `%s` from turtle: %s", path, exc)

return materialized


def _cleanup_saved_weight_files(
save_dir: str,
expected_files: List[str],
Expand Down Expand Up @@ -658,6 +755,12 @@ def debug_saved_config(path):
# Due to shell/turtle state, we need to sync the modules from turtle to shell
if not self.load_quantized_model:
alias_all_from_turtle_if_meta(shell_model=self.model, turtle_model=self.turtle_model)
materialized_layers = _materialize_meta_layers_from_turtle(self.model, self.turtle_model)
if materialized_layers:
log.info("Model save: materialized %s meta layer modules from turtle source.", materialized_layers)
restored_meta = _materialize_remaining_meta_params_from_turtle(self.model, self.turtle_model)
if restored_meta:
log.info("Model save: materialized %s remaining meta params from turtle source.", restored_meta)

offload_root = self.quantize_config.offload_to_disk_path if getattr(self.quantize_config, "offload_to_disk", False) else None
state_dict = get_state_dict_for_save(self.model, offload_root=offload_root)
Expand Down
Loading