Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions common/jinja/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,4 +954,50 @@ value keyword_argument_expression::execute_impl(context & ctx) {
return mk_val<value_kwarg>(k, v);
}

std::string runtime::debug_dump_program(const program & prog, const std::string & src) {
std::ostringstream oss;
size_t lvl = 0;
context ctx;
ctx.src.reset(new std::string(src));

auto indent = [](size_t lvl) -> std::string {
return std::string(lvl * 2, ' ');
};

ctx.visitor = [&](bool is_leaf, statement * node, std::vector<visitor_pair> children) {
oss << indent(lvl) << node->type() << ":\n";
lvl++;
if (is_leaf) {
const auto & pos = node->pos;
oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n";
std::string snippet = peak_source(src, pos);
string_replace_all(snippet, "\n", "\n" + indent(lvl));
oss << indent(lvl) << snippet << "\n";
} else {
for (auto & [label, children_vec] : children) {
oss << indent(lvl) << label << ":\n";
lvl++;
if (children_vec.empty()) {
oss << indent(lvl) << "<empty>\n\n";
} else {
for (auto * child : children_vec) {
if (!child) {
continue;
}
child->visit(ctx);
}
}
lvl--;
}
}
lvl--;
};

for (const auto & stmt : prog.body) {
stmt->visit(ctx);
}

return oss.str();
}

} // namespace jinja
127 changes: 127 additions & 0 deletions common/jinja/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) {
// not thread-safe
void enable_debug(bool enable);

// for visiting AST nodes
// function signature: void(bool is_leaf, statement * node, pair of <label, children>)
using visitor_pair = std::pair<std::string, std::vector<statement *>>;
using visitor_fn = std::function<void(bool, statement *, std::vector<visitor_pair>)>;

struct context {
std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
std::time_t current_time; // for functions that need current time

bool is_get_stats = false; // whether to collect stats

visitor_fn visitor;

// src is optional, used for error reporting
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
env = mk_val<value_object>();
Expand Down Expand Up @@ -99,13 +106,23 @@ struct context {
value_object env;
};

// utils for visiting AST nodes
static std::vector<statement *> stmts_to_ptr(const statements & stmts) {
std::vector<statement *> children;
for (const auto & stmt : stmts) {
children.push_back(stmt.get());
}
return children;
}

/**
* Base class for all nodes in the AST.
*/
struct statement {
size_t pos; // position in source, for debugging
virtual ~statement() = default;
virtual std::string type() const { return "Statement"; }
virtual void visit(context & ctx) { ctx.visitor(true, this, {}); }

// execute_impl must be overridden by derived classes
virtual value execute_impl(context &) { throw_exec_error(); }
Expand Down Expand Up @@ -166,6 +183,13 @@ struct if_statement : public statement {

std::string type() const override { return "If"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"test", {test.get()}},
{"body", stmts_to_ptr(body)},
{"alternate", stmts_to_ptr(alternate)}
});
}
};

struct identifier;
Expand All @@ -190,6 +214,14 @@ struct for_statement : public statement {

std::string type() const override { return "For"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"loopvar", {loopvar.get()}},
{"iterable", {iterable.get()}},
{"body", stmts_to_ptr(body)},
{"default_block", stmts_to_ptr(default_block)}
});
}
};

struct break_statement : public statement {
Expand Down Expand Up @@ -241,6 +273,13 @@ struct set_statement : public statement {

std::string type() const override { return "Set"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"assignee", {assignee.get()}},
{"value", {val.get()}},
{"body", stmts_to_ptr(body)}
});
}
};

struct macro_statement : public statement {
Expand All @@ -256,6 +295,13 @@ struct macro_statement : public statement {

std::string type() const override { return "Macro"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"name", {name.get()}},
{"args", stmts_to_ptr(args)},
{"body", stmts_to_ptr(body)}
});
}
};

struct comment_statement : public statement {
Expand Down Expand Up @@ -289,6 +335,12 @@ struct member_expression : public expression {
}
std::string type() const override { return "MemberExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"object", {object.get()}},
{"property", {property.get()}}
});
}
};

struct call_expression : public expression {
Expand All @@ -302,6 +354,12 @@ struct call_expression : public expression {
}
std::string type() const override { return "CallExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"callee", {callee.get()}},
{"args", stmts_to_ptr(args)}
});
}
};

