Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-cross.yml
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ jobs:
apt-get install -y --no-install-recommends \
build-essential \
glslc \
spirv-headers \
gcc-14-loongarch64-linux-gnu \
g++-14-loongarch64-linux-gnu \
libvulkan-dev:loong64
Expand Down
2 changes: 1 addition & 1 deletion common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2334,7 +2334,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
? input
: params.generation_prompt + input;

LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
//LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());

common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
if (params.debug) {
Expand Down
4 changes: 2 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <sstream>
#include <string>
#include <string_view>
#include <variant>
#include <vector>
#include <map>

Expand Down Expand Up @@ -303,7 +302,7 @@ struct common_params_speculative {
// general-purpose speculative decoding parameters

int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)

Expand All @@ -312,6 +311,7 @@ struct common_params_speculative {
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models

std::shared_ptr<common_ngram_mod> ngram_mod;

Expand Down
8 changes: 4 additions & 4 deletions common/ngram-map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ void common_ngram_map_begin(
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
}

map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
map.idx_last_check = size_begin;
map.size_last_begin = size_begin;
}

Expand All @@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map,
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
}

if (map.idx_last_check > cur_len) {
if (map.idx_last_check > cur_len) {
// Should not happen because of common_ngram_map_begin().
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
}
Expand Down Expand Up @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());

map.last_draft_created = false;
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;
Expand Down Expand Up @@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.

// update the value statistics
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
n_accepted, curr_value.n_accepted);
curr_value.n_accepted = n_accepted;
}
162 changes: 139 additions & 23 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <cstring>
#include <iomanip>
#include <map>
#include <cinttypes>

#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
Expand Down Expand Up @@ -144,10 +145,28 @@ struct common_speculative_state {
virtual void accept(uint16_t n_accepted) = 0;
};

struct common_speculative_checkpoint {
llama_pos pos_min = 0;
llama_pos pos_max = 0;

int64_t n_tokens = 0;

std::vector<uint8_t> data;

size_t size() const {
return data.size();
}

size_t ckpt_size = 0;
};

struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;

struct common_speculative_checkpoint ckpt;
bool use_checkpoint;

common_sampler * smpl;

llama_batch batch;
Expand All @@ -160,10 +179,12 @@ struct common_speculative_state_draft : public common_speculative_state {
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements)
const std::vector<std::pair<std::string, std::string>> & replacements,
bool use_checkpoint)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, use_checkpoint(use_checkpoint)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr;
Expand Down Expand Up @@ -218,7 +239,48 @@ struct common_speculative_state_draft : public common_speculative_state {
}

void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
if (use_checkpoint && ckpt.size() > 0) {
// delete checkpoint
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
ckpt.pos_min = 0;
ckpt.pos_max = 0;
ckpt.n_tokens = 0;
ckpt.ckpt_size = 0;
ckpt.data.clear();
}
}

size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
int slot_id = 0;
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
ckpt.data.resize(checkpoint_size);

const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}

LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__,
ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024);
return n;
}

size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
int slot_id = 0;
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_part_expected) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
}
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);

return n;
}

void draft(
Expand All @@ -236,8 +298,8 @@ struct common_speculative_state_draft : public common_speculative_state {

auto * mem_dft = llama_get_memory(ctx_dft);

int reuse_i = 0;
int reuse_n = 0;
int reuse_i = 0; // index of part to be reused in prompt_dft
int reuse_n = 0; // length of part to be reused in prompt_dft

const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;

Expand Down Expand Up @@ -287,18 +349,26 @@ struct common_speculative_state_draft : public common_speculative_state {
}
}

LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
__func__, reuse_i, reuse_n);
reuse_i = 0;
reuse_n = 0;
}

result.clear();
result.reserve(params.n_max);

if (reuse_n == 0) {
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);

Expand All @@ -310,19 +380,50 @@ struct common_speculative_state_draft : public common_speculative_state {
return;
}

bool do_restore = false;
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
// This can happen after a partial acceptance (speculative decoding with checkpoints)
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
__func__, prompt_dft.size(), prompt_cur.size());
prompt_dft.resize(prompt_cur.size());
do_restore = true;
}

if (reuse_i > 0) {
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
}
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);

prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
}

if (reuse_n < (int) prompt_dft.size()) {
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
if (reuse_n < (int) prompt_dft.size() || do_restore) {
if (use_checkpoint) {
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
}
draft_restore_checkpoint(ckpt.ckpt_size);
reuse_n = ckpt.n_tokens;
prompt_dft.resize(reuse_n);
needs_ckpt = false;
} else {
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
__func__, reuse_n, prompt_dft.size());
}
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
}
}
}

if (needs_ckpt) {
ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens);
}

// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);

Expand All @@ -337,7 +438,11 @@ struct common_speculative_state_draft : public common_speculative_state {
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());

llama_decode(ctx_dft, batch);
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
__func__, ret, prompt_cur.size());
}
}

const llama_pos n_past = prompt_dft.size();
Expand All @@ -351,7 +456,11 @@ struct common_speculative_state_draft : public common_speculative_state {

LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());

llama_decode(ctx_dft, batch);
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
__func__, ret, prompt_cur.size(), prompt_dft.size());
}

common_sampler_reset(smpl);

Expand Down Expand Up @@ -387,7 +496,11 @@ struct common_speculative_state_draft : public common_speculative_state {
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);

// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch);
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
__func__, i, ret, prompt_cur.size(), prompt_dft.size());
}

prompt_dft.push_back(id);
}
Expand Down Expand Up @@ -739,6 +852,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {

struct common_speculative {
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states

common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
};

Expand Down Expand Up @@ -798,13 +912,13 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
return it->second;
}

bool common_speculative_is_compat(llama_context * ctx_tgt) {
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
auto * mem = llama_get_memory(ctx_tgt);
if (mem == nullptr) {
return false;
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
}

bool res = true;
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;

llama_memory_clear(mem, true);

Expand All @@ -816,14 +930,14 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = false;
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
goto done;
}

// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
goto done;
}

Expand Down Expand Up @@ -909,9 +1023,10 @@ common_speculative * common_speculative_init(
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
));
break;
}
Expand Down Expand Up @@ -966,7 +1081,8 @@ common_speculative * common_speculative_init(
}

auto * result = new common_speculative {
/* .impls = */ std::move(impls)
/* .impls = */ std::move(impls),
/* .curr_impl = */ nullptr,
};

return result;
Expand Down
Loading
Loading