Skip to content

Commit 6ce863c

Browse files
authored
server: prevent data race from HTTP threads (ggml-org#18263)
* server: prevent data race from HTTP threads * fix params * fix default_generation_settings * nits: make handle_completions_impl looks less strange * stricter const * fix GGML_ASSERT(idx < states.size()) * move index to be managed by server_response_reader * http: make sure req & res lifecycle are tied together * fix compile * fix index handling buggy * fix data race for lora endpoint * nits: fix shadow variable * nits: revert redundant changes * nits: correct naming for json_webui_settings
1 parent 3997c78 commit 6ce863c

11 files changed

Lines changed: 459 additions & 366 deletions

tools/cli/cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ int main(int argc, char ** argv) {
216216
ctx_cli.ctx_server.start_loop();
217217
});
218218

219-
auto inf = ctx_cli.ctx_server.get_info();
219+
auto inf = ctx_cli.ctx_server.get_meta();
220220
std::string modalities = "text";
221221
if (inf.has_inp_image) {
222222
modalities += ", vision";

tools/server/server-common.cpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -115,26 +115,14 @@ bool lora_should_clear_cache(
115115
!lora_all_alora(next));
116116
}
117117

118-
std::vector<common_adapter_lora_info> parse_lora_request(
119-
const std::vector<common_adapter_lora_info> & lora_base,
120-
const json & data) {
121-
std::vector<common_adapter_lora_info> lora(lora_base);
122-
int max_idx = lora.size();
123-
124-
// clear existing value
125-
for (auto & entry : lora) {
126-
entry.scale = 0.0f;
127-
}
118+
std::map<int, float> parse_lora_request(const json & data) {
119+
std::map<int, float> lora;
128120

129121
// set value
130122
for (const auto & entry : data) {
131123
int id = json_value(entry, "id", -1);
132124
float scale = json_value(entry, "scale", 0.0f);
133-
if (0 <= id && id < max_idx) {
134-
lora[id].scale = scale;
135-
} else {
136-
throw std::runtime_error("invalid adapter id");
137-
}
125+
lora[id] = scale;
138126
}
139127

140128
return lora;
@@ -1435,7 +1423,7 @@ std::string safe_json_to_str(const json & data) {
14351423

14361424
// TODO: reuse llama_detokenize
14371425
template <class Iter>
1438-
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
1426+
static std::string tokens_to_str(const llama_vocab * ctx, Iter begin, Iter end) {
14391427
std::string ret;
14401428
for (; begin != end; ++begin) {
14411429
ret += common_token_to_piece(ctx, *begin);
@@ -1445,7 +1433,12 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
14451433
}
14461434

14471435
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) {
1448-
return tokens_to_str(ctx, tokens.begin(), tokens.end());
1436+
auto model = llama_get_model(ctx);
1437+
return tokens_to_str(llama_model_get_vocab(model), tokens.begin(), tokens.end());
1438+
}
1439+
1440+
std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens) {
1441+
return tokens_to_str(vocab, tokens.begin(), tokens.end());
14491442
}
14501443

14511444
// format incomplete utf-8 multibyte character for output

tools/server/server-common.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ bool lora_should_clear_cache(
107107
const std::vector<common_adapter_lora_info> & current,
108108
const std::vector<common_adapter_lora_info> & next);
109109

110-
std::vector<common_adapter_lora_info> parse_lora_request(
111-
const std::vector<common_adapter_lora_info> & lora_base,
112-
const json & data);
110+
std::map<int, float> parse_lora_request(const json & data);
113111

114112
bool are_lora_equal(
115113
const std::vector<common_adapter_lora_info> & l1,
@@ -325,6 +323,7 @@ std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int i
325323
std::string safe_json_to_str(const json & data);
326324

327325
std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens);
326+
std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens);
328327

329328
// format incomplete utf-8 multibyte character for output
330329
std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token);

0 commit comments

Comments
 (0)