Skip to content
9 changes: 8 additions & 1 deletion common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,14 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");

n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
// The MTP hidden-state I/O width is the *backbone* width, which can differ
// from the drafter's internal embedding width (gemma4-assistant: internal
// n_embd=256, but it reads/writes target hidden states of width
// n_embd_out=1536). Size every embd row against the output width and assert
// it matches the target's hidden width that we feed in.
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target hidden width");

const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
Expand Down
9 changes: 8 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@ extern "C" {
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;

// sibling context whose model supplies shared inputs / KV cache.
// required for arches that read from another model (gemma4-assistant MTP
// drafter reads the target model's token embeddings and shares its KV).
// null for all other arches. (the caller keeps this context alive.)
struct llama_context * ctx_other;
};

struct llama_model_tensor_override {
Expand Down Expand Up @@ -551,7 +557,8 @@ extern "C" {

DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");

LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API struct llama_context * llama_get_ctx_other(struct llama_context * ctx);
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type

Expand Down
9 changes: 9 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_GEMMA4, "gemma4" },
{ LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" },
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
Expand Down Expand Up @@ -461,6 +462,10 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
{ LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" },
{ LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" },
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
{ LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
{ LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" },
Expand Down Expand Up @@ -582,6 +587,10 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer)
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
{LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
Expand Down
6 changes: 6 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_GEMMA4,
LLM_ARCH_GEMMA4_ASSISTANT,
LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
Expand Down Expand Up @@ -569,6 +570,11 @@ enum llm_tensor {
// EAGLE3 draft-model tensors (upstream PR #18039)
LLM_TENSOR_EAGLE3_TARGET_FEATURES,
LLM_TENSOR_EAGLE3_TARGET_TOK_EMBD,
// gemma4-assistant (MTP next-token drafter) tensors
LLM_TENSOR_NEXTN_PROJ_PRE,
LLM_TENSOR_NEXTN_PROJ_POST,
LLM_TENSOR_MASKED_EMBD_CENTROIDS,
LLM_TENSOR_MASKED_EMBD_ORDERING,
};

enum llm_tensor_layer {
Expand Down
41 changes: 31 additions & 10 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ llama_context::llama_context(

cparams.ctx_type = params.ctx_type;

cparams.ctx_other = nullptr;

// gemma4-assistant reads the target model's token embeddings and shares its
// KV cache, so it requires a sibling target context to be supplied.
if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) {
if (params.ctx_other == nullptr) {
throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this warning is normal during memory fitting)");
}

cparams.ctx_other = params.ctx_other;
}

// Initialize backend samplers here so they are part of the sampling graph
// before the reserve passes run later in this function. This avoids a later
// re-reserve when graph nodes change.
Expand Down Expand Up @@ -318,10 +330,11 @@ llama_context::llama_context(
// init the memory module
if (!hparams.vocab_only) {
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type= */ cparams.ctx_type,
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type =*/ cparams.ctx_type,
/*.mem_other =*/ cparams.ctx_other ? llama_get_memory(cparams.ctx_other) : nullptr,
};

memory.reset(model.create_memory(params_mem, cparams));
Expand Down Expand Up @@ -919,7 +932,7 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
}

const int64_t j = output_resolve_row(i);
const uint32_t n_embd = model.hparams.n_embd;
const uint32_t n_embd = model.hparams.n_embd_out();
return embd_pre_norm.data + j*n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
Expand Down Expand Up @@ -1459,12 +1472,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}

// extract pre-norm embeddings (hidden state before the final output norm)
// extract pre-norm embeddings (hidden state before the final output norm).
// the hidden width is n_embd_out (the backbone width): == n_embd for ordinary
// models, but wider than the drafter's n_embd for gemma4-assistant (where this
// tensor is the n_embd_out-wide nextn hidden state).
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);

const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
}
Expand Down Expand Up @@ -1910,11 +1926,12 @@ int llama_context::decode(const llama_batch & batch_inp) {

// extract pre-norm embeddings (hidden state before the final output norm)
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
// width is n_embd_out (backbone width); wider than n_embd for gemma4-assistant.
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);

const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;

GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
Expand Down Expand Up @@ -2006,7 +2023,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {

const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();

bool has_logits = true;
Expand All @@ -2025,7 +2041,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {

logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd_out*n_outputs_max : 0;

// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
Expand Down Expand Up @@ -3358,6 +3374,7 @@ llama_context_params llama_context_default_params() {
/*.kv_dynamic =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
/*.ctx_other =*/ nullptr,
};

return result;
Expand Down Expand Up @@ -3497,6 +3514,10 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model();
}

llama_context * llama_get_ctx_other(llama_context * ctx) {
return ctx->get_cparams().ctx_other;
}

enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
return ctx->pooling_type();
}
Expand Down
5 changes: 5 additions & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ struct llama_cparams {
enum llama_context_type ctx_type;
enum llama_pooling_type pooling_type;

// sibling context whose model provides shared inputs/KV (gemma4-assistant MTP
// drafter reads the target model's token embeddings + shares its KV cache).
// null for all other arches.
struct llama_context * ctx_other;

ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
};
2 changes: 2 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ class llm_graph_result {
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; }
ggml_tensor * get_h_nextn() const { return t_h_nextn; }

ggml_cgraph * get_gf() const { return gf; }
ggml_context * get_ctx() const { return ctx_compute.get(); }
Expand Down Expand Up @@ -675,6 +676,7 @@ class llm_graph_result {
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm
ggml_tensor * t_h_nextn = nullptr; // [n_embd_backbone, n_outputs] next-token hidden state (gemma4-assistant MTP)

std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
std::map<llama_seq_id, ggml_tensor*> t_candidates;
Expand Down
4 changes: 4 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const {
}

uint32_t llama_hparams::n_embd_inp() const {
if (n_embd_inp_impl > 0) {
return n_embd_inp_impl;
}

uint32_t n_embd_inp = n_embd;

if (n_deepstack_layers > 0) {
Expand Down
4 changes: 4 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ struct llama_hparams {
// output embedding dimension (0 = use n_embd)
uint32_t n_embd_out_impl = 0;

// input/backbone embedding dimension (0 = derive from n_embd; gemma4-assistant
// sets this to the target hidden size so the MTP drafter projects to/from it)
uint32_t n_embd_inp_impl = 0;

// llama4 smallthinker
uint32_t n_moe_layer_step = 0;
uint32_t n_no_rope_layer_step = 4;
Expand Down
15 changes: 12 additions & 3 deletions src/llama-kv-cache-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
const layer_reuse_cb & reuse,
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {

// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
Expand Down Expand Up @@ -59,17 +61,24 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(

LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);

llama_memory_t mem_other_base = nullptr;
llama_memory_t mem_other_swa = nullptr;
if (mem_other) {
mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base();
mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa();
}

kv_base = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share);

LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);

kv_swa = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share);
}

void llama_kv_cache_iswa::clear(bool data) {
Expand Down
4 changes: 3 additions & 1 deletion src/llama-kv-cache-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ class llama_kv_cache_iswa : public llama_memory_i {
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);

~llama_kv_cache_iswa() = default;

Expand Down
24 changes: 23 additions & 1 deletion src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,16 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse,
const layer_share_cb & share,
uint32_t kv_size_max) :
model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {

other = static_cast<llama_kv_cache *>(mem_other);

// save construction parameters for dynamic resize
saved_type_k = type_k;
saved_type_v = type_v;
Expand Down Expand Up @@ -194,6 +198,23 @@ llama_kv_cache::llama_kv_cache(
continue;
}

if (share && other) {
const int32_t il_share = share(il);

if (il_share >= 0) {
const auto & layer_share = other->layers[other->map_layer_ids[il_share]];

LLAMA_LOG_DEBUG("%s: layer %3d: sharing with sibling layer %d\n", __func__, il, il_share);

map_layer_ids[il] = layers.size();

layers.push_back(layer_share);
layers.back().il = il;

continue;
}
}

if (n_embd_head_k_all == 0) {
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
Expand Down Expand Up @@ -1196,7 +1217,8 @@ bool llama_kv_cache::try_resize() {
// NOTE: pass kv_size_max=0 so the constructor does NOT apply
// the dynamic start logic (which would shrink back to 256)
llama_kv_cache tmp(model, saved_type_k, saved_type_v, saved_v_trans, saved_offload, saved_unified, new_size,
saved_n_seq_max, saved_n_pad, saved_n_swa, saved_swa_type, saved_filter, saved_reuse,
saved_n_seq_max, saved_n_pad, saved_n_swa, saved_swa_type,
/*mem_other=*/nullptr, saved_filter, saved_reuse, /*share=*/nullptr,
/*kv_size_max=*/0);

// copy existing data
Expand Down
6 changes: 6 additions & 0 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ class llama_kv_cache : public llama_memory_i {
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse,
const layer_share_cb & share,
uint32_t kv_size_max = 0);

~llama_kv_cache() = default;
Expand Down Expand Up @@ -276,6 +278,10 @@ class llama_kv_cache : public llama_memory_i {

std::vector<kv_layer> layers;

// sibling KV cache to share K/V tensors with (gemma4-assistant MTP drafter
// shares KV with the target). null for all other arches.
llama_kv_cache * other = nullptr;

// dynamic resize state
uint32_t kv_size_cur = 0;
uint32_t kv_size_max_val = 0;
Expand Down
4 changes: 3 additions & 1 deletion src/llama-memory-hybrid-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
n_seq_max,
n_ubatch,
n_pad,
nullptr, // mem_other
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
nullptr, // reuse
nullptr // share
)),
mem_recr(new llama_memory_recurrent(
model,
Expand Down
Loading
Loading