diff --git a/common/arg.cpp b/common/arg.cpp index c289ff713da0..0fc94e553215 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3296,6 +3296,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.reasoning_budget_message = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE")); + add_opt(common_arg( + {"--reasoning-preserve"}, + {"--no-reasoning-preserve"}, + "preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n" + "compatible with certain templates having 'supports_preserve_reasoning' capability\n" + "example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking", + [](common_params & params, bool value) { + if (value) { + params.default_template_kwargs["preserve_reasoning"] = "true"; + } else { + params.default_template_kwargs["preserve_reasoning"] = "false"; + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/chat.cpp b/common/chat.cpp index 0cee80434ece..efba80e45914 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl( if (inputs.add_generation_prompt) { inp["add_generation_prompt"] = true; } + if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) { + bool enabled = inp["preserve_reasoning"].get(); + jinja::caps_apply_preserve_reasoning(ctx, enabled); + } jinja::global_from_json(ctx, inp, inputs.mark_input); diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index ead864763e1d..ae378ebd4fd0 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -16,22 +16,34 @@ using json = nlohmann::ordered_json; namespace jinja { using caps_json_fn = std::function; -using caps_analyze_fn = std::function; +using caps_ctx_fn = std::function; +using caps_analyze_fn = std::function; + +void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) { + ctx.set_val("preserve_thinking", mk_val(enabled)); + ctx.set_val("clear_thinking", mk_val(!enabled)); + ctx.set_val("truncate_history_thinking", mk_val(!enabled)); +} static void caps_try_execute(jinja::program & prog, const caps_json_fn & messages_fn, + const caps_ctx_fn & ctx_fn, const caps_json_fn & tools_fn, const caps_analyze_fn & analyze_fn) { context ctx; ctx.is_get_stats = true; jinja::global_from_json(ctx, json{ {"messages", messages_fn()}, - {"tools", tools_fn()}, + {"tools", tools_fn ? tools_fn() : json::array()}, {"bos_token", ""}, {"eos_token", ""}, {"add_generation_prompt", true} }, true); + if (ctx_fn) { + ctx_fn(ctx); + } + auto messages = ctx.get_val("messages"); auto tools = ctx.get_val("tools"); @@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog, // ignore exceptions during capability analysis } - analyze_fn(success, messages, tools); + analyze_fn(success, messages, tools, result); } // for debugging only @@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) { } }); }, - [&]() { - // tools - return json{nullptr}; - }, - [&](bool success, value & messages, value &) { + nullptr, // ctx_fn + nullptr, // tools_fn + [&](bool success, value & messages, value &, const std::string &) { auto & content = messages->at(0)->at("content"); caps_print_stats(content, "messages[0].content"); if (has_op(content, "selectattr") || has_op(content, "array_access")) { @@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) { }, }); }, - [&]() { - // tools - return json::array(); - }, - [&](bool, value & messages, value &) { + nullptr, // ctx_fn + nullptr, // tools_fn + [&](bool, value & messages, value &, const std::string &) { auto & content = messages->at(0)->at("content"); caps_print_stats(content, "messages[0].content"); if (!content->stats.used) { @@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) { }, }); }, + nullptr, // ctx_fn [&]() { // tools return json::array({ @@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) { }, }); }, - [&](bool success, value & messages, value & tools) { + [&](bool success, value & messages, value & tools, const std::string &) { if (!success) { return; // Nothing can be inferred } @@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) { }, }); }, + nullptr, // ctx_fn [&]() { // tools return json::array({ @@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) { }, }); }, - [&](bool success, value & messages, value & tools) { + [&](bool success, value & messages, value & tools, const std::string &) { if (!success) { result.supports_tool_calls = false; result.supports_tools = false; @@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) { }, }); }, + nullptr, // ctx_fn [&]() { // tools return json::array({ @@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) { }, }); }, - [&](bool success, value & messages, value & /*tools*/) { + [&](bool success, value & messages, value &, const std::string &) { if (!success) { result.supports_parallel_tool_calls = false; return; @@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) { JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning"); // case: preserve reasoning content in chat history + const std::string reasoning_placeholder = ""; caps_try_execute( prog, [&]() { // messages return json::array({ + { + {"role", "user"}, + {"content", "User message"} + }, + { + {"role", "assistant"}, + {"content", "Assistant message"}, + // check of reasoning_content deeper in the history, not just the last assistant message + {"reasoning_content", reasoning_placeholder} + }, { {"role", "user"}, {"content", "User message"} @@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) { }, }); }, - [&]() { - // tools - return json::array(); + [&](context & ctx) { + caps_apply_preserve_reasoning(ctx, true); }, - [&](bool, value & messages, value &) { - auto & content = messages->at(1)->at("reasoning_content"); - caps_print_stats(content, "messages[1].reasoning_content"); - if (content->stats.used) { + nullptr, // tools_fn + [&](bool, value &, value &, const std::string & output) { + // note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result + if (output.find(reasoning_placeholder) != std::string::npos) { result.supports_preserve_reasoning = true; } } diff --git a/common/jinja/caps.h b/common/jinja/caps.h index 93a7fe09260e..a290cd7da627 100644 --- a/common/jinja/caps.h +++ b/common/jinja/caps.h @@ -12,7 +12,9 @@ struct caps { bool supports_tool_calls = true; bool supports_system_role = true; bool supports_parallel_tool_calls = true; - bool supports_preserve_reasoning = false; // support assistant message with reasoning_content + + // supports preserve reasoning trace in the full history, not just the last assistant message + bool supports_preserve_reasoning = false; // one of the 2 content capabilities must be true bool supports_string_content = true; @@ -29,4 +31,6 @@ struct caps { caps caps_get(jinja::program & prog); +void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled); + } // namespace jinja diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index bb1c236cbf77..ed8b036b6495 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1538,6 +1538,19 @@ struct server_context_impl { /* media_path */ params_base.media_path, /* force_pure_content */ params_base.force_pure_content_parser }; + + { + auto caps = common_chat_templates_get_caps(chat_params.tmpls.get()); + auto it = params_base.default_template_kwargs.find("preserve_reasoning"); + bool supported = caps.at("supports_preserve_reasoning"); + bool enabled = it != params_base.default_template_kwargs.end(); + if (supported && !enabled) { + SRV_INF("%s", "chat template supports preserving reasoning, consider enabling it via --reasoning-preserve\n"); + } + if (!supported && enabled) { + SRV_WRN("%s", "chat template does NOT support preserving reasoning, --reasoning-preserve has no effect\n"); + } + } } return true;