Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion conversion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
sentence_transformers_dense_modules: bool = False,
target_model_dir: Path | None = None,
fuse_gate_up_exps: bool = False,
fp8_as_q8: bool = False):
fp8_as_q8: bool = False,
fuse_qkv: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
Expand All @@ -145,6 +146,13 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.fuse_gate_up_exps = fuse_gate_up_exps
self._gate_exp_buffer: dict[int, Tensor] = {}
self._up_exp_buffer: dict[int, Tensor] = {}
self.fuse_qkv = fuse_qkv
self._q_buffer: dict[int, Tensor] = {}
self._k_buffer: dict[int, Tensor] = {}
self._v_buffer: dict[int, Tensor] = {}
self._q_bias_buffer: dict[int, Tensor] = {}
self._k_bias_buffer: dict[int, Tensor] = {}
self._v_bias_buffer: dict[int, Tensor] = {}
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
Expand Down Expand Up @@ -637,6 +645,35 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
return []

# Handle Q/K/V tensor fusion if enabled
if self.fuse_qkv and bid is not None:
is_bias = name.endswith('.bias')
suffix = 'bias' if is_bias else 'weight'
buf_q = self._q_bias_buffer if is_bias else self._q_buffer
buf_k = self._k_bias_buffer if is_bias else self._k_buffer
buf_v = self._v_bias_buffer if is_bias else self._v_buffer

if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_Q, bid):
buf_q[bid] = data_torch
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_K, bid):
buf_k[bid] = data_torch
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_V, bid):
buf_v[bid] = data_torch

if bid in buf_q and bid in buf_k and bid in buf_v:
q_data = buf_q.pop(bid)
k_data = buf_k.pop(bid)
v_data = buf_v.pop(bid)
fused_data = torch.cat([q_data, k_data, v_data], dim=0)
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_QKV, bid, suffix=suffix)
logger.info(f"Fused Q, K, V {suffix} into QKV for layer {bid}")
return [(fused_name, fused_data)]

if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_Q, bid) or \
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_K, bid) or \
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.ATTN_V, bid):
return []

return [(new_name, data_torch)]

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
Expand Down
3 changes: 3 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def parse_args() -> argparse.Namespace:
)

parser.add_argument(
"--fuse-qkv", action="store_true",
help="Fuse separate Q, K, V weight tensors into a single QKV tensor.",
"--target-model-dir", type=str, default=None,
help=(
"path to the target model directory; required when converting a standalone draft model "
Expand Down Expand Up @@ -281,6 +283,7 @@ def main() -> None:
target_model_dir=Path(args.target_model_dir) if args.target_model_dir else None,
fuse_gate_up_exps=args.fuse_gate_up_exps,
fp8_as_q8=args.fp8_as_q8,
fuse_qkv=args.fuse_qkv,
)

if args.vocab_only:
Expand Down
Loading