Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions apps/tundra_shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
#include "TundraQLParser.h"
#include "arrow/map_union_types.hpp"
#include "common/constants.hpp"
#include "main/database.hpp"
#include "linenoise.h"
#include "common/logger.hpp"
#include "common/types.hpp"
#include "common/utils.hpp"
#include "linenoise.h"
#include "main/database.hpp"

// Tee stream class that outputs to both console and file
class TeeStream : public std::ostream {
Expand Down Expand Up @@ -525,7 +525,7 @@ class TundraQLVisitorImpl : public tundraql::TundraQLBaseVisitor {
node_type = node_alias;
}

auto query_builder = tundradb::Query::from(node_alias + ":" + node_type);
auto query_builder = tundradb::Query::match(node_alias + ":" + node_type);

for (size_t i = 0; i < edges.size(); i++) {
auto edge = edges[i];
Expand Down Expand Up @@ -807,7 +807,7 @@ class TundraQLVisitorImpl : public tundraql::TundraQLBaseVisitor {
schema_name = alias;
}

auto query_builder = tundradb::Query::from(alias + ":" + schema_name);
auto query_builder = tundradb::Query::match(alias + ":" + schema_name);

if (ctx->whereClause()) {
processWhereClause(query_builder, ctx->whereClause());
Expand Down Expand Up @@ -1082,7 +1082,7 @@ class TundraQLVisitorImpl : public tundraql::TundraQLBaseVisitor {
return qb;
}
// Mode 2: single nodePattern → build a trivial query
return tundradb::Query::from(alias + ":" + schema_name);
return tundradb::Query::match(alias + ":" + schema_name);
}();

// Build alias→schema map from the query builder's pattern
Expand Down Expand Up @@ -1456,7 +1456,7 @@ class TundraQLVisitorImpl : public tundraql::TundraQLBaseVisitor {

try {
// Build a query to find matching nodes
auto query_builder = tundradb::Query::from("n:" + node_type);
auto query_builder = tundradb::Query::match("n:" + node_type);

// Add WHERE conditions for each property
for (const auto& [prop_name, prop_value] : properties) {
Expand Down Expand Up @@ -2211,4 +2211,4 @@ int main(int argc, char* argv[]) {
g_tee_stream.reset();

return 0;
}
}
111 changes: 62 additions & 49 deletions bench/tundra_runner.cpp
Original file line number Diff line number Diff line change
@@ -1,44 +1,45 @@
#include <arrow/api.h>
#include <arrow/csv/api.h>
#include <arrow/io/api.h>

#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>

#include "common/types.hpp"
#include "main/database.hpp"
#include "query/query.hpp"
#include "common/types.hpp"

using namespace tundradb;

static arrow::Result<std::shared_ptr<arrow::Table>> read_csv(const std::string& path) {
static arrow::Result<std::shared_ptr<arrow::Table>> read_csv(
const std::string& path) {
ARROW_ASSIGN_OR_RAISE(auto input, arrow::io::ReadableFile::Open(path));
auto read_options = arrow::csv::ReadOptions::Defaults();
auto parse_options = arrow::csv::ParseOptions::Defaults();
auto convert_options = arrow::csv::ConvertOptions::Defaults();
ARROW_ASSIGN_OR_RAISE(auto reader,
arrow::csv::TableReader::Make(arrow::io::default_io_context(), input,
read_options, parse_options, convert_options));
ARROW_ASSIGN_OR_RAISE(
auto reader, arrow::csv::TableReader::Make(
arrow::io::default_io_context(), input, read_options,
parse_options, convert_options));
return reader->Read();
}

void load_data(Database& db, const std::string& users_csv,
const std::string& companies_csv,
const std::string& friend_csv,
void load_data(Database& db, const std::string& users_csv,
const std::string& companies_csv, const std::string& friend_csv,
const std::string& works_at_csv) {
auto load_start = std::chrono::high_resolution_clock::now();

// Define schemas (must include "id" field)
auto user_schema = arrow::schema({
arrow::field("name", arrow::utf8()),
auto user_schema = arrow::schema({arrow::field("name", arrow::utf8()),
arrow::field("age", arrow::int64()),
arrow::field("country", arrow::utf8())});
db.get_schema_registry()->create("User", user_schema).ValueOrDie();

auto company_schema = arrow::schema({
arrow::field("name", arrow::utf8()),
arrow::field("industry", arrow::utf8())});
auto company_schema =
arrow::schema({arrow::field("name", arrow::utf8()),
arrow::field("industry", arrow::utf8())});
db.get_schema_registry()->create("Company", company_schema).ValueOrDie();
auto users_tbl = read_csv(users_csv).ValueOrDie();
users_tbl = users_tbl->CombineChunks().ValueOrDie();
Expand All @@ -58,9 +59,12 @@ void load_data(Database& db, const std::string& users_csv,
auto name_idx = users_tbl->schema()->GetFieldIndex("name");
auto age_idx = users_tbl->schema()->GetFieldIndex("age");
auto country_idx = users_tbl->schema()->GetFieldIndex("country");
auto name_arr = std::static_pointer_cast<arrow::StringArray>(users_tbl->column(name_idx)->chunk(0));
auto age_arr = std::static_pointer_cast<arrow::Int64Array>(users_tbl->column(age_idx)->chunk(0));
auto country_arr = std::static_pointer_cast<arrow::StringArray>(users_tbl->column(country_idx)->chunk(0));
auto name_arr = std::static_pointer_cast<arrow::StringArray>(
users_tbl->column(name_idx)->chunk(0));
auto age_arr = std::static_pointer_cast<arrow::Int64Array>(
users_tbl->column(age_idx)->chunk(0));
auto country_arr = std::static_pointer_cast<arrow::StringArray>(
users_tbl->column(country_idx)->chunk(0));
for (int64_t i = 0; i < users_tbl->num_rows(); ++i) {
std::unordered_map<std::string, Value> data;
data["name"] = Value(std::string(name_arr->GetView(i)));
Expand All @@ -73,11 +77,12 @@ void load_data(Database& db, const std::string& users_csv,

// Load Companies (global ids continue after users)


auto cname_idx = companies_tbl->schema()->GetFieldIndex("name");
auto ind_idx = companies_tbl->schema()->GetFieldIndex("industry");
auto cname_arr = std::static_pointer_cast<arrow::StringArray>(companies_tbl->column(cname_idx)->chunk(0));
auto ind_arr = std::static_pointer_cast<arrow::StringArray>(companies_tbl->column(ind_idx)->chunk(0));
auto cname_arr = std::static_pointer_cast<arrow::StringArray>(
companies_tbl->column(cname_idx)->chunk(0));
auto ind_arr = std::static_pointer_cast<arrow::StringArray>(
companies_tbl->column(ind_idx)->chunk(0));
for (int64_t i = 0; i < companies_tbl->num_rows(); ++i) {
std::unordered_map<std::string, Value> data;
data["name"] = Value(std::string(cname_arr->GetView(i)));
Expand All @@ -89,41 +94,47 @@ void load_data(Database& db, const std::string& users_csv,

auto fsrc_idx = friend_tbl->schema()->GetFieldIndex("src");
auto fdst_idx = friend_tbl->schema()->GetFieldIndex("dst");
auto fsrc = std::static_pointer_cast<arrow::Int64Array>(friend_tbl->column(fsrc_idx)->chunk(0));
auto fdst = std::static_pointer_cast<arrow::Int64Array>(friend_tbl->column(fdst_idx)->chunk(0));
auto fsrc = std::static_pointer_cast<arrow::Int64Array>(
friend_tbl->column(fsrc_idx)->chunk(0));
auto fdst = std::static_pointer_cast<arrow::Int64Array>(
friend_tbl->column(fdst_idx)->chunk(0));
for (int64_t i = 0; i < friend_tbl->num_rows(); ++i) {
db.connect(fsrc->Value(i), "FRIEND", fdst->Value(i)).ValueOrDie();
}


auto wsrc_idx = works_tbl->schema()->GetFieldIndex("src");
auto wdst_idx = works_tbl->schema()->GetFieldIndex("dst");
auto wsrc = std::static_pointer_cast<arrow::Int64Array>(works_tbl->column(wsrc_idx)->chunk(0));
auto wdst = std::static_pointer_cast<arrow::Int64Array>(works_tbl->column(wdst_idx)->chunk(0));
auto wsrc = std::static_pointer_cast<arrow::Int64Array>(
works_tbl->column(wsrc_idx)->chunk(0));
auto wdst = std::static_pointer_cast<arrow::Int64Array>(
works_tbl->column(wdst_idx)->chunk(0));
for (int64_t i = 0; i < works_tbl->num_rows(); ++i) {
db.connect(wsrc->Value(i), "WORKS_AT", users_count + wdst->Value(i)).ValueOrDie();
db.connect(wsrc->Value(i), "WORKS_AT", users_count + wdst->Value(i))
.ValueOrDie();
}

db.get_table("User", nullptr).ValueOrDie();
db.get_table("Company", nullptr).ValueOrDie();

auto load_end = std::chrono::high_resolution_clock::now();
auto load_duration = std::chrono::duration_cast<std::chrono::milliseconds>(
load_end - load_start);
std::cerr << "Data load time: " << load_duration.count() << " ms" << std::endl;
std::cerr << "Data load time: " << load_duration.count() << " ms"
<< std::endl;
}

int64_t run_query(Database& db) {
auto query_start_time = std::chrono::high_resolution_clock::now();
Query query = Query::from("u:User")
.where("u.age", CompareOp::Gt, Value(30))
.and_where("u.country", CompareOp::Eq, Value(std::string("US")))
.traverse("u", "FRIEND", "f:User", TraverseType::Inner)
.where("f.age", CompareOp::Gt, Value((int64_t)25))
.select()
.parallel(true)
.inline_where()
.build();
Query query =
Query::match("u:User")
.where("u.age", CompareOp::Gt, Value(30))
.and_where("u.country", CompareOp::Eq, Value(std::string("US")))
.traverse("u", "FRIEND", "f:User", TraverseType::Inner)
.where("f.age", CompareOp::Gt, Value((int64_t)25))
.select()
.parallel(true)
.inline_where()
.build();

auto res = db.query(query);
auto query_end_time = std::chrono::high_resolution_clock::now();
Expand All @@ -134,22 +145,24 @@ int64_t run_query(Database& db) {
std::cerr << "Query failed: " << res.status().ToString() << "\n";
return -1;
}

auto table = res.ValueOrDie()->table();
int64_t row_count = table ? table->num_rows() : 0;

// Output in machine-readable format for Python parser
std::cout << query_duration.count() << std::endl; // Just the time in ms

return row_count;
}

int main(int argc, char** argv) {
if (argc < 5) {
std::cerr << "Usage: " << argv[0] << " <users.csv> <companies.csv> <friend.csv> <works_at.csv> [repetitions]\n";
std::cerr << "Usage: " << argv[0]
<< " <users.csv> <companies.csv> <friend.csv> <works_at.csv> "
"[repetitions]\n";
return 1;
}

std::string users_csv = argv[1];
std::string companies_csv = argv[2];
std::string friend_csv = argv[3];
Expand All @@ -158,16 +171,16 @@ int main(int argc, char** argv) {

// Build in-memory DB
auto config = make_config()
.with_persistence_enabled(false)
.with_shard_capacity(200000)
.with_chunk_size(100000)
.build();
.with_persistence_enabled(false)
.with_shard_capacity(200000)
.with_chunk_size(100000)
.build();

Database db(config);

// Load data once (not timed for benchmark)
load_data(db, users_csv, companies_csv, friend_csv, works_at_csv);

// Run query multiple times and output each timing
int64_t rows = 0;
for (int i = 0; i < repetitions; i++) {
Expand All @@ -177,7 +190,7 @@ int main(int argc, char** argv) {
return 2;
}
}

std::cerr << "rows=" << rows << std::endl;
return 0;
}
}
36 changes: 27 additions & 9 deletions include/main/database.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class Database {
*
* Mode 2 - by MATCH query (alias-qualified SET, multi-schema):
* db.update(UpdateQuery::match(
* Query::from("u:User")
* Query::match("u:User")
* .traverse("u", "WORKS_AT", "c:Company")
* .where("c.name", CompareOp::Eq, Value("Google"))
* .build()
Expand All @@ -177,13 +177,13 @@ class Database {
const std::vector<FieldUpdate> &fields,
UpdateType update_type, UpdateResult &result);

/** Initialize QueryState from query: temporal context, FROM table, prepare.
/** Initialize QueryState from query: temporal context, root table, prepare.
*/
[[nodiscard]] arrow::Status init_query_state(const Query &query,
QueryState &query_state) const;

/** Inline WHERE clauses applicable to the FROM alias. */
[[nodiscard]] arrow::Status inline_from_where(const Query &query,
/** Inline WHERE clauses applicable to the root alias. */
[[nodiscard]] arrow::Status inline_root_where(const Query &query,
QueryState &query_state,
QueryResult &result) const;

Expand All @@ -196,12 +196,30 @@ class Database {
/** Execute a single TRAVERSE clause, updating query_state in-place. */
[[nodiscard]] arrow::Status execute_traverse(
const std::shared_ptr<Traverse> &traverse, QueryState &query_state,
const Query &query, size_t clause_index, QueryResult &result) const;
const Query &query, size_t clause_index, size_t traverse_index,
QueryResult &result) const;

/** Apply a single-variable WHERE filter, or defer to post_where. */
[[nodiscard]] arrow::Status apply_where_filter(
const std::shared_ptr<WhereExpr> &where, QueryState &query_state,
std::vector<std::shared_ptr<WhereExpr>> &post_where) const;
/** High-level action chosen for one WHERE clause in legacy execution mode. */
struct WhereDisposition {
enum class Kind {
Skip,
Defer,
ApplyToAlias,
};

Kind kind = Kind::Skip;
std::string alias;
};

/** Classify a WHERE clause as skipped, deferred, or directly applicable. */
[[nodiscard]] arrow::Result<WhereDisposition> classify_where_filter(
const std::shared_ptr<WhereExpr> &where,
const QueryState &query_state) const;

/** Apply a single-alias WHERE clause to an already materialized alias. */
[[nodiscard]] arrow::Status apply_alias_where(
const std::shared_ptr<WhereExpr> &where, const std::string &alias,
QueryState &query_state) const;

/** Build the final output table: denormalize, populate rows, apply
* deferred WHERE filters, and project via SELECT. */
Expand Down
7 changes: 5 additions & 2 deletions include/query/execution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "query/query.hpp"
#include "query/where_planner.hpp"
#include "schema/schema.hpp"

namespace tundradb {
Expand Down Expand Up @@ -448,8 +449,9 @@ struct QueryState {
/// Arrow tables keyed by schema alias.
std::unordered_map<std::string, std::shared_ptr<arrow::Table>> tables;

SchemaRef from; ///< Source schema from the FROM clause.
SchemaRef root; ///< Root schema for query execution.
std::vector<Traverse> traversals; ///< Traverse clauses in query order.
std::optional<WhereExecutionPlan> where_plan; ///< Planned WHERE execution.

std::shared_ptr<SchemaRegistry> schema_registry; ///< Node schema registry.
std::shared_ptr<NodeManager> node_manager; ///< Node storage.
Expand Down Expand Up @@ -724,7 +726,8 @@ std::vector<std::shared_ptr<WhereExpr>> get_where_to_inline(
arrow::Result<std::shared_ptr<arrow::Table>> inline_where(
const SchemaRef& ref, std::shared_ptr<arrow::Table> table,
QueryState& query_state,
const std::vector<std::shared_ptr<WhereExpr>>& where_exprs);
const std::vector<std::shared_ptr<WhereExpr>>& where_exprs,
bool mark_inlined = true);

/**
* @brief Prepares a query for execution: registers aliases, resolves fields,
Expand Down
Loading
Loading