Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
hf_tag = "default";
}

std::string model_endpoint = get_model_endpoint();
std::string model_endpoint = common_get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";

// prepare local path for caching
Expand Down Expand Up @@ -1316,13 +1316,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"--clear-idle"},
{"--no-clear-idle"},
{"--cache-idle-slots"},
{"--no-cache-idle-slots"},
"save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)",
[](common_params & params, bool value) {
params.clear_idle = value;
params.cache_idle_slots = value;
}
).set_env("LLAMA_ARG_CLEAR_IDLE").set_examples({LLAMA_EXAMPLE_SERVER}));
).set_env("LLAMA_ARG_CACHE_IDLE_SLOTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
Expand Down
38 changes: 37 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {

common_init_result::~common_init_result() = default;

std::string get_model_endpoint() {
std::string common_get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
Expand All @@ -1397,6 +1397,42 @@ std::string get_model_endpoint() {
return model_endpoint;
}

common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
auto * mem = llama_get_memory(ctx);
if (mem == nullptr) {
return COMMON_CONTEXT_SEQ_RM_TYPE_NO;
}

common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART;

llama_memory_clear(mem, true);

// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);

int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
goto done;
}

// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}

done:
llama_memory_clear(mem, true);
llama_synchronize(ctx);

return res;
}

void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
std::vector<llama_adapter_lora *> loras;
std::vector<float> scales;
Expand Down
27 changes: 21 additions & 6 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,9 @@ struct common_params_speculative {

// ngram-based speculative decoding

uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed

std::shared_ptr<common_ngram_mod> ngram_mod;

Expand Down Expand Up @@ -567,7 +566,7 @@ struct common_params {
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
bool cache_prompt = true; // whether to enable prompt caching
bool clear_idle = true; // save and clear idle slots upon starting a new task
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
Expand Down Expand Up @@ -847,7 +846,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);

std::string get_model_endpoint();
// model endpoint from env
std::string common_get_model_endpoint();

//
// Context utils
//

enum common_context_seq_rm_type {
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
};

// check if the llama_context can remove sequences
// note: clears the memory of the context
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);


//
// Batch utils
Expand Down
4 changes: 2 additions & 2 deletions common/hf-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ static nl::json api_get(const std::string & url,
static std::string get_repo_commit(const std::string & repo_id,
const std::string & token) {
try {
auto endpoint = get_model_endpoint();
auto endpoint = common_get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);

if (!json.is_object() ||
Expand Down Expand Up @@ -308,7 +308,7 @@ hf_files get_repo_files(const std::string & repo_id,
hf_files files;

try {
auto endpoint = get_model_endpoint();
auto endpoint = common_get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);

if (!json.is_array()) {
Expand Down
62 changes: 14 additions & 48 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;

bool use_ckpt = false;
struct common_speculative_checkpoint ckpt;
bool use_checkpoint;

common_sampler * smpl;

Expand All @@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt,
llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements,
bool use_checkpoint)
bool use_ckpt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, use_checkpoint(use_checkpoint)
, use_ckpt(use_ckpt)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr;
Expand Down Expand Up @@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state {
}

void begin(const llama_tokens & prompt) override {
if (use_checkpoint && ckpt.size() > 0) {
if (use_ckpt && ckpt.size() > 0) {
// delete checkpoint
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
Expand Down Expand Up @@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state {

LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
__func__, reuse_i, reuse_n);
reuse_i = 0;
Expand All @@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state {
result.clear();
result.reserve(params.n_max);

bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
Expand Down Expand Up @@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state {
}

if (reuse_n < (int) prompt_dft.size() || do_restore) {
if (use_checkpoint) {
if (use_ckpt) {
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
Expand Down Expand Up @@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
return it->second;
}

common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
auto * mem = llama_get_memory(ctx_tgt);
if (mem == nullptr) {
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
}

common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;

llama_memory_clear(mem, true);

// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);

int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
goto done;
}

// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
goto done;
}

done:
llama_memory_clear(mem, true);
llama_synchronize(ctx_tgt);

return res;
}

// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(
Expand Down Expand Up @@ -1022,11 +986,13 @@ common_speculative * common_speculative_init(
case COMMON_SPECULATIVE_TYPE_NONE:
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;

impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .use_ckpt = */ use_ckpt
));
break;
}
Expand Down
10 changes: 0 additions & 10 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);

enum common_speculative_compat_type {
COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0,
COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1,
COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2,
};

// check if the llama_context is compatible for speculative decoding
// note: clears the memory of the context
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt);

common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
Expand Down
24 changes: 12 additions & 12 deletions ggml/src/ggml-sycl/mmvq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK4_0 == 0);
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
constexpr size_t num_subgroups = WARP_SIZE;
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;

const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
Expand Down Expand Up @@ -682,9 +682,9 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK8_0 == 0);
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
constexpr size_t num_subgroups = WARP_SIZE;
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;

const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
Expand Down Expand Up @@ -798,9 +798,9 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);

const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
constexpr size_t num_subgroups = WARP_SIZE;
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;

const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
Expand Down Expand Up @@ -842,9 +842,9 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
constexpr size_t num_subgroups = WARP_SIZE;
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;

const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
Expand Down
Loading
Loading