diff --git a/common/speculative.cpp b/common/speculative.cpp index 8bd76c108..b1cf61c93 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -434,7 +434,14 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); - n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); + // The MTP hidden-state I/O width is the *backbone* width, which can differ + // from the drafter's internal embedding width (gemma4-assistant: internal + // n_embd=256, but it reads/writes target hidden states of width + // n_embd_out=1536). Size every embd row against the output width and assert + // it matches the target's hidden width that we feed in. + n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft)); + GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) && + "MTP input row width must match the target hidden width"); const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1); diff --git a/include/llama.h b/include/llama.h index cd020b09b..b601995c5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -392,6 +392,12 @@ extern "C" { // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + + // sibling context whose model supplies shared inputs / KV cache. + // required for arches that read from another model (gemma4-assistant MTP + // drafter reads the target model's token embeddings and shares its KV). + // null for all other arches. (the caller keeps this context alive.) + struct llama_context * ctx_other; }; struct llama_model_tensor_override { @@ -551,7 +557,8 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); - LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API struct llama_context * llama_get_ctx_other(struct llama_context * ctx); LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4bd4cbf32..1d4945ec8 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -59,6 +59,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA4, "gemma4" }, + { LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -461,6 +462,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, + { LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" }, + { LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" }, + { LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" }, + { LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" }, { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, @@ -582,6 +587,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer) {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + {LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 1dea33689..2990e8c54 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -63,6 +63,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA4, + LLM_ARCH_GEMMA4_ASSISTANT, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -569,6 +570,11 @@ enum llm_tensor { // EAGLE3 draft-model tensors (upstream PR #18039) LLM_TENSOR_EAGLE3_TARGET_FEATURES, LLM_TENSOR_EAGLE3_TARGET_TOK_EMBD, + // gemma4-assistant (MTP next-token drafter) tensors + LLM_TENSOR_NEXTN_PROJ_PRE, + LLM_TENSOR_NEXTN_PROJ_POST, + LLM_TENSOR_MASKED_EMBD_CENTROIDS, + LLM_TENSOR_MASKED_EMBD_ORDERING, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 609e9aa35..de916a2b4 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -106,6 +106,18 @@ llama_context::llama_context( cparams.ctx_type = params.ctx_type; + cparams.ctx_other = nullptr; + + // gemma4-assistant reads the target model's token embeddings and shares its + // KV cache, so it requires a sibling target context to be supplied. + if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) { + if (params.ctx_other == nullptr) { + throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this warning is normal during memory fitting)"); + } + + cparams.ctx_other = params.ctx_other; + } + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -318,10 +330,11 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, - /*.swa_full =*/ params.swa_full, - /*.ctx_type= */ cparams.ctx_type, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, + /*.ctx_type =*/ cparams.ctx_type, + /*.mem_other =*/ cparams.ctx_other ? llama_get_memory(cparams.ctx_other) : nullptr, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -919,7 +932,7 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { } const int64_t j = output_resolve_row(i); - const uint32_t n_embd = model.hparams.n_embd; + const uint32_t n_embd = model.hparams.n_embd_out(); return embd_pre_norm.data + j*n_embd; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); @@ -1459,12 +1472,15 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - // extract pre-norm embeddings (hidden state before the final output norm) + // extract pre-norm embeddings (hidden state before the final output norm). + // the hidden width is n_embd_out (the backbone width): == n_embd for ordinary + // models, but wider than the drafter's n_embd for gemma4-assistant (where this + // tensor is the n_embd_out-wide nextn hidden state). if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd = hparams.n_embd_out(); GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); } @@ -1910,11 +1926,12 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract pre-norm embeddings (hidden state before the final output norm) // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + // width is n_embd_out (backbone width); wider than n_embd for gemma4-assistant. if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd = hparams.n_embd_out(); float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd; GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); @@ -2006,7 +2023,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; @@ -2025,7 +2041,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { logits.size = has_logits ? n_vocab*n_outputs_max : 0; embd.size = has_embd ? n_embd_out*n_outputs_max : 0; - embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd_out*n_outputs_max : 0; // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -3358,6 +3374,7 @@ llama_context_params llama_context_default_params() { /*.kv_dynamic =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, + /*.ctx_other =*/ nullptr, }; return result; @@ -3497,6 +3514,10 @@ const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } +llama_context * llama_get_ctx_other(llama_context * ctx) { + return ctx->get_cparams().ctx_other; +} + enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { return ctx->pooling_type(); } diff --git a/src/llama-cparams.h b/src/llama-cparams.h index ad1d1d1c8..29af834c9 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -46,6 +46,11 @@ struct llama_cparams { enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; + // sibling context whose model provides shared inputs/KV (gemma4-assistant MTP + // drafter reads the target model's token embeddings + shares its KV cache). + // null for all other arches. + struct llama_context * ctx_other; + ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; }; diff --git a/src/llama-graph.h b/src/llama-graph.h index 9e55d0a67..38c503526 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -646,6 +646,7 @@ class llm_graph_result { ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } + ggml_tensor * get_h_nextn() const { return t_h_nextn; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -675,6 +676,7 @@ class llm_graph_result { ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm + ggml_tensor * t_h_nextn = nullptr; // [n_embd_backbone, n_outputs] next-token hidden state (gemma4-assistant MTP) std::map t_sampled_logits; std::map t_candidates; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 2239309c8..81fb53149 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -71,6 +71,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const { } uint32_t llama_hparams::n_embd_inp() const { + if (n_embd_inp_impl > 0) { + return n_embd_inp_impl; + } + uint32_t n_embd_inp = n_embd; if (n_deepstack_layers > 0) { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 556cd69d1..cb71a466e 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -188,6 +188,10 @@ struct llama_hparams { // output embedding dimension (0 = use n_embd) uint32_t n_embd_out_impl = 0; + // input/backbone embedding dimension (0 = derive from n_embd; gemma4-assistant + // sets this to the target hidden size so the MTP drafter projects to/from it) + uint32_t n_embd_inp_impl = 0; + // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 26e2cb427..bf0917191 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : hparams(model.hparams), unified(unified) { // chain filters const layer_filter_cb filter_base = [&](int32_t il) { @@ -59,17 +61,24 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + llama_memory_t mem_other_base = nullptr; + llama_memory_t mem_other_swa = nullptr; + if (mem_other) { + mem_other_base = static_cast(mem_other)->get_base(); + mem_other_swa = static_cast(mem_other)->get_swa(); + } + kv_base = std::make_unique( model, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); + 0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type, filter_swa, reuse); + hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share); } void llama_kv_cache_iswa::clear(bool data) { diff --git a/src/llama-kv-cache-iswa.h b/src/llama-kv-cache-iswa.h index 70ab22f0d..dfafc1ef5 100644 --- a/src/llama-kv-cache-iswa.h +++ b/src/llama-kv-cache-iswa.h @@ -25,8 +25,10 @@ class llama_kv_cache_iswa : public llama_memory_i { uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache_iswa() = default; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index cd9666a21..2d08ab908 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -89,12 +89,16 @@ llama_kv_cache::llama_kv_cache( uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, const layer_reuse_cb & reuse, + const layer_share_cb & share, uint32_t kv_size_max) : model(model), hparams(model.hparams), v_trans(v_trans), n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + other = static_cast(mem_other); + // save construction parameters for dynamic resize saved_type_k = type_k; saved_type_v = type_v; @@ -194,6 +198,23 @@ llama_kv_cache::llama_kv_cache( continue; } + if (share && other) { + const int32_t il_share = share(il); + + if (il_share >= 0) { + const auto & layer_share = other->layers[other->map_layer_ids[il_share]]; + + LLAMA_LOG_DEBUG("%s: layer %3d: sharing with sibling layer %d\n", __func__, il, il_share); + + map_layer_ids[il] = layers.size(); + + layers.push_back(layer_share); + layers.back().il = il; + + continue; + } + } + if (n_embd_head_k_all == 0) { n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { @@ -1196,7 +1217,8 @@ bool llama_kv_cache::try_resize() { // NOTE: pass kv_size_max=0 so the constructor does NOT apply // the dynamic start logic (which would shrink back to 256) llama_kv_cache tmp(model, saved_type_k, saved_type_v, saved_v_trans, saved_offload, saved_unified, new_size, - saved_n_seq_max, saved_n_pad, saved_n_swa, saved_swa_type, saved_filter, saved_reuse, + saved_n_seq_max, saved_n_pad, saved_n_swa, saved_swa_type, + /*mem_other=*/nullptr, saved_filter, saved_reuse, /*share=*/nullptr, /*kv_size_max=*/0); // copy existing data diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 60b33b429..23486f322 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -105,8 +105,10 @@ class llama_kv_cache : public llama_memory_i { uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, const layer_reuse_cb & reuse, + const layer_share_cb & share, uint32_t kv_size_max = 0); ~llama_kv_cache() = default; @@ -276,6 +278,10 @@ class llama_kv_cache : public llama_memory_i { std::vector layers; + // sibling KV cache to share K/V tensors with (gemma4-assistant MTP drafter + // shares KV with the target). null for all other arches. + llama_kv_cache * other = nullptr; + // dynamic resize state uint32_t kv_size_cur = 0; uint32_t kv_size_max_val = 0; diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index a59561ea5..99c2bb4c7 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -43,10 +43,12 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_ubatch, n_pad, + nullptr, // mem_other filter_attn == nullptr ? [&](int32_t il) { return !hparams.is_recurrent(il); } : filter_attn, - nullptr + nullptr, // reuse + nullptr // share )), mem_recr(new llama_memory_recurrent( model, diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index e21e42d09..03d23aa26 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -45,10 +45,12 @@ llama_memory_hybrid::llama_memory_hybrid( n_pad, n_swa, swa_type, + nullptr, // mem_other filter_attn == nullptr ? [&](int32_t il) { return !hparams.is_recurrent(il); } : filter_attn, - nullptr, + nullptr, // reuse + nullptr, // share kv_size_max )), mem_recr(new llama_memory_recurrent( diff --git a/src/llama-memory.h b/src/llama-memory.h index 4ad1612e4..527f02de8 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -23,6 +23,10 @@ struct llama_memory_params { bool swa_full; llama_context_type ctx_type; + + // sibling memory whose KV cache is shared per-layer (gemma4-assistant MTP + // drafter shares KV with the target). null for all other arches. + llama_memory_t mem_other; }; enum llama_memory_status { @@ -76,6 +80,10 @@ struct llama_memory_i { // return negative value to indicate that the layer il should not reuse memory using layer_reuse_cb = std::function; + // this callback is used to specify which layer of the sibling (mem_other) cache + // a given layer should share its KV with. return negative to indicate no sharing. + using layer_share_cb = std::function; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 864774635..93ed4d7ca 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -136,6 +136,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_gemma3n(params); case LLM_ARCH_GEMMA4: return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA4_ASSISTANT: + return new llama_model_gemma4_assistant(params); case LLM_ARCH_GEMMA_EMBEDDING: return new llama_model_gemma_embedding(params); case LLM_ARCH_STARCODER2: @@ -2046,6 +2048,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } else { llama_memory_i::layer_reuse_cb reuse = nullptr; llama_kv_cache::layer_filter_cb filter = nullptr; + llama_kv_cache::layer_share_cb share = nullptr; + llama_memory_t mem_other = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2062,6 +2066,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; } + if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { + // The MTP drafter carries no K/V weights; every drafter layer + // shares the target model's last (dense) or second-to-last (SWA) + // KV cache layer via the sibling ctx_other memory. + mem_other = llama_get_memory(cparams.ctx_other); + + share = [&](int32_t il) { + const llama_model * model_other = llama_get_model(cparams.ctx_other); + + if (hparams.is_swa(il)) { + return (int32_t) llama_model_n_layer(model_other) - 2; + } + + return (int32_t) llama_model_n_layer(model_other) - 1; + }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2077,8 +2098,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, + mem_other, filter, - reuse); + reuse, + share); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2094,8 +2117,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, + mem_other, filter, nullptr, + share, cparams.kv_dynamic ? cparams.n_ctx_seq : 0); } } @@ -2334,6 +2359,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA4: + case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: diff --git a/src/llama-model.h b/src/llama-model.h index de0b1bf85..c7c3b3f2d 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -549,6 +549,10 @@ struct llama_model { struct ggml_tensor * dflash_fc = nullptr; struct ggml_tensor * dflash_hidden_norm = nullptr; + // gemma4-assistant (MTP next-token drafter) projections. + struct ggml_tensor * nextn_proj_pre = nullptr; + struct ggml_tensor * nextn_proj_post = nullptr; + struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; diff --git a/src/models/gemma4-assistant.cpp b/src/models/gemma4-assistant.cpp new file mode 100644 index 000000000..fc8280ed5 --- /dev/null +++ b/src/models/gemma4-assistant.cpp @@ -0,0 +1,215 @@ +#include "models.h" + +void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) { + hparams.n_embd_inp_impl = hparams.n_embd_out(); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + + type = LLM_TYPE_E2B; +} + +void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + if (hparams.n_embd_out() == n_embd) { + throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED); + + const int64_t n_embd_backbone = hparams.n_embd_inp(); + nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0); + + int rope_freqs_flag = 0; + + const int n_layer_nextn = (int) hparams.n_layer; + + for (int i = 0; i < n_layer_nextn; ++i) { + auto & layer = layers[i]; + + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_ff = hparams.n_ff(i); + + if (i == 0) { + nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0); + } + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0); + + if (!hparams.is_swa(i)) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0); + } +} + +std::unique_ptr llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_backbone = hparams.n_embd_inp(); + + const int n_layer_nextn = (int) hparams.n_layer; + + ggml_tensor * inp_tokens; + ggml_tensor * inp_h; + { + auto inp = std::make_unique(n_embd_backbone); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + inp_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens); + cb(inp->embd, "inp_h", -1); + ggml_set_input(inp->embd); + inp_h = inp->embd; + res->t_inp_embd = inp->embd; + + res->add_input(std::move(inp)); + } + + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens); + x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); + cb(x, "inp_embd_target", -1); + + ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0); + cb(xh, "inp_xh", -1); + + ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh); + cb(cur, "pre_proj", -1); + + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer_nextn; ++il) { + const bool is_swa = hparams.is_swa(il); + + const int64_t n_embd_head = hparams.n_embd_head_k(il); + const int64_t n_head = hparams.n_head(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur_norm, "attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs; + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + + if (il == n_layer_nextn - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + ggml_tensor * logits = build_lora_mm(model.output, cur); + cb(logits, "result_output", -1); + res->t_logits = logits; + + ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur); + cb(h_next, "h_nextn", -1); + res->t_h_nextn = h_next; + ggml_set_output(res->t_h_nextn); + + // Route the next-token hidden state through the existing pre-norm extraction + // seam so the MTP runtime (common/speculative.cpp) can chain it back as the + // drafter's input hidden for the next speculative step. The pre-norm/nextn + // hidden width is n_embd_out (the backbone width), not the drafter's n_embd. + res->t_h_pre_norm = h_next; + + ggml_build_forward_expand(gf, logits); + ggml_build_forward_expand(gf, h_next); +} diff --git a/src/models/gemma4.cpp b/src/models/gemma4.cpp index 4f9d8b18b..dc309cf11 100644 --- a/src/models/gemma4.cpp +++ b/src/models/gemma4.cpp @@ -376,6 +376,13 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para model.output_norm, nullptr, LLM_NORM_RMS, -1); + // Expose the post-output-norm hidden state (the LM-head input feature) through + // the pre-norm extraction seam so gemma4-assistant MTP draft contexts can read + // it as the recurrent h input. This matches the reference (transformers/vLLM/ + // SGLang), which feeds the drafter the target's post-final-norm hidden state. + cb(cur, "h_nextn", -1); + res->t_h_pre_norm = cur; + cb(cur, "result_norm", -1); res->t_embd = cur; diff --git a/src/models/models.h b/src/models/models.h index bb6c23a75..955caf427 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -813,6 +813,19 @@ struct llama_model_gemma4 : public llama_model_base { }; +struct llama_model_gemma4_assistant : public llama_model_base { + llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_gemma_embedding : public llama_model_base { llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/tools/omnivoice/src/eliza-inference-ffi.cpp b/tools/omnivoice/src/eliza-inference-ffi.cpp index 2821dcd22..a0390cacc 100644 --- a/tools/omnivoice/src/eliza-inference-ffi.cpp +++ b/tools/omnivoice/src/eliza-inference-ffi.cpp @@ -1407,6 +1407,11 @@ static Engine * create_engine( cp.n_rs_seq = 0; // draft ctx rolls back via PART/checkpoint, not RS cp.n_threads = cparams_tgt.n_threads; cp.n_threads_batch = cparams_tgt.n_threads_batch; + // Separate-drafter MTP archs that reach into the target context for its + // token embeddings + hidden state (e.g. gemma4-assistant) require + // `ctx_other` = the target context; the llama-context ctor hard-fails + // without it. Inert for archs that don't consult ctx_other (same-file MTP). + cp.ctx_other = e->ctx_tgt; e->ctx_dft = llama_init_from_model(model_for_dft, cp); if (!e->ctx_dft) { diff --git a/tools/omnivoice/src/voice-classifiers/voice_classifier/voice_diarizer.c b/tools/omnivoice/src/voice-classifiers/voice_classifier/voice_diarizer.c index a10c729c0..c266e5de0 100644 --- a/tools/omnivoice/src/voice-classifiers/voice_classifier/voice_diarizer.c +++ b/tools/omnivoice/src/voice-classifiers/voice_classifier/voice_diarizer.c @@ -196,9 +196,11 @@ static inline float sigmoidf(float x) { } } -/* One-direction LSTM step. Gates packed in I, F, G, O order - * (matches the converter's reorder). `x_dot_W` is the - * pre-computed x @ W_ih^T + b_ih, shape [T, 4H]. */ +/* One-direction LSTM step. The GGUF packs the gates in ONNX LSTM order + * I, O, F, C(=G) — NOT PyTorch's I, F, G, O. Reading them as IFGO scrambled the + * forget/output/cell gates and made the diarizer over-detect overlap on inputs + * near the decision boundary (it passed the small parity-fixture suite by luck); + * see #9460. `x_dot_W` is the pre-computed x @ W_ih^T + b_ih, shape [T, 4H]. */ static void lstm_run_dir(const float *x_dot_W, int T, int H, const float *W_hh, /* [4H, H] */ @@ -223,11 +225,11 @@ static void lstm_run_dir(const float *x_dot_W, gate_buf[g] = acc; } - /* Apply nonlinearities. Gate order I, F, G, O. */ + /* Apply nonlinearities. ONNX LSTM gate order I, O, F, C(=G). */ const float *gi = gate_buf + 0 * H; - const float *gf = gate_buf + 1 * H; - const float *gg = gate_buf + 2 * H; - const float *go = gate_buf + 3 * H; + const float *go = gate_buf + 1 * H; + const float *gf = gate_buf + 2 * H; + const float *gg = gate_buf + 3 * H; for (int j = 0; j < H; ++j) { const float i_t = sigmoidf(gi[j]); const float f_t = sigmoidf(gf[j]); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index cf875391f..c8fc91c43 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -866,23 +866,34 @@ struct server_context_impl { return false; } + const bool spec_mtp = std::find(params_base.speculative.types.begin(), + params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); + // Upstream PR #22660: for SWA draft models, force swa_full on the // draft context so prefix reuse (seq_rm + seq_add) works beyond // the SWA window during speculation. Without this, the draft has // to re-decode from the window edge on every long-context request // and acceptance length degrades sharply. - if (llama_model_n_swa(model_dft.get()) > 0 && !params_dft.swa_full) { + // + // Exception: an MTP drafter that shares the target's KV cache (e.g. + // gemma4-assistant via ctx_other) must size its SWA cache to match + // the target's exactly. Forcing swa_full here would make the drafter + // expect a full-size SWA cache while the shared target tensor is + // small-SWA-sized, overflowing the view in get_k/get_v. + if (!spec_mtp && llama_model_n_swa(model_dft.get()) > 0 && !params_dft.swa_full) { SRV_INF("%s", "draft model uses SWA - enabling swa_full for the draft context\n"); params_dft.swa_full = true; } auto cparams = common_context_params_to_llama(params_dft); - const bool spec_mtp = std::find(params_base.speculative.types.begin(), - params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); if (spec_mtp) { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + // gemma4-assistant-style MTP drafters read the target model's + // token embeddings and share its KV cache via ctx_other. Setting + // it for other draft arches is harmless (only that arch reads it). + cparams.ctx_other = ctx_tgt; } // note: for small models maybe we can set this to the maximum possible draft from all speculative types @@ -900,8 +911,9 @@ struct server_context_impl { params_base.model.path.c_str()); auto cparams_mtp = common_context_params_to_llama(params_base); - cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; - cparams_mtp.n_rs_seq = 0; + cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + cparams_mtp.n_rs_seq = 0; + cparams_mtp.ctx_other = ctx_tgt; ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp)); if (ctx_dft == nullptr) {