Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 96 additions & 11 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11855,7 +11855,7 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("HunYuanDenseV1ForCausalLM", "HunYuanVLForConditionalGeneration")
@ModelBase.register("HunYuanDenseV1ForCausalLM")
class HunYuanModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE

Expand Down Expand Up @@ -11994,40 +11994,125 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter


@ModelBase.register("HunYuanVLForConditionalGeneration")
class HunyuanOCRVisionModel(MmprojModel):
class HunyuanVLVisionModel(MmprojModel):
# Handles both HunyuanOCR and HunyuanVL, which share the HF architecture name
# "HunYuanVLForConditionalGeneration" and the `vit.perceive.*` vision layout.
# Each variant maps to a different projector type in clip.cpp so image
# preprocessing follows the correct code path.

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
# HunyuanOCR uses max_image_size instead of image_size
# HunyuanOCR / HunyuanVL uses max_image_size instead of image_size
if "image_size" not in self.hparams_vision:
self.hparams_vision["image_size"] = self.hparams_vision.get("max_image_size", 2048)

@staticmethod
def is_ocr_variant(hparams: dict) -> bool:
"""Return True for HunyuanOCR, False for HunyuanVL.

The projector's output dim must equal the text model's hidden_size by
construction (that's what "projector" means). HunyuanOCR pairs a 1B text
backbone (hidden=1024); HunyuanVL pairs a 4B one (hidden=3072). So the
ViT -> LLM projection dim is a hard architectural signature, not a
magic number.
"""
vision_out = int((hparams.get("vision_config") or {}).get("out_hidden_size", 0))
return vision_out == 1024

def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
hparams = self.hparams_vision
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-5))
self.gguf_writer.add_vision_spatial_merge_size(hparams.get("spatial_merge_size", 2))
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
vcfg = self.hparams_vision

if self.is_ocr_variant(self.global_config):
# --- HunyuanOCR ---
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANOCR)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(vcfg.get("rms_norm_eps", 1e-5))
self.gguf_writer.add_vision_spatial_merge_size(vcfg.get("spatial_merge_size", 2))
self.gguf_writer.add_vision_min_pixels(self.preprocessor_config["min_pixels"])
self.gguf_writer.add_vision_max_pixels(self.preprocessor_config["max_pixels"])
return

# --- HunyuanVL ---
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.HUNYUANVL)
self.gguf_writer.add_vision_use_gelu(str(vcfg["hidden_act"]).lower() == "gelu")
self.gguf_writer.add_vision_attention_layernorm_eps(float(vcfg["rms_norm_eps"]))
self.gguf_writer.add_vision_spatial_merge_size(int(vcfg["spatial_merge_size"]))
self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"]))
self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"]))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith("vit."):
return # skip text tensors
return
# strip CLS token (row 0) from position embeddings so resize_position_embeddings works
if "position_embedding" in name:
data_torch = data_torch[1:] # [n_patches+1, n_embd] -> [n_patches, n_embd]
yield from super().modify_tensors(data_torch, name, bid)

def tensor_force_quant(self, name, new_name, bid, n_dims):
# force conv weights to F32 or F16 to avoid BF16 IM2COL issues on Metal
# Both HunyuanOCR and HunyuanVL emit the ViT -> LLM projection as mm.0/mm.2.
if ("mm.0." in new_name or "mm.2." in new_name) and new_name.endswith(".weight"):
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)


@ModelBase.register("HunYuanVLForConditionalGeneration")
class HunyuanVLTextModel(HunYuanModel):
# The "HunYuanVLForConditionalGeneration" HF architecture covers both HunyuanOCR
# and HunyuanVL. HunyuanOCR reuses the HunYuan-Dense text backbone (standard RoPE),
# while HunyuanVL introduces a new LLM arch with XD-RoPE. Detect the variant from
# the config and pick the matching GGUF architecture.
model_arch = gguf.MODEL_ARCH.HUNYUAN_VL

@staticmethod
def _is_ocr_config(hparams: dict) -> bool:
# OCR pairs a 1B text backbone (hidden=1024) with a ViT projector that
# outputs 1024-d; HunyuanVL uses 3072-d. Keep in sync with
# HunyuanVLVisionModel.is_ocr_variant.
return int((hparams.get("vision_config") or {}).get("out_hidden_size", 0)) == 1024

def __init__(self, dir_model: Path, *args, **kwargs):
raw_hparams = kwargs.get("hparams") or ModelBase.load_hparams(dir_model, is_mistral_format=False)
if self._is_ocr_config(raw_hparams):
self.model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
else:
self.model_arch = gguf.MODEL_ARCH.HUNYUAN_VL
super().__init__(dir_model, *args, **kwargs)

def set_gguf_parameters(self):
super().set_gguf_parameters()

# Only emit XD-RoPE metadata for the HunyuanVL backbone; HunyuanOCR uses
# the HunYuan-Dense arch which already handles standard rope in super().
if self.model_arch != gguf.MODEL_ARCH.HUNYUAN_VL:
return

if self.rope_parameters.get("rope_type") != "xdrope":
return

# defaults for HunyuanVL. The C++ side later computes:
# freq_base = rope_theta * alpha ** (head_dim / (head_dim - 2))
self.gguf_writer.add_rope_freq_base(float(self.rope_parameters["rope_theta"]))
self.gguf_writer.add_rope_scaling_alpha(float(self.rope_parameters["alpha"]))
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(float(self.rope_parameters.get("factor", 1)))

ctx_len = int(self.hparams["max_position_embeddings"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(ctx_len)
self.gguf_writer.add_context_length(ctx_len)

self.gguf_writer.add_rope_dimension_sections(list(self.rope_parameters["xdrope_section"]))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors — they are written by HunyuanVLVisionModel
if name.startswith("vit."):
return
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("SmolLM3ForCausalLM")
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3
Expand Down
112 changes: 104 additions & 8 deletions examples/speculative-simple/speculative-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,24 @@
#include <clocale>
#include <cstdio>
#include <cstring>
#include <cinttypes>
#include <string>
#include <vector>
#include <utility>

struct spec_checkpoint {
int64_t n_tokens = 0;

std::vector<uint8_t> data;

size_t size() const {
return data.size();
}

bool empty() const {
return data.empty();
}
};

int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
Expand Down Expand Up @@ -46,6 +62,14 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();

// check if the context supports partial sequence removal
const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt);
const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL);

if (use_ckpt) {
LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n");
}

const llama_vocab * vocab = llama_model_get_vocab(model_tgt);

// load the draft model
Expand Down Expand Up @@ -119,7 +143,7 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();

// target model sampling context
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling));

// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
Expand All @@ -142,21 +166,61 @@ int main(int argc, char ** argv) {

llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);

size_t n_draft = 0;

llama_tokens draft;
spec_checkpoint spec_ckpt;

const auto t_enc_end = ggml_time_us();

const auto t_dec_start = ggml_time_us();

while (true) {
// optionally, generate draft tokens that can be appended to the target batch
// generate or reuse draft tokens
//
// this is the most important part of the speculation. the more probable tokens that are provided here
// the better the performance will be. in theory, this computation can be performed asynchronously and even
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
// from a cache or lookup tables.
//
llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
if (draft.empty()) {
// generate a new draft
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);

if ((int) draft.size() > params_spec.n_max) {
LOG_WRN("draft size %zu exceeds max %d, truncating\n", draft.size(), params_spec.n_max);
draft.resize(params_spec.n_max);
}

if ((int) draft.size() < params_spec.n_min) {
LOG_DBG("ignoring small draft: %zu < %d\n", draft.size(), params_spec.n_min);
draft.clear();
}

// save the original draft size
n_draft = draft.size();

// save a checkpoint of the target context before evaluating the draft
// this allows us to restore the state if partial draft acceptance occurs
if (!draft.empty() && use_ckpt) {
const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
spec_ckpt.data.resize(ckpt_size);

//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == ckpt_size);

spec_ckpt.n_tokens = (int64_t) prompt_tgt.size();
LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n",
spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
}
} else {
// we have a previous (partial) draft to reuse from checkpoint restoration
if (use_ckpt) {
GGML_ASSERT(!spec_ckpt.empty());
}
}

GGML_ASSERT(n_draft > 0);

// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
Expand All @@ -178,21 +242,51 @@ int main(int argc, char ** argv) {
llama_decode(ctx_tgt, batch_tgt);
}

// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt) {
smpl_save.reset(common_sampler_clone(smpl.get()));
}

// sample from the full target batch and return the accepted tokens based on the target sampler
//
// for each token to be accepted, the sampler would have to sample that same token
// in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
// available logits from the batch and sample the next token until we run out of logits or the sampler
// disagrees with the draft
//
const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
auto ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft);

//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());

GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token

// check for partial draft acceptance:
// if the context doesn't support partial sequence removal, restore the checkpoint
// and make the accepted tokens the new partial draft for the next iteration
if (use_ckpt && ids.size() - 1 < draft.size()) {
LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size());

draft = std::move(ids);

const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
GGML_ASSERT(n == spec_ckpt.size());

llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1);

prompt_tgt.resize(spec_ckpt.n_tokens);
smpl = std::move(smpl_save);

n_past = (int) prompt_tgt.size();

continue;
}

common_speculative_accept(spec, ids.size() - 1);

// full acceptance: consume the draft and commit accepted tokens
n_past += ids.size() - 1;
n_drafted += draft.size(); // note: we ignore the discarded small drafts
n_drafted += n_draft; // note: we ignore the discarded small drafts
n_accept += ids.size() - 1;
n_predict += ids.size();

Expand Down Expand Up @@ -222,6 +316,9 @@ int main(int argc, char ** argv) {

LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);

// clear the draft since it has been consumed
draft.clear();

{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);

Expand Down Expand Up @@ -254,11 +351,10 @@ int main(int argc, char ** argv) {

LOG_INF("\n");
LOG_INF("target:\n\n");
common_perf_print(ctx_tgt, smpl);
common_perf_print(ctx_tgt, smpl.get());

llama_batch_free(batch_tgt);

common_sampler_free(smpl);
common_speculative_free(spec);

llama_backend_free();
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@

namespace syclexp = sycl::ext::oneapi::experimental;

#if defined(__INTEL_LLVM_COMPILER) && __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
#include <sycl/ext/oneapi/bfloat16.hpp>
#ifndef GGML_SYCL_HAS_BF16
#define GGML_SYCL_HAS_BF16
#endif
#endif

#if GGML_SYCL_DNNL
#include "dnnl.hpp"
#include "dnnl_sycl.hpp"
Expand Down
23 changes: 16 additions & 7 deletions ggml/src/ggml-sycl/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@
#include "dequantize.hpp"
#include "presets.hpp"

#if defined(__INTEL_LLVM_COMPILER)
#if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
#include <sycl/ext/oneapi/bfloat16.hpp>
#define GGML_SYCL_HAS_BF16
#endif
#endif

template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) {
Expand Down Expand Up @@ -767,6 +760,22 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
}


#ifdef GGML_SYCL_HAS_BF16
to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
case GGML_TYPE_BF16:
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
default:
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}
#endif

to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
Expand Down
Loading
Loading