diff --git a/conversion/base.py b/conversion/base.py index 0421aa4bc4d3..4a5c5f8fbde5 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -122,7 +122,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, sentence_transformers_dense_modules: bool = False, target_model_dir: Path | None = None, fuse_gate_up_exps: bool = False, - fp8_as_q8: bool = False): + fp8_as_q8: bool = False, + fuse_qkv: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -145,6 +146,13 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.fuse_gate_up_exps = fuse_gate_up_exps self._gate_exp_buffer: dict[int, Tensor] = {} self._up_exp_buffer: dict[int, Tensor] = {} + self.fuse_qkv = fuse_qkv + self._q_buffer: dict[int, Tensor] = {} + self._k_buffer: dict[int, Tensor] = {} + self._v_buffer: dict[int, Tensor] = {} + self._q_bias_buffer: dict[int, Tensor] = {} + self._k_bias_buffer: dict[int, Tensor] = {} + self._v_bias_buffer: dict[int, Tensor] = {} self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override @@ -637,6 +645,35 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid): return [] + # Handle Q/K/V tensor fusion if enabled + if self.fuse_qkv and bid is not None: + is_bias = name.endswith('.bias') + suffix = 'bias' if is_bias else 'weight' + buf_q = self._q_bias_buffer if is_bias else self._q_buffer + buf_k = self._k_bias_buffer if is_bias else self._k_buffer + buf_v = self._v_bias_buffer if is_bias else self._v_buffer + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_Q, bid): + buf_q[bid] = data_torch + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_K, bid): + buf_k[bid] = data_torch + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_V, bid): + buf_v[bid] = data_torch + + if bid in buf_q and bid in buf_k and bid in buf_v: + q_data = buf_q.pop(bid) + k_data = buf_k.pop(bid) + v_data = buf_v.pop(bid) + fused_data = torch.cat([q_data, k_data, v_data], dim=0) + fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_QKV, bid, suffix=suffix) + logger.info(f"Fused Q, K, V {suffix} into QKV for layer {bid}") + return [(fused_name, fused_data)] + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_Q, bid) or \ + self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_K, bid) or \ + self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_V, bid): + return [] + return [(new_name, data_torch)] def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3b23d5ebc0d3..641a8b074193 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -154,6 +154,8 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( + "--fuse-qkv", action="store_true", + help="Fuse separate Q, K, V weight tensors into a single QKV tensor.", "--target-model-dir", type=str, default=None, help=( "path to the target model directory; required when converting a standalone draft model " @@ -281,6 +283,7 @@ def main() -> None: target_model_dir=Path(args.target_model_dir) if args.target_model_dir else None, fuse_gate_up_exps=args.fuse_gate_up_exps, fp8_as_q8=args.fp8_as_q8, + fuse_qkv=args.fuse_qkv, ) if args.vocab_only: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cd4cdef8991f..b242303937c3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1762,6 +1762,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1782,6 +1783,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1805,6 +1807,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1825,6 +1828,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1870,6 +1874,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -1972,6 +1977,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_TYPES, MODEL_TENSOR.ATTN_NORM_2, MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2000,6 +2006,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2029,6 +2036,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2041,6 +2049,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2068,6 +2077,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2099,6 +2109,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2114,6 +2125,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2129,6 +2141,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2143,6 +2156,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2157,6 +2171,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2177,6 +2192,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2193,6 +2209,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2242,6 +2259,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2258,6 +2276,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2346,6 +2365,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2479,6 +2499,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2494,6 +2515,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2512,6 +2534,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FACTORS_LONG, MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2549,6 +2572,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2561,6 +2585,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2577,6 +2602,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2595,6 +2621,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2631,6 +2658,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2686,6 +2714,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.DENSE_2_OUT, MODEL_TENSOR.DENSE_3_OUT, MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -2706,6 +2735,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2869,6 +2899,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2898,6 +2929,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2912,6 +2944,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2926,6 +2959,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2977,6 +3011,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.OLMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -2989,6 +3024,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3004,6 +3040,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.SEED_OSS: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3020,6 +3057,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3051,6 +3089,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3072,6 +3111,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3095,6 +3135,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_A, MODEL_TENSOR.ATTN_Q_B, @@ -3127,6 +3168,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_A, MODEL_TENSOR.ATTN_Q_B, @@ -3242,6 +3284,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3318,6 +3361,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3385,6 +3429,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.BITNET: [ + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3463,6 +3509,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3477,6 +3524,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3498,6 +3546,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3517,6 +3566,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3541,6 +3591,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3556,6 +3607,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3581,6 +3633,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3612,6 +3665,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3626,6 +3680,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3651,6 +3706,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3674,6 +3730,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3713,6 +3770,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3760,6 +3818,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3785,6 +3844,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3800,6 +3860,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3826,6 +3887,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3840,6 +3902,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3857,6 +3920,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_NORM, # Attention components + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, # Query projection MODEL_TENSOR.ATTN_K, # Key projection MODEL_TENSOR.ATTN_V, # Value projection @@ -3889,6 +3953,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3909,6 +3974,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3925,6 +3991,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -3942,6 +4009,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3958,6 +4026,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -3981,6 +4050,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_NORM, # operator_norm MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4001,6 +4071,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_NORM, # operator_norm MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4016,6 +4087,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4035,6 +4107,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4052,6 +4125,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4069,6 +4143,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -4089,6 +4164,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -4124,6 +4200,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -4141,6 +4218,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4156,6 +4234,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4266,6 +4345,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -4299,6 +4379,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -4318,6 +4399,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K, @@ -4334,6 +4416,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4c86e43c1f74..8f4dda478879 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1494,6 +1494,11 @@ llm_graph_qkv llm_graph_context::build_qkv( if (layer.wqkv_b) { qkv = ggml_add(ctx0, qkv, layer.wqkv_b); cb(qkv, "wqkv_b", il); + } else if (layer.wq_b && layer.wk_b && layer.wv_b) { + // fused weights but separate biases (from --fuse-qkv conversion) + ggml_tensor * qkv_b = ggml_concat(ctx0, ggml_concat(ctx0, layer.wq_b, layer.wk_b, 0), layer.wv_b, 0); + qkv = ggml_add(ctx0, qkv, qkv_b); + cb(qkv, "wqkv_b", il); } if (hparams.f_clamp_kqv > 0.0f) { qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d58ebac28b9b..ee8b28472b73 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2725,6 +2725,12 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); if (layer.wqkv) { layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + // fallback: when --fuse-qkv only fused weights, biases may still be stored separately + if (!layer.wqkv_b) { + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + } } else { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); diff --git a/src/models/gemma3n.cpp b/src/models/gemma3n.cpp index 83eb8250aa94..ea616db3ba3c 100644 --- a/src/models/gemma3n.cpp +++ b/src/models/gemma3n.cpp @@ -176,7 +176,14 @@ llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_par hparams.f_attention_scale, il); } else { // reuse KV cache of earlier layers - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + const int64_t q_dim = n_embd_head * n_head; + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + } cb(Qcur, "Qcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); diff --git a/src/models/gemma4.cpp b/src/models/gemma4.cpp index 6a96979cebde..de6236c1f769 100644 --- a/src/models/gemma4.cpp +++ b/src/models/gemma4.cpp @@ -224,9 +224,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para // Q projection (shared for both non-KV and KV layers) // this is to mirror Gemma4Attention in pytorch code + ggml_tensor * qkv_fused = nullptr; ggml_tensor * Qcur; - { + if (model.layers[il].wqkv) { + qkv_fused = build_lora_mm(model.layers[il].wqkv, cur, model.layers[il].wqkv_s); + cb(qkv_fused, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head; + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv_fused, q_dim, n_tokens, qkv_fused->nb[1], 0)); + } else { Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); + } + { cb(Qcur, "Qcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -241,12 +249,22 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para // self-attention if (hparams.has_kv(il)) { - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (qkv_fused) { + const int64_t q_dim = n_embd_head * n_head; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv_fused); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv_fused, k_dim, n_tokens, qkv_fused->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv_fused, v_dim, n_tokens, qkv_fused->nb[1], (q_dim + k_dim) * esize)); + } else { + Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + Vcur = model.layers[il].wv + ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) + : Kcur; // if v_proj is not present, use Kcur as Vcur + } cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = model.layers[il].wv - ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) - : Kcur; // if v_proj is not present, use Kcur as Vcur cb(Vcur, "Vcur", il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 367f6990d1fb..a8a9f19d6042 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -195,7 +195,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V // Step 1: Q, K, V projections -> [d_inner, n_tokens] - ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x); + ggml_tensor * x_proj = proj_w ? ggml_mul_mat(ctx0, proj_w, x) : x; // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs); @@ -295,9 +295,20 @@ llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); cb(conv_states_all, "conv_states_all", il); ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs); - ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); - ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); - ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + ggml_tensor * q_in = cur, * k_in = cur, * v_in = cur; + ggml_tensor * q_w = layer.wq, * k_w = layer.wk, * v_w = layer.wv; + if (layer.wqkv) { + ggml_tensor * qkv = ggml_mul_mat(ctx0, layer.wqkv, cur); + const int64_t d_inner = head_dim * n_head; + const size_t esize = ggml_element_size(qkv); + q_in = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, d_inner, n_tokens, qkv->nb[1], 0)); + k_in = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, d_inner, n_tokens, qkv->nb[1], d_inner * esize)); + v_in = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, d_inner, n_tokens, qkv->nb[1], 2 * d_inner * esize)); + q_w = nullptr; k_w = nullptr; v_w = nullptr; + } + ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, q_in, q_w, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, k_in, k_w, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, v_in, v_w, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias) ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur); diff --git a/src/models/minimax-m2.cpp b/src/models/minimax-m2.cpp index b25435e4d97c..3f5a862c9eb6 100644 --- a/src/models/minimax-m2.cpp +++ b/src/models/minimax-m2.cpp @@ -69,14 +69,26 @@ llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, diff --git a/src/models/olmo2.cpp b/src/models/olmo2.cpp index cb52cdef7204..9bd1911b99d3 100644 --- a/src/models/olmo2.cpp +++ b/src/models/olmo2.cpp @@ -93,14 +93,26 @@ llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph // self_attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, diff --git a/src/models/olmoe.cpp b/src/models/olmoe.cpp index 1e2baeb207ff..8f8481c905a5 100644 --- a/src/models/olmoe.cpp +++ b/src/models/olmoe.cpp @@ -79,14 +79,26 @@ llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_param // self_attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index d8ffe43ae76c..480fadc2fd7d 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -266,8 +266,27 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur, model.layers[il].wqkv_s); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head * 2; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur_full = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); + } cb(Qcur_full, "Qcur_full", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, ggml_element_size(Qcur_full) * n_embd_head * 2, @@ -278,12 +297,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 7b0876cbb04b..4996d0b30ab6 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -290,8 +290,27 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur, model.layers[il].wqkv_s); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head * 2; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur_full = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); + } cb(Qcur_full, "Qcur_full", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, ggml_element_size(Qcur_full) * n_embd_head * 2, @@ -302,12 +321,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 97200a44072f..63e5bf6f9148 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -214,8 +214,27 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn( // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur_full; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head * n_head * 2; + const int64_t k_dim = n_embd_head * n_head_kv; + const int64_t v_dim = n_embd_head * n_head_kv; + const size_t esize = ggml_element_size(qkv); + Qcur_full = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur_full = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur_full, "Qcur_full", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); @@ -230,12 +249,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn( Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); cb(gate, "gate", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); diff --git a/src/models/step35.cpp b/src/models/step35.cpp index 9b7b18a3678b..289bd280bf9c 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -216,9 +216,24 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para { cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + if (model.layers[il].wqkv) { + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + const int64_t q_dim = n_embd_head_k * n_head_l; + const int64_t k_dim = n_embd_head_k * n_head_kv_l; + const int64_t v_dim = n_embd_head_v * n_head_kv_l; + const size_t esize = ggml_element_size(qkv); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, q_dim, n_tokens, qkv->nb[1], 0)); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, k_dim, n_tokens, qkv->nb[1], q_dim * esize)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, v_dim, n_tokens, qkv->nb[1], (q_dim + k_dim) * esize)); + } else { + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il);