From a26ba0fe5f6a5de69a3d7f3bb89816a8f3ea2df0 Mon Sep 17 00:00:00 2001 From: kangletian Date: Sat, 23 May 2026 14:28:22 +0800 Subject: [PATCH] convert : add --fuse-qkv flag to fuse Q/K/V into QKV during HF-to-GGUF conversion --- conversion/base.py | 39 +++++++++++++++++- convert_hf_to_gguf.py | 6 +++ gguf-py/gguf/constants.py | 83 ++++++++++++++++++++++++++++++++++++++ src/llama-graph.cpp | 5 +++ src/llama-model.cpp | 6 +++ src/models/gemma3n.cpp | 9 ++++- src/models/gemma4.cpp | 30 +++++++++++--- src/models/kimi-linear.cpp | 19 +++++++-- src/models/minimax-m2.cpp | 24 ++++++++--- src/models/olmo2.cpp | 24 ++++++++--- src/models/olmoe.cpp | 24 ++++++++--- src/models/qwen35.cpp | 27 +++++++++---- src/models/qwen35moe.cpp | 27 +++++++++---- src/models/qwen3next.cpp | 27 +++++++++---- src/models/step35.cpp | 21 ++++++++-- 15 files changed, 317 insertions(+), 54 deletions(-) diff --git a/conversion/base.py b/conversion/base.py index f861f8b5296c..d6643bf13b71 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -120,7 +120,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, disable_mistral_community_chat_template: bool = False, sentence_transformers_dense_modules: bool = False, fuse_gate_up_exps: bool = False, - fp8_as_q8: bool = False): + fp8_as_q8: bool = False, + fuse_qkv: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -142,6 +143,13 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.fuse_gate_up_exps = fuse_gate_up_exps self._gate_exp_buffer: dict[int, Tensor] = {} self._up_exp_buffer: dict[int, Tensor] = {} + self.fuse_qkv = fuse_qkv + self._q_buffer: dict[int, Tensor] = {} + self._k_buffer: dict[int, Tensor] = {} + self._v_buffer: dict[int, Tensor] = {} + self._q_bias_buffer: dict[int, Tensor] = {} + self._k_bias_buffer: dict[int, Tensor] = {} + self._v_bias_buffer: dict[int, Tensor] = {} self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override @@ -634,6 +642,35 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid): return [] + # Handle Q/K/V tensor fusion if enabled + if self.fuse_qkv and bid is not None: + is_bias = name.endswith('.bias') + suffix = 'bias' if is_bias else 'weight' + buf_q = self._q_bias_buffer if is_bias else self._q_buffer + buf_k = self._k_bias_buffer if is_bias else self._k_buffer + buf_v = self._v_bias_buffer if is_bias else self._v_buffer + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_Q, bid): + buf_q[bid] = data_torch + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_K, bid): + buf_k[bid] = data_torch + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_V, bid): + buf_v[bid] = data_torch + + if bid in buf_q and bid in buf_k and bid in buf_v: + q_data = buf_q.pop(bid) + k_data = buf_k.pop(bid) + v_data = buf_v.pop(bid) + fused_data = torch.cat([q_data, k_data, v_data], dim=0) + fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_QKV, bid, suffix=suffix) + logger.info(f"Fused Q, K, V {suffix} into QKV for layer {bid}") + return [(fused_name, fused_data)] + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_Q, bid) or \ + self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_K, bid) or \ + self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_V, bid): + return [] + return [(new_name, data_torch)] def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 85527553563d..07744e9498a3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -153,6 +153,11 @@ def parse_args() -> argparse.Namespace: help="Store tensors dequantized from FP8 as Q8_0 instead of BF16/F16.", ) + parser.add_argument( + "--fuse-qkv", action="store_true", + help="Fuse separate Q, K, V weight tensors into a single QKV tensor.", + ) + args = parser.parse_args() if not args.print_supported_models and args.model is None: parser.error("the following arguments are required: model") @@ -270,6 +275,7 @@ def main() -> None: sentence_transformers_dense_modules=args.sentence_transformers_dense_modules, fuse_gate_up_exps=args.fuse_gate_up_exps, fp8_as_q8=args.fp8_as_q8, + fuse_qkv=args.fuse_qkv, ) if args.vocab_only: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 92578490cb3d..f4e14908644e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1587,6 +1587,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1607,6 +1608,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1630,6 +1632,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1650,6 +1653,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1695,6 +1699,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1797,6 +1802,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_TYPES, MODEL_TENSOR.ATTN_NORM_2, MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -1825,6 +1831,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1854,6 +1861,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1866,6 +1874,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1893,6 +1902,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1924,6 +1934,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1939,6 +1950,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1954,6 +1966,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1968,6 +1981,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1982,6 +1996,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2002,6 +2017,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2018,6 +2034,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2067,6 +2084,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2083,6 +2101,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2171,6 +2190,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2304,6 +2324,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2319,6 +2340,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2337,6 +2359,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FACTORS_LONG, MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2374,6 +2397,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2386,6 +2410,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2402,6 +2427,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2420,6 +2446,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2456,6 +2483,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2491,6 +2519,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.DENSE_2_OUT, MODEL_TENSOR.DENSE_3_OUT, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2511,6 +2540,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2674,6 +2704,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2703,6 +2734,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2717,6 +2749,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2731,6 +2764,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2755,6 +2789,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.OLMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2767,6 +2802,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2782,6 +2818,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.SEED_OSS: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2798,6 +2835,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2829,6 +2867,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2850,6 +2889,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2873,6 +2913,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_A, MODEL_TENSOR.ATTN_Q_B, @@ -2905,6 +2946,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_A, MODEL_TENSOR.ATTN_Q_B, @@ -2977,6 +3019,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3053,6 +3096,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3120,6 +3164,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.BITNET: [ + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3198,6 +3244,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3212,6 +3259,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3233,6 +3281,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3252,6 +3301,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3276,6 +3326,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3291,6 +3342,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3309,6 +3361,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3340,6 +3393,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3354,6 +3408,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3379,6 +3434,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3402,6 +3458,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3441,6 +3498,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3488,6 +3546,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3513,6 +3572,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3528,6 +3588,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3554,6 +3615,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3568,6 +3630,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3585,6 +3648,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_NORM, # Attention components + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, # Query projection MODEL_TENSOR.ATTN_K, # Key projection MODEL_TENSOR.ATTN_V, # Value projection @@ -3617,6 +3681,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3637,6 +3702,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3653,6 +3719,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3670,6 +3737,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3686,6 +3754,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3709,6 +3778,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_NORM, # operator_norm MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3729,6 +3799,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_NORM, # operator_norm MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3744,6 +3815,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3763,6 +3835,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3780,6 +3853,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3797,6 +3871,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3817,6 +3892,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3852,6 +3928,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3869,6 +3946,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3884,6 +3962,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3960,6 +4039,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3986,6 +4066,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4005,6 +4086,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -4021,6 +4103,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 5bca8230b9b0..0a53c1b2648f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1169,6 +1169,11 @@ llm_graph_qkv llm_graph_context::build_qkv( if (layer.wqkv_b) { qkv = ggml_add(ctx0, qkv, layer.wqkv_b); cb(qkv, "wqkv_b", il); + } else if (layer.wq_b && layer.wk_b && layer.wv_b) { + // fused weights but separate biases (from --fuse-qkv conversion) + ggml_tensor * qkv_b = ggml_concat(ctx0, ggml_concat(ctx0, layer.wq_b, layer.wk_b, 0), layer.wv_b, 0); + qkv = ggml_add(ctx0, qkv, qkv_b); + cb(qkv, "wqkv_b", il); } if (hparams.f_clamp_kqv > 0.0f) { qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a8323c8fb1e4..006836f705f1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2578,6 +2578,12 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); if (layer.wqkv) { layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + // fallback: when --fuse-qkv only fused weights, biases may still be stored separately + if (!layer.wqkv_b) { + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + } } else { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); diff --git a/src/models/gemma3n.cpp b/src/models/gemma3n.cpp index 6ec3a006081f..a6d9d6a9346d 100644 --- a/src/models/gemma3n.cpp +++ b/src/models/gemma3n.cpp @@ -176,7 +176,14 @@ llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_par hparams.f_attention_scale, il); } else { // reuse KV cache of earlier layers - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + const int64_t q_dim = n_embd_head * n_head; + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + } cb(Qcur, "Qcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); diff --git a/src/models/gemma4.cpp b/src/models/gemma4.cpp index 4f9d8b18bc72..ba0c3c5bc10b 100644 --- a/src/models/gemma4.cpp +++ b/src/models/gemma4.cpp @@ -195,9 +195,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para // Q projection (shared for both non-KV and KV layers) // this is to mirror Gemma4Attention in pytorch code + ggml_tensor * qkv_fused = nullptr; ggml_tensor * Qcur; - { + if (model.layers[il].wqkv) { + qkv_fused = build_lora_mm(model.layers[il].wqkv, cur, model.layers[il].wqkv_s); + cb(qkv_fused, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head; + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv_fused, q_dim, n_tokens, qkv_fused->nb[1], 0)); + } else { Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); + } + { cb(Qcur, "Qcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -212,12 +220,22 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para // self-attention if (hparams.has_kv(il)) { - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (qkv_fused) { + const int64_t q_dim = n_embd_head * n_head; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv_fused); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv_fused, k_dim, n_tokens, qkv_fused->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv_fused, v_dim, n_tokens, qkv_fused->nb[1], (q_dim + k_dim) * esize)); + } else { + Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + Vcur = model.layers[il].wv + ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) + : Kcur; // if v_proj is not present, use Kcur as Vcur + } cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = model.layers[il].wv - ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) - : Kcur; // if v_proj is not present, use Kcur as Vcur cb(Vcur, "Vcur", il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index ecffb105496b..a185e0802f8a 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -195,7 +195,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V // Step 1: Q, K, V projections -> [d_inner, n_tokens] - ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x); + ggml_tensor * x_proj = proj_w ? ggml_mul_mat(ctx0, proj_w, x) : x; // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs); @@ -295,9 +295,20 @@ llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); cb(conv_states_all, "conv_states_all", il); ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs); - ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); - ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); - ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + ggml_tensor * q_in = cur, * k_in = cur, * v_in = cur; + ggml_tensor * q_w = layer.wq, * k_w = layer.wk, * v_w = layer.wv; + if (layer.wqkv) { + ggml_tensor * qkv = ggml_mul_mat(ctx0, layer.wqkv, cur); + const int64_t d_inner = head_dim * n_head; + const size_t esize = ggml_element_size(qkv); + q_in = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, d_inner, n_tokens, qkv->nb[1], 0)); + k_in = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, d_inner, n_tokens, qkv->nb[1], d_inner * esize)); + v_in = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, d_inner, n_tokens, qkv->nb[1], 2 * d_inner * esize)); + q_w = nullptr; k_w = nullptr; v_w = nullptr; + } + ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, q_in, q_w, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, k_in, k_w, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, v_in, v_w, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias) ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur); diff --git a/src/models/minimax-m2.cpp b/src/models/minimax-m2.cpp index 22e291d73a33..32c5745c8c7c 100644 --- a/src/models/minimax-m2.cpp +++ b/src/models/minimax-m2.cpp @@ -69,14 +69,26 @@ llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, diff --git a/src/models/olmo2.cpp b/src/models/olmo2.cpp index 7cc262f55046..95e1f7c13af3 100644 --- a/src/models/olmo2.cpp +++ b/src/models/olmo2.cpp @@ -93,14 +93,26 @@ llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph // self_attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, diff --git a/src/models/olmoe.cpp b/src/models/olmoe.cpp index 7976ae44a51c..b6e806d73ed3 100644 --- a/src/models/olmoe.cpp +++ b/src/models/olmoe.cpp @@ -78,14 +78,26 @@ llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_param // self_attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index ba63ae441df5..60895cd862a0 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -269,8 +269,27 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur, model.layers[il].wqkv_s); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head * 2; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur_full = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); + } cb(Qcur_full, "Qcur_full", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, ggml_element_size(Qcur_full) * n_embd_head * 2, @@ -281,12 +300,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 4f87d55d9112..5829ae22cb40 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -292,8 +292,27 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur, model.layers[il].wqkv_s); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head * 2; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur_full = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); + } cb(Qcur_full, "Qcur_full", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, ggml_element_size(Qcur_full) * n_embd_head * 2, @@ -304,12 +323,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 1d873427db5e..248b17dfc2ed 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -214,8 +214,27 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur_full; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head * 2; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur_full = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur_full = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur_full, "Qcur_full", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); @@ -230,12 +249,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn( Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); cb(gate, "gate", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); diff --git a/src/models/step35.cpp b/src/models/step35.cpp index 3b68e68707ae..f656752d7ad1 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -129,9 +129,24 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para { cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head_k * n_head_l; + const int64_t k_dim = n_embd_head_k * n_head_kv_l; + const int64_t v_dim = n_embd_head_v * n_head_kv_l; + const size_t esize = ggml_element_size(qkv); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il);