diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index f98cb0876f19..474129df2c4c 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -954,4 +954,50 @@ value keyword_argument_expression::execute_impl(context & ctx) { return mk_val(k, v); } +std::string runtime::debug_dump_program(const program & prog, const std::string & src) { + std::ostringstream oss; + size_t lvl = 0; + context ctx; + ctx.src.reset(new std::string(src)); + + auto indent = [](size_t lvl) -> std::string { + return std::string(lvl * 2, ' '); + }; + + ctx.visitor = [&](bool is_leaf, statement * node, std::vector children) { + oss << indent(lvl) << node->type() << ":\n"; + lvl++; + if (is_leaf) { + const auto & pos = node->pos; + oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n"; + std::string snippet = peak_source(src, pos); + string_replace_all(snippet, "\n", "\n" + indent(lvl)); + oss << indent(lvl) << snippet << "\n"; + } else { + for (auto & [label, children_vec] : children) { + oss << indent(lvl) << label << ":\n"; + lvl++; + if (children_vec.empty()) { + oss << indent(lvl) << "\n\n"; + } else { + for (auto * child : children_vec) { + if (!child) { + continue; + } + child->visit(ctx); + } + } + lvl--; + } + } + lvl--; + }; + + for (const auto & stmt : prog.body) { + stmt->visit(ctx); + } + + return oss.str(); +} + } // namespace jinja diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index 37b4c35cac8f..0884a15922bb 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) { // not thread-safe void enable_debug(bool enable); +// for visiting AST nodes +// function signature: void(bool is_leaf, statement * node, pair of ) +using visitor_pair = std::pair>; +using visitor_fn = std::function)>; + struct context { std::shared_ptr src; // for debugging; use shared_ptr to avoid copying on scope creation std::time_t current_time; // for functions that need current time bool is_get_stats = false; // whether to collect stats + visitor_fn visitor; + // src is optional, used for error reporting context(std::string src = "") : src(std::make_shared(std::move(src))) { env = mk_val(); @@ -99,6 +106,15 @@ struct context { value_object env; }; +// utils for visiting AST nodes +static std::vector stmts_to_ptr(const statements & stmts) { + std::vector children; + for (const auto & stmt : stmts) { + children.push_back(stmt.get()); + } + return children; +} + /** * Base class for all nodes in the AST. */ @@ -106,6 +122,7 @@ struct statement { size_t pos; // position in source, for debugging virtual ~statement() = default; virtual std::string type() const { return "Statement"; } + virtual void visit(context & ctx) { ctx.visitor(true, this, {}); } // execute_impl must be overridden by derived classes virtual value execute_impl(context &) { throw_exec_error(); } @@ -166,6 +183,13 @@ struct if_statement : public statement { std::string type() const override { return "If"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"test", {test.get()}}, + {"body", stmts_to_ptr(body)}, + {"alternate", stmts_to_ptr(alternate)} + }); + } }; struct identifier; @@ -190,6 +214,14 @@ struct for_statement : public statement { std::string type() const override { return "For"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"loopvar", {loopvar.get()}}, + {"iterable", {iterable.get()}}, + {"body", stmts_to_ptr(body)}, + {"default_block", stmts_to_ptr(default_block)} + }); + } }; struct break_statement : public statement { @@ -241,6 +273,13 @@ struct set_statement : public statement { std::string type() const override { return "Set"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"assignee", {assignee.get()}}, + {"value", {val.get()}}, + {"body", stmts_to_ptr(body)} + }); + } }; struct macro_statement : public statement { @@ -256,6 +295,13 @@ struct macro_statement : public statement { std::string type() const override { return "Macro"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"name", {name.get()}}, + {"args", stmts_to_ptr(args)}, + {"body", stmts_to_ptr(body)} + }); + } }; struct comment_statement : public statement { @@ -289,6 +335,12 @@ struct member_expression : public expression { } std::string type() const override { return "MemberExpression"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"object", {object.get()}}, + {"property", {property.get()}} + }); + } }; struct call_expression : public expression { @@ -302,6 +354,12 @@ struct call_expression : public expression { } std::string type() const override { return "CallExpression"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"callee", {callee.get()}}, + {"args", stmts_to_ptr(args)} + }); + } }; /** @@ -405,6 +463,12 @@ struct binary_expression : public expression { } std::string type() const override { return "BinaryExpression"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"left", {left.get()}}, + {"right", {right.get()}} + }); + } }; /** @@ -431,6 +495,12 @@ struct filter_expression : public expression { std::string type() const override { return "FilterExpression"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"operand", {operand.get()}}, + {"filter", {filter.get()}} + }); + } }; struct filter_statement : public statement { @@ -443,6 +513,12 @@ struct filter_statement : public statement { } std::string type() const override { return "FilterStatement"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"filter", {filter.get()}}, + {"body", stmts_to_ptr(body)} + }); + } }; /** @@ -468,6 +544,12 @@ struct select_expression : public expression { } return lhs->execute_impl(ctx); } + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"lhs", {lhs.get()}}, + {"test", {test.get()}} + }); + } }; /** @@ -486,6 +568,12 @@ struct test_expression : public expression { } std::string type() const override { return "TestExpression"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"operand", {operand.get()}}, + {"test", {test.get()}} + }); + } }; /** @@ -501,6 +589,11 @@ struct unary_expression : public expression { } std::string type() const override { return "UnaryExpression"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"argument", {argument.get()}} + }); + } }; struct slice_expression : public expression { @@ -518,6 +611,13 @@ struct slice_expression : public expression { [[noreturn]] value execute_impl(context &) override { throw std::runtime_error("must be handled by MemberExpression"); } + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"start_expr", {start_expr.get()}}, + {"stop_expr", {stop_expr.get()}}, + {"step_expr", {step_expr.get()}} + }); + } }; struct keyword_argument_expression : public expression { @@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression { } std::string type() const override { return "KeywordArgumentExpression"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"key", {key.get()}}, + {"val", {val.get()}} + }); + } }; struct spread_expression : public expression { @@ -539,6 +645,11 @@ struct spread_expression : public expression { chk_type(this->argument); } std::string type() const override { return "SpreadExpression"; } + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"argument", {argument.get()}} + }); + } }; struct call_statement : public statement { @@ -553,6 +664,13 @@ struct call_statement : public statement { } std::string type() const override { return "CallStatement"; } value execute_impl(context & ctx) override; + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"call", {call.get()}}, + {"caller_args", stmts_to_ptr(caller_args)}, + {"body", stmts_to_ptr(body)} + }); + } }; struct ternary_expression : public expression { @@ -575,6 +693,13 @@ struct ternary_expression : public expression { return false_expr->execute(ctx); } } + void visit(context & ctx) override { + ctx.visitor(false, this, { + {"condition", {condition.get()}}, + {"true_expr", {true_expr.get()}}, + {"false_expr", {false_expr.get()}} + }); + } }; struct raised_exception : public std::exception { @@ -648,6 +773,8 @@ struct runtime { } return parts; } + + static std::string debug_dump_program(const program & prog, const std::string & src); }; } // namespace jinja diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index d971b23746c7..6a6292cd0151 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -25,7 +25,7 @@ using json = nlohmann::ordered_json; static int main_automated_tests(void); static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false); -static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = ""); +static void run_single(const std::string& contents, json input, bool use_common = false, bool dump_prog = false, const std::string & output_path = ""); static std::string HELP = R"( Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE @@ -35,6 +35,7 @@ Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE --json Path to the JSON input file. --stop-on-first-fail Stop testing on the first failure (default: false). --no-common Use direct Jinja engine instead of common chat templates (default: use common). + --dump-prog Dump the parsed program for debugging (only for single template runs). --output Path to output results (only for single template runs). If PATH_TO_TEMPLATE is a file, runs that single template. If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory. @@ -118,6 +119,7 @@ int main(int argc, char ** argv) { std::string & json_to_use = DEFAULT_JSON; bool stop_on_first_fail = false; bool use_common = true; + bool dump_prog = false; for (size_t i = 1; i < args.size(); i++) { if (args[i] == "--help" || args[i] == "-h") { @@ -136,6 +138,8 @@ int main(int argc, char ** argv) { i++; } else if (args[i] == "--no-common") { use_common = false; + } else if (args[i] == "--dump-prog") { + dump_prog = true; } else if (tmpl_path.empty()) { tmpl_path = args[i]; } else { @@ -172,7 +176,7 @@ int main(int argc, char ** argv) { std::string contents = std::string( std::istreambuf_iterator(infile), std::istreambuf_iterator()); - run_single(contents, input_json, use_common, output_path); + run_single(contents, input_json, use_common, dump_prog, output_path); } else { std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n"; return 1; @@ -276,11 +280,21 @@ static jinja::value_string format_using_direct_engine( } -void run_single(const std::string& contents, json input, bool use_common, const std::string & output_path) { +void run_single(const std::string& contents, json input, bool use_common, bool dump_prog, const std::string & output_path) { jinja::enable_debug(true); jinja::value_string output_parts; + if (dump_prog) { + jinja::lexer lexer; + auto lexer_res = lexer.tokenize(contents); + jinja::program ast = jinja::parse_from_tokens(lexer_res); + std::string prog_dump = jinja::runtime::debug_dump_program(ast, contents); + std::cout << "\n=== DUMPED PROGRAM ===\n"; + std::cout << prog_dump << "\n"; + return; + } + if (use_common) { std::string bos_token = ""; std::string eos_token = "";