/**
Expand Down Expand Up @@ -405,6 +463,12 @@ struct binary_expression : public expression {
}
std::string type() const override { return "BinaryExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"left", {left.get()}},
{"right", {right.get()}}
});
}
};

/**
Expand All @@ -431,6 +495,12 @@ struct filter_expression : public expression {

std::string type() const override { return "FilterExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"operand", {operand.get()}},
{"filter", {filter.get()}}
});
}
};

struct filter_statement : public statement {
Expand All @@ -443,6 +513,12 @@ struct filter_statement : public statement {
}
std::string type() const override { return "FilterStatement"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"filter", {filter.get()}},
{"body", stmts_to_ptr(body)}
});
}
};

/**
Expand All @@ -468,6 +544,12 @@ struct select_expression : public expression {
}
return lhs->execute_impl(ctx);
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"lhs", {lhs.get()}},
{"test", {test.get()}}
});
}
};

/**
Expand All @@ -486,6 +568,12 @@ struct test_expression : public expression {
}
std::string type() const override { return "TestExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"operand", {operand.get()}},
{"test", {test.get()}}
});
}
};

/**
Expand All @@ -501,6 +589,11 @@ struct unary_expression : public expression {
}
std::string type() const override { return "UnaryExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"argument", {argument.get()}}
});
}
};

struct slice_expression : public expression {
Expand All @@ -518,6 +611,13 @@ struct slice_expression : public expression {
[[noreturn]] value execute_impl(context &) override {
throw std::runtime_error("must be handled by MemberExpression");
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"start_expr", {start_expr.get()}},
{"stop_expr", {stop_expr.get()}},
{"step_expr", {step_expr.get()}}
});
}
};

struct keyword_argument_expression : public expression {
Expand All @@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression {
}
std::string type() const override { return "KeywordArgumentExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"key", {key.get()}},
{"val", {val.get()}}
});
}
};

struct spread_expression : public expression {
Expand All @@ -539,6 +645,11 @@ struct spread_expression : public expression {
chk_type<expression>(this->argument);
}
std::string type() const override { return "SpreadExpression"; }
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"argument", {argument.get()}}
});
}
};

struct call_statement : public statement {
Expand All @@ -553,6 +664,13 @@ struct call_statement : public statement {
}
std::string type() const override { return "CallStatement"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"call", {call.get()}},
{"caller_args", stmts_to_ptr(caller_args)},
{"body", stmts_to_ptr(body)}
});
}
};

struct ternary_expression : public expression {
Expand All @@ -575,6 +693,13 @@ struct ternary_expression : public expression {
return false_expr->execute(ctx);
}
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"condition", {condition.get()}},
{"true_expr", {true_expr.get()}},
{"false_expr", {false_expr.get()}}
});
}
};

struct raised_exception : public std::exception {
Expand Down Expand Up @@ -648,6 +773,8 @@ struct runtime {
}
return parts;
}

static std::string debug_dump_program(const program & prog, const std::string & src);
};

} // namespace jinja
20 changes: 17 additions & 3 deletions tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using json = nlohmann::ordered_json;
static int main_automated_tests(void);

static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false);
static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = "");
static void run_single(const std::string& contents, json input, bool use_common = false, bool dump_prog = false, const std::string & output_path = "");

static std::string HELP = R"(
Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
Expand All @@ -35,6 +35,7 @@ Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
--json <path> Path to the JSON input file.
--stop-on-first-fail Stop testing on the first failure (default: false).
--no-common Use direct Jinja engine instead of common chat templates (default: use common).
--dump-prog Dump the parsed program for debugging (only for single template runs).
--output <path> Path to output results (only for single template runs).
If PATH_TO_TEMPLATE is a file, runs that single template.
If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
Expand Down Expand Up @@ -118,6 +119,7 @@ int main(int argc, char ** argv) {
std::string & json_to_use = DEFAULT_JSON;
bool stop_on_first_fail = false;
bool use_common = true;
bool dump_prog = false;

for (size_t i = 1; i < args.size(); i++) {
if (args[i] == "--help" || args[i] == "-h") {
Expand All @@ -136,6 +138,8 @@ int main(int argc, char ** argv) {
i++;
} else if (args[i] == "--no-common") {
use_common = false;
} else if (args[i] == "--dump-prog") {
dump_prog = true;
} else if (tmpl_path.empty()) {
tmpl_path = args[i];
} else {
Expand Down Expand Up @@ -172,7 +176,7 @@ int main(int argc, char ** argv) {
std::string contents = std::string(
std::istreambuf_iterator<char>(infile),
std::istreambuf_iterator<char>());
run_single(contents, input_json, use_common, output_path);
run_single(contents, input_json, use_common, dump_prog, output_path);
} else {
std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n";
return 1;
Expand Down Expand Up @@ -276,11 +280,21 @@ static jinja::value_string format_using_direct_engine(
}


void run_single(const std::string& contents, json input, bool use_common, const std::string & output_path) {
void run_single(const std::string& contents, json input, bool use_common, bool dump_prog, const std::string & output_path) {
jinja::enable_debug(true);

jinja::value_string output_parts;

if (dump_prog) {
jinja::lexer lexer;
auto lexer_res = lexer.tokenize(contents);
jinja::program ast = jinja::parse_from_tokens(lexer_res);
std::string prog_dump = jinja::runtime::debug_dump_program(ast, contents);
std::cout << "\n=== DUMPED PROGRAM ===\n";
std::cout << prog_dump << "\n";
return;
}

if (use_common) {
std::string bos_token = "<s>";
std::string eos_token = "</s>";
Expand Down
Loading