Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3296,6 +3296,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--reasoning-preserve"},
{"--no-reasoning-preserve"},
"preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n"
"compatible with certain templates having 'supports_preserve_reasoning' capability\n"
"example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking",
[](common_params & params, bool value) {
if (value) {
params.default_template_kwargs["preserve_reasoning"] = "true";
} else {
params.default_template_kwargs["preserve_reasoning"] = "false";
}
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
Expand Down
4 changes: 4 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl(
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) {
bool enabled = inp["preserve_reasoning"].get<bool>();
jinja::caps_apply_preserve_reasoning(ctx, enabled);
}

jinja::global_from_json(ctx, inp, inputs.mark_input);

Expand Down
67 changes: 44 additions & 23 deletions common/jinja/caps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,34 @@ using json = nlohmann::ordered_json;
namespace jinja {

using caps_json_fn = std::function<json()>;
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
using caps_ctx_fn = std::function<void(context &)>;
using caps_analyze_fn = std::function<void(bool, value &, value &, const std::string &)>;

void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) {
ctx.set_val("preserve_thinking", mk_val<value_bool>(enabled));
ctx.set_val("clear_thinking", mk_val<value_bool>(!enabled));
ctx.set_val("truncate_history_thinking", mk_val<value_bool>(!enabled));
}

static void caps_try_execute(jinja::program & prog,
const caps_json_fn & messages_fn,
const caps_ctx_fn & ctx_fn,
const caps_json_fn & tools_fn,
const caps_analyze_fn & analyze_fn) {
context ctx;
ctx.is_get_stats = true;
jinja::global_from_json(ctx, json{
{"messages", messages_fn()},
{"tools", tools_fn()},
{"tools", tools_fn ? tools_fn() : json::array()},
{"bos_token", ""},
{"eos_token", ""},
{"add_generation_prompt", true}
}, true);

if (ctx_fn) {
ctx_fn(ctx);
}

auto messages = ctx.get_val("messages");
auto tools = ctx.get_val("tools");

Expand All @@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog,
// ignore exceptions during capability analysis
}

analyze_fn(success, messages, tools);
analyze_fn(success, messages, tools, result);
}

// for debugging only
Expand Down Expand Up @@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) {
}
});
},
[&]() {
// tools
return json{nullptr};
},
[&](bool success, value & messages, value &) {
nullptr, // ctx_fn
nullptr, // tools_fn
[&](bool success, value & messages, value &, const std::string &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
Expand Down Expand Up @@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&]() {
// tools
return json::array();
},
[&](bool, value & messages, value &) {
nullptr, // ctx_fn
nullptr, // tools_fn
[&](bool, value & messages, value &, const std::string &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (!content->stats.used) {
Expand Down Expand Up @@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
Expand All @@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & tools) {
[&](bool success, value & messages, value & tools, const std::string &) {
if (!success) {
return; // Nothing can be inferred
}
Expand Down Expand Up @@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
Expand All @@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & tools) {
[&](bool success, value & messages, value & tools, const std::string &) {
if (!success) {
result.supports_tool_calls = false;
result.supports_tools = false;
Expand Down Expand Up @@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
Expand All @@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & /*tools*/) {
[&](bool success, value & messages, value &, const std::string &) {
if (!success) {
result.supports_parallel_tool_calls = false;
return;
Expand All @@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) {
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");

// case: preserve reasoning content in chat history
const std::string reasoning_placeholder = "<REASONING_CONTENT_PLACEHOLDER>";
caps_try_execute(
prog,
[&]() {
// messages
return json::array({
{
{"role", "user"},
{"content", "User message"}
},
{
{"role", "assistant"},
{"content", "Assistant message"},
// check of reasoning_content deeper in the history, not just the last assistant message
{"reasoning_content", reasoning_placeholder}
},
{
{"role", "user"},
{"content", "User message"}
Expand All @@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&]() {
// tools
return json::array();
[&](context & ctx) {
caps_apply_preserve_reasoning(ctx, true);
},
[&](bool, value & messages, value &) {
auto & content = messages->at(1)->at("reasoning_content");
caps_print_stats(content, "messages[1].reasoning_content");
if (content->stats.used) {
nullptr, // tools_fn
[&](bool, value &, value &, const std::string & output) {
// note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result
if (output.find(reasoning_placeholder) != std::string::npos) {
result.supports_preserve_reasoning = true;
}
}
Expand Down
6 changes: 5 additions & 1 deletion common/jinja/caps.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ struct caps {
bool supports_tool_calls = true;
bool supports_system_role = true;
bool supports_parallel_tool_calls = true;
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content

// supports preserve reasoning trace in the full history, not just the last assistant message
bool supports_preserve_reasoning = false;

// one of the 2 content capabilities must be true
bool supports_string_content = true;
Expand All @@ -29,4 +31,6 @@ struct caps {

caps caps_get(jinja::program & prog);

void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled);

} // namespace jinja
13 changes: 13 additions & 0 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,19 @@ struct server_context_impl {
/* media_path */ params_base.media_path,
/* force_pure_content */ params_base.force_pure_content_parser
};

{
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
auto it = params_base.default_template_kwargs.find("preserve_reasoning");
bool supported = caps.at("supports_preserve_reasoning");
bool enabled = it != params_base.default_template_kwargs.end();
if (supported && !enabled) {
SRV_INF("%s", "chat template supports preserving reasoning, consider enabling it via --reasoning-preserve\n");
}
Comment on lines +1547 to +1549

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I suppose it would be useful if this reflected the default state.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all known templates have it disabled by default, so I think we are temporary ok here

if (!supported && enabled) {
SRV_WRN("%s", "chat template does NOT support preserving reasoning, --reasoning-preserve has no effect\n");
}
}
}

return true;
Expand Down