diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index e9520f3d1a3..44428cc9595 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -380,6 +380,7 @@ struct clip_ctx { if (backend_cpu != backend) { ggml_backend_free(backend_cpu); } + clip_image_size_free(load_image_size); } }; @@ -1618,6 +1619,12 @@ struct clip_image_f32 * clip_image_f32_init() { return new clip_image_f32(); } +void clip_image_size_free(struct clip_image_size * load_image_size) { + if (load_image_size == nullptr) { + return; + } + delete load_image_size; +} void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { @@ -2270,6 +2277,9 @@ ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { } void clip_free(clip_ctx * ctx) { + if (ctx == nullptr) { + return; + } delete ctx; } diff --git a/examples/llava/clip.h b/examples/llava/clip.h index d806465bf68..87aa61574b1 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -77,6 +77,7 @@ CLIP_API struct clip_image_size * clip_image_size_init(); CLIP_API struct clip_image_u8 * clip_image_u8_init (); CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API void clip_image_size_free (struct clip_image_size * img_size); CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 760c3646433..1bf1ee876b4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1705,6 +1705,8 @@ struct server_queue { }; struct server_response { + bool running = true; + // for keeping track of all tasks waiting for the result std::unordered_set waiting_task_ids; @@ -1759,6 +1761,10 @@ struct server_response { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } return !queue_results.empty(); }); @@ -1789,6 +1795,10 @@ struct server_response { } std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (!running) { + SRV_DBG("%s : queue result stop\n", __func__); + std::terminate(); // we cannot return here since the caller is HTTP code + } if (cr_res == std::cv_status::timeout) { return nullptr; } @@ -1818,6 +1828,12 @@ struct server_response { } } } + + // terminate the waiting loop + void terminate() { + running = false; + condition_results.notify_all(); + } }; struct server_context { @@ -4491,9 +4507,10 @@ int main(int argc, char ** argv) { svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; // clean up function, to be called before exit - auto clean_up = [&svr]() { + auto clean_up = [&svr, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); svr->stop(); + ctx_server.queue_results.terminate(); llama_backend_free(); }; @@ -4534,7 +4551,7 @@ int main(int argc, char ** argv) { if (!ctx_server.load_model(params)) { clean_up(); - // t.join(); // FIXME: see below + t.join(); LOG_ERR("%s: exiting due to model loading error\n", __func__); return 1; } @@ -4582,7 +4599,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.start_loop(); clean_up(); - // t.join(); // FIXME: http thread may stuck if there is an on-going request. we don't need to care about this for now as the HTTP connection will already be closed at this point, but it's better to fix this + t.join(); return 0; } diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 8b0eb42b092..0feb452ccfc 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -49,6 +49,26 @@ def test_embedding_multiple(): assert len(d['embedding']) > 1 +def test_embedding_multiple_with_fa(): + server = ServerPreset.bert_bge_small_with_fa() + server.pooling = 'last' + server.start() + # one of these should trigger the FA branch (i.e. context size % 256 == 0) + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "a "*253, + "b "*254, + "c "*255, + "d "*256, + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + @pytest.mark.parametrize( "input,is_multi_prompt", [ diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 30aa8660950..4dc2062a8e5 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -323,6 +323,21 @@ def bert_bge_small() -> ServerProcess: server.server_embeddings = True return server + @staticmethod + def bert_bge_small_with_fa() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 1024 + server.n_batch = 300 + server.n_ubatch = 300 + server.n_slots = 2 + server.fa = True + server.seed = 42 + server.server_embeddings = True + return server + @staticmethod def tinyllama_infill() -> ServerProcess: server = ServerProcess() diff --git a/examples/server_embd.py b/examples/server_embd.py index 0e34c6ceab9..f8b0ffecd8f 100644 --- a/examples/server_embd.py +++ b/examples/server_embd.py @@ -15,7 +15,7 @@ async def main(): model_url = "http://127.0.0.1:6900" responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( url= f"{model_url}/embedding", - json= {"content": str(0)*1024} + json= {"content": "a "*1022} ) for i in range(n)]) for response in responses: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7a8d5ac6fd9..f63656be54f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; - GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); - GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); + GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( vs = expf(s - M); } - v_to_float(v_data, V32, DV); - // V += v*expf(s - M) - ggml_vec_mad_f32(DV, VKQ32, V32, vs); + if (v_to_float) { + v_to_float(v_data, V32, DV); + ggml_vec_mad_f32(DV, VKQ32, V32, vs); + } else { + // V is F32 + ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs); + } } S = S*ms + vs; // scale and increment sum with partial sum diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 456e1fd994c..f226826020a 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: + if (op->src[0]->ne[0] == 32) { + // head size == 32 (e.g. bert-bge-small) + // TODO: not sure if it is worth adding kernels for this size + return false; + } if (op->src[1]->type != op->src[2]->type) { return false; } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c3469177e09..cd955d63bc3 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha( v = ggml_transpose(ctx0, v); } + // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn) + if (k->type == GGML_TYPE_F32) { + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + } + + if (v->type == GGML_TYPE_F32) { + v = ggml_cast(ctx0, v, GGML_TYPE_F16); + } + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);