diff --git a/include/main/database.hpp b/include/main/database.hpp index 8db834a..bb65393 100644 --- a/include/main/database.hpp +++ b/include/main/database.hpp @@ -177,12 +177,38 @@ class Database { const std::vector &fields, UpdateType update_type, UpdateResult &result); - /** - * Build an alias->schema mapping from a Query's FROM + TRAVERSE clauses. - * Only declarations ("alias:Schema") are recorded; bare references ("alias") - * are skipped. Returns an error if the same alias is bound to two different - * schemas. + /** Initialize QueryState from query: temporal context, FROM table, prepare. + */ + [[nodiscard]] arrow::Status init_query_state(const Query &query, + QueryState &query_state) const; + + /** Inline WHERE clauses applicable to the FROM alias. */ + [[nodiscard]] arrow::Status inline_from_where(const Query &query, + QueryState &query_state, + QueryResult &result) const; + + /** Process all clauses (WHERE + TRAVERSE) and collect deferred expressions. */ + [[nodiscard]] arrow::Result>> + execute_clauses(const Query &query, QueryState &query_state, + QueryResult &result) const; + + /** Execute a single TRAVERSE clause, updating query_state in-place. */ + [[nodiscard]] arrow::Status execute_traverse( + const std::shared_ptr &traverse, QueryState &query_state, + const Query &query, size_t clause_index, QueryResult &result) const; + + /** Apply a single-variable WHERE filter, or defer to post_where. */ + [[nodiscard]] arrow::Status apply_where_filter( + const std::shared_ptr &where, QueryState &query_state, + std::vector> &post_where) const; + + /** Build the final output table: denormalize, populate rows, apply + * deferred WHERE filters, and project via SELECT. */ + [[nodiscard]] arrow::Result> build_result_table( + const Query &query, QueryState &query_state, + const std::vector> &post_where, + QueryResult &result) const; }; } // namespace tundradb diff --git a/include/query/execution.hpp b/include/query/execution.hpp index ef432e0..f26d363 100644 --- a/include/query/execution.hpp +++ b/include/query/execution.hpp @@ -604,6 +604,37 @@ struct QueryState { std::string ToString() const; }; +/** + * @brief Results of an expand traverse hop operation in a graph traversal + * + * Holds the sets of source and target node IDs identified during a single + * traversal hop, categorized by their connection statuses. + */ +struct ExpandTraverseHopResult { + /** + * @brief Source node IDs with matched connections discovered during traversal + * + * Represents the set of source node identifiers that have at least one valid + * connection to matching target nodes, as determined by traversal criteria. + */ + llvm::DenseSet matched_source_ids; + /** + * @brief Target node IDs that were successfully matched during traversal + * + * Represents the set of target node identifiers that satisfy the traversal + * criteria and have a valid connection from at least one source node. + */ + llvm::DenseSet matched_target_ids; + + /** + * @brief Source node IDs without any accepted edges + * + * Represents the set of node IDs that were part of the traversal but did not + * have any connecting edges or target nodes meeting the required criteria. + */ + llvm::DenseSet unmatched_source_ids; +}; + /** * @brief Recursively collects all paths from a node in a connection graph * (debug). @@ -705,6 +736,39 @@ arrow::Result> inline_where( */ arrow::Status prepare_query(const Query& query, QueryState& query_state); +/** + * @brief Executes one graph-pattern hop for a @c TRAVERSE clause. + * + * Walks the current source frontier in @p query_state (IDs under + * @c traverse.source().value()), loads each node's outgoing edges of type + * @c traverse.edge_type(), and keeps edges whose target resolves to a node in + * @p target_schema. Optional pruning applies when @p query_state already + * holds IDs for the target alias: targets not in that set are skipped. + * + * For each surviving edge, @p node_filters are evaluated on the target node and + * @p edge_filters on the edge (typically inlined WHERE expressions). Matches + * append a @c Connection to @p query_state (per-source connection lists and + * @c incoming on the target id) for later row materialization. + * + * The three output sets support @c JoinStrategy: sources with at least one + * matching edge, distinct matched target ids, and sources with no matching + * edge. Callers should clear these sets before calling if they are reused. + * + * @param traverse Parsed TRAVERSE (source/target aliases, edge type, optional + * edge alias). + * @param target_schema Resolved concrete schema name for the target endpoint. + * @param query_state Execution state; must have @c node_manager, @c edge_store, + * populated @c ids for the source alias, and (when used) target ids. + * @param node_filters WHERE expressions applied to each candidate target node. + * @param edge_filters WHERE expressions applied to each candidate edge. + * @return @c ExpandTraverseHopResult + */ +arrow::Result expand_traverse_hop( + const Traverse& traverse, const std::string& target_schema, + QueryState& query_state, + const std::vector>& node_filters, + const std::vector>& edge_filters); + } // namespace tundradb #endif // QUERY_EXECUTION_HPP diff --git a/src/main/database.cpp b/src/main/database.cpp index 70ffa9b..f441c51 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -621,7 +621,22 @@ arrow::Result> Database::query( QueryState query_state(this->schema_registry_); auto result = std::make_shared(); - // Initialize temporal context if AS OF clause is present + ARROW_RETURN_NOT_OK(init_query_state(query, query_state)); + ARROW_RETURN_NOT_OK(inline_from_where(query, query_state, *result)); + ARROW_ASSIGN_OR_RAISE(const auto post_where, + execute_clauses(query, query_state, *result)); + ARROW_ASSIGN_OR_RAISE( + auto output_table, + build_result_table(query, query_state, post_where, *result)); + result->set_table(std::move(output_table)); + return result; +} + +// --------------------------------------------------------------------------- +// Database::init_query_state +// --------------------------------------------------------------------------- +arrow::Status Database::init_query_state(const Query& query, + QueryState& query_state) const { if (query.temporal_snapshot().has_value()) { query_state.temporal_context = std::make_unique(query.temporal_snapshot().value()); @@ -632,451 +647,311 @@ arrow::Result> Database::query( } } - // Pre-size hash maps to avoid expensive resizing during execution query_state.reserve_capacity(query); + query_state.node_manager = this->node_manager_; + query_state.edge_store = this->edge_store_; IF_DEBUG_ENABLED { - log_debug("Executing query starting from schema '{}'", - query.from().toString()); + log_debug("processing 'from' {}", query.from().toString()); } - query_state.node_manager = this->node_manager_; - query_state.edge_store = this->edge_store_; + query_state.from = query.from(); + query_state.from.set_tag(compute_tag(query_state.from)); + ARROW_ASSIGN_OR_RAISE(auto source_schema, + query_state.register_schema(query.from())); + if (!this->schema_registry_->exists(source_schema)) { + return arrow::Status::KeyError("schema doesn't exit: {}", source_schema); + } + ARROW_ASSIGN_OR_RAISE( + auto source_table, + this->get_table(source_schema, query_state.temporal_context.get())); + ARROW_RETURN_NOT_OK(query_state.update_table(source_table, query.from())); + ARROW_RETURN_NOT_OK( + query_state.compute_fully_qualified_names(query.from()).status()); + + return prepare_query(query, query_state); +} - { - IF_DEBUG_ENABLED { - log_debug("processing 'from' {}", query.from().toString()); - } - query_state.from = query.from(); - query_state.from.set_tag(compute_tag(query_state.from)); - ARROW_ASSIGN_OR_RAISE(auto source_schema, - query_state.register_schema(query.from())); - if (!this->schema_registry_->exists(source_schema)) { - log_error("schema '{}' doesn't exist", source_schema); - return arrow::Status::KeyError("schema doesn't exit: {}", source_schema); - } - ARROW_ASSIGN_OR_RAISE( - auto source_table, - this->get_table(source_schema, query_state.temporal_context.get())); - ARROW_RETURN_NOT_OK(query_state.update_table(source_table, query.from())); - if (auto res = query_state.compute_fully_qualified_names(query.from()); - !res.ok()) { - return res.status(); +// --------------------------------------------------------------------------- +// Database::inline_from_where +// --------------------------------------------------------------------------- +arrow::Status Database::inline_from_where(const Query& query, + QueryState& query_state, + QueryResult& result) const { + auto where_exps = + get_where_to_inline(query.from().value(), 0, query.clauses()); + result.mutable_execution_stats().num_where_clauses_inlined += + where_exps.size(); + return inline_where(query.from(), query_state.tables[query.from().value()], + query_state, where_exps) + .status(); +} + +// --------------------------------------------------------------------------- +// Database::execute_clauses +// --------------------------------------------------------------------------- +arrow::Result>> +Database::execute_clauses(const Query& query, QueryState& query_state, + QueryResult& result) const { + std::vector> post_where; + for (size_t i = 0; i < query.clauses().size(); ++i) { + auto clause = query.clauses()[i]; + switch (clause->type()) { + case Clause::Type::WHERE: + ARROW_RETURN_NOT_OK( + apply_where_filter(std::dynamic_pointer_cast(clause), + query_state, post_where)); + break; + case Clause::Type::TRAVERSE: + ARROW_RETURN_NOT_OK( + execute_traverse(std::static_pointer_cast(clause), + query_state, query, i, result)); + break; + default: + return arrow::Status::NotImplemented( + "Database::query unsupported clause"); } } + return post_where; +} + +// --------------------------------------------------------------------------- +// Database::build_result_table +// --------------------------------------------------------------------------- +arrow::Result> Database::build_result_table( + const Query& query, QueryState& query_state, + const std::vector>& post_where, + QueryResult& result) const { + ARROW_ASSIGN_OR_RAISE(auto output_schema, + build_denormalized_schema(query_state)); + IF_DEBUG_ENABLED { log_debug("output_schema={}", output_schema->ToString()); } + + ARROW_ASSIGN_OR_RAISE(auto rows, + populate_rows(query.execution_config(), query_state, + query_state.traversals, output_schema)); + ARROW_ASSIGN_OR_RAISE(auto table, + create_table_from_rows(rows, output_schema)); - // PHASE: Query Preparation - Populate aliases, traversals, tags, and resolve - // field references - { + for (const auto& expr : post_where) { + result.mutable_execution_stats().num_where_clauses_post_processed++; + IF_DEBUG_ENABLED { log_debug("post process where: {}", expr->toString()); } + ARROW_ASSIGN_OR_RAISE(table, filter(table, *expr, false)); + } + return apply_select(query.select(), table); +} + +// --------------------------------------------------------------------------- +// Database::apply_where_filter +// --------------------------------------------------------------------------- +arrow::Status Database::apply_where_filter( + const std::shared_ptr& where, QueryState& query_state, + std::vector>& post_where) const { + if (where->inlined()) { IF_DEBUG_ENABLED { - log_debug( - "Preparing query: populating aliases, traversals, and resolving " - "field references"); + log_debug("where '{}' is inlined, skip", where->toString()); } - auto preparation_result = prepare_query(query, query_state); - if (!preparation_result.ok()) { - log_error("Failed to prepare query: {}", preparation_result.ToString()); - return preparation_result; - } - IF_DEBUG_ENABLED { log_debug("Query preparation completed successfully"); } + return arrow::Status::OK(); } - - { - auto where_exps = - get_where_to_inline(query.from().value(), 0, query.clauses()); - result->mutable_execution_stats().num_where_clauses_inlined += - where_exps.size(); - auto res = - inline_where(query.from(), query_state.tables[query.from().value()], - query_state, where_exps); - if (!res.ok()) { - return res.status(); + auto variables = where->get_all_variables(); + if (variables.empty()) { + return arrow::Status::Invalid( + "where clause field must have variable " + "., actual={}", + where->toString()); + } + if (variables.size() != 1) { + IF_DEBUG_ENABLED { + log_debug("Add compound WHERE expression: '{}' to post process", + where->toString()); } + post_where.emplace_back(where); + return arrow::Status::OK(); } IF_DEBUG_ENABLED { - log_debug("Processing {} query clauses", query.clauses().size()); + log_debug("Processing WHERE clause: '{}'", where->toString()); } - // Precompute 16-bit alias-based tags for all SchemaRefs - // Also precompute fully-qualified field names per alias used in the query - std::vector> post_where; - for (auto i = 0; i < query.clauses().size(); ++i) { - switch (auto clause = query.clauses()[i]; clause->type()) { - case Clause::Type::WHERE: { - auto where = std::dynamic_pointer_cast(clause); - if (where->inlined()) { - IF_DEBUG_ENABLED { - log_debug("where '{}' is inlined, skip", where->toString()); - } - continue; - } - auto variables = where->get_all_variables(); - if (variables.empty()) { - return arrow::Status::Invalid( - "where clause field must have variable " - "., actual={}", - where->toString()); - } - if (variables.size() == 1) { - IF_DEBUG_ENABLED { - log_debug("Processing WHERE clause: '{}'", where->toString()); - } - - std::unordered_map> new_front_ids; - std::string variable = *variables.begin(); - if (!query_state.tables.contains(variable)) { - if (!query_state.aliases().contains(variable)) { - return arrow::Status::Invalid("Unknown variable '{}'", variable); - } - // Alias is valid but not materialized as a table at this point - // (e.g. edge alias). Defer to post-processing/inlined traversal. - post_where.emplace_back(where); - continue; - } - auto table = query_state.tables.at(variable); - arrow::Result> filtered_table_result = - filter(table, *where, true); - if (!filtered_table_result.ok() && where->requires_row_eval()) { - ARROW_ASSIGN_OR_RAISE( - const auto resolved_schema, - query_state.resolve_schema(SchemaRef::parse(variable))); - - llvm::DenseSet keep_ids; - for (const auto id : query_state.ids()[variable]) { - auto node_res = node_manager_->get_node(resolved_schema, id); - if (!node_res.ok()) continue; - ARROW_ASSIGN_OR_RAISE(const bool matches, - where->matches(node_res.ValueOrDie())); - if (matches) { - keep_ids.insert(id); - } - } - - auto id_column = table->GetColumnByName("id"); - if (!id_column) { - return arrow::Status::Invalid( - "Could not find 'id' column for variable '", variable, "'"); - } + std::string variable = *variables.begin(); + if (!query_state.tables.contains(variable)) { + if (!query_state.aliases().contains(variable)) { + return arrow::Status::Invalid("Unknown variable '{}'", variable); + } + post_where.emplace_back(where); + return arrow::Status::OK(); + } + auto table = query_state.tables.at(variable); + arrow::Result> filtered_table_result = + filter(table, *where, true); + if (!filtered_table_result.ok() && where->requires_row_eval()) { + ARROW_ASSIGN_OR_RAISE( + const auto resolved_schema, + query_state.resolve_schema(SchemaRef::parse(variable))); + + llvm::DenseSet keep_ids; + for (const auto id : query_state.ids()[variable]) { + auto node_res = node_manager_->get_node(resolved_schema, id); + if (!node_res.ok()) continue; + ARROW_ASSIGN_OR_RAISE(const bool matches, + where->matches(node_res.ValueOrDie())); + if (matches) { + keep_ids.insert(id); + } + } - arrow::BooleanBuilder mask_builder; - for (int ci = 0; ci < id_column->num_chunks(); ++ci) { - auto ids = std::static_pointer_cast( - id_column->chunk(ci)); - for (int64_t irow = 0; irow < ids->length(); ++irow) { - if (ids->IsNull(irow)) { - ARROW_RETURN_NOT_OK(mask_builder.Append(false)); - } else { - ARROW_RETURN_NOT_OK( - mask_builder.Append(keep_ids.contains(ids->Value(irow)))); - } - } - } + auto id_column = table->GetColumnByName("id"); + if (!id_column) { + return arrow::Status::Invalid("Could not find 'id' column for variable '", + variable, "'"); + } - std::shared_ptr mask_array; - ARROW_RETURN_NOT_OK(mask_builder.Finish(&mask_array)); - ARROW_ASSIGN_OR_RAISE( - auto filtered_datum, - arrow::compute::Filter(arrow::Datum(table), - arrow::Datum(mask_array))); - filtered_table_result = filtered_datum.table(); - } - if (!filtered_table_result.ok()) { - log_error("Failed to process where: '{}', error: {}", - where->toString(), - filtered_table_result.status().ToString()); - return filtered_table_result.status(); - } - ARROW_RETURN_NOT_OK(query_state.update_table( - filtered_table_result.ValueOrDie(), SchemaRef::parse(variable))); + arrow::BooleanBuilder mask_builder; + for (int ci = 0; ci < id_column->num_chunks(); ++ci) { + auto ids = + std::static_pointer_cast(id_column->chunk(ci)); + for (int64_t irow = 0; irow < ids->length(); ++irow) { + if (ids->IsNull(irow)) { + ARROW_RETURN_NOT_OK(mask_builder.Append(false)); } else { - IF_DEBUG_ENABLED { - log_debug("Add compound WHERE expression: '{}' to post process", - where->toString()); - } - post_where.emplace_back(where); + ARROW_RETURN_NOT_OK( + mask_builder.Append(keep_ids.contains(ids->Value(irow)))); } - break; } - case Clause::Type::TRAVERSE: { - auto traverse = std::static_pointer_cast(clause); - // Tags and schemas are already set during preparation phase - - // Get resolved schemas using const resolve_schema (read-only) - ARROW_ASSIGN_OR_RAISE(const auto source_schema, - query_state.resolve_schema(traverse->source())); - ARROW_ASSIGN_OR_RAISE(const auto target_schema, - query_state.resolve_schema(traverse->target())); - // Fully-qualified field names should also be precomputed during - // preparation - ARROW_RETURN_NOT_OK( - query_state.compute_fully_qualified_names(traverse->source())); - ARROW_RETURN_NOT_OK( - query_state.compute_fully_qualified_names(traverse->target())); - if (traverse->edge_alias().has_value()) { - ARROW_RETURN_NOT_OK(query_state.compute_fully_qualified_names( - SchemaRef::parse(traverse->edge_alias().value()))); - } + } - std::vector> where_clauses; - std::vector> edge_where_clauses; - if (query.inline_where()) { - where_clauses = get_where_to_inline(traverse->target().value(), i + 1, - query.clauses()); - } - if (traverse->edge_alias().has_value()) { - edge_where_clauses = get_where_to_inline( - traverse->edge_alias().value(), i + 1, query.clauses()); - } - for (const auto& wc : where_clauses) wc->set_inlined(true); - for (const auto& wc : edge_where_clauses) wc->set_inlined(true); - result->mutable_execution_stats().num_where_clauses_inlined += - where_clauses.size() + edge_where_clauses.size(); - // Traversal already added to query_state.traversals during preparation - IF_DEBUG_ENABLED { - log_debug("Processing TRAVERSE {}-({})->{}", - traverse->source().toString(), traverse->edge_type(), - traverse->target().toString()); - } - auto source = traverse->source(); - if (!query_state.tables.contains(source.value())) { - IF_DEBUG_ENABLED { - log_debug("Source table '{}' not found. Loading", - traverse->source().toString()); - } - ARROW_ASSIGN_OR_RAISE( - auto source_table, - this->get_table(source_schema, - query_state.temporal_context.get())); - ARROW_RETURN_NOT_OK( - query_state.update_table(source_table, traverse->source())); - } + std::shared_ptr mask_array; + ARROW_RETURN_NOT_OK(mask_builder.Finish(&mask_array)); + ARROW_ASSIGN_OR_RAISE( + auto filtered_datum, + arrow::compute::Filter(arrow::Datum(table), arrow::Datum(mask_array))); + filtered_table_result = filtered_datum.table(); + } + if (!filtered_table_result.ok()) { + log_error("Failed to process where: '{}', error: {}", where->toString(), + filtered_table_result.status().ToString()); + return filtered_table_result.status(); + } + ARROW_RETURN_NOT_OK(query_state.update_table( + filtered_table_result.ValueOrDie(), SchemaRef::parse(variable))); + return arrow::Status::OK(); +} - IF_DEBUG_ENABLED { - log_debug("Traversing from {} source nodes", - query_state.ids()[source.value()].size()); - } - llvm::DenseSet matched_source_ids; - llvm::DenseSet matched_target_ids; - llvm::DenseSet unmatched_source_ids; - for (auto source_id : query_state.ids()[source.value()]) { - auto outgoing_edges = - edge_store_->get_outgoing_edges(source_id, traverse->edge_type()) - .ValueOrDie(); // todo check result - IF_DEBUG_ENABLED { - log_debug("Node {} has {} outgoing edges of type '{}'", source_id, - outgoing_edges.size(), traverse->edge_type()); - } +// --------------------------------------------------------------------------- +// Database::execute_traverse +// --------------------------------------------------------------------------- +arrow::Status Database::execute_traverse( + const std::shared_ptr& traverse, QueryState& query_state, + const Query& query, size_t clause_index, QueryResult& result) const { + ARROW_ASSIGN_OR_RAISE(const auto source_schema, + query_state.resolve_schema(traverse->source())); + ARROW_ASSIGN_OR_RAISE(const auto target_schema, + query_state.resolve_schema(traverse->target())); + ARROW_RETURN_NOT_OK( + query_state.compute_fully_qualified_names(traverse->source())); + ARROW_RETURN_NOT_OK( + query_state.compute_fully_qualified_names(traverse->target())); + if (traverse->edge_alias().has_value()) { + ARROW_RETURN_NOT_OK(query_state.compute_fully_qualified_names( + SchemaRef::parse(traverse->edge_alias().value()))); + } - bool source_had_match = false; - for (const auto& edge : outgoing_edges) { - auto target_id = edge->get_target_id(); - if (query_state.ids().contains(traverse->target().value()) && - !query_state.ids() - .at(traverse->target().value()) - .contains(target_id)) { - continue; - } - auto node_result = - node_manager_->get_node(target_schema, target_id); - if (node_result.ok()) { - if (const auto target_node = node_result.ValueOrDie(); - target_node->schema_name == target_schema) { - // Then apply all WHERE clauses with AND logic - bool passes_all_filters = true; - // Multiple conditions - could optimize by creating a - // temporary table and using Arrow expressions For now, use - // the existing approach but this could be optimized - for (const auto& where_clause : where_clauses) { - auto node_where = - apply_where_to_node(where_clause, target_node); - if (!node_where.ok()) { - return node_where.status(); - } - if (!node_where.ValueOrDie()) { - passes_all_filters = false; - break; - } - } - if (passes_all_filters) { - for (const auto& where_clause : edge_where_clauses) { - auto edge_where = apply_where_to_edge(where_clause, edge); - if (!edge_where.ok()) { - return edge_where.status(); - } - if (!edge_where.ValueOrDie()) { - passes_all_filters = false; - break; - } - } - } - if (passes_all_filters) { - IF_DEBUG_ENABLED { - log_debug("found edge {}:{} -[{}{}]-> {}:{}", - source.value(), source_id, - traverse->edge_alias().has_value() - ? traverse->edge_alias().value() + ":" - : "", - traverse->edge_type(), traverse->target().value(), - target_node->id); - } - // record match immediately to avoid extra containers/copies - if (!source_had_match) { - matched_source_ids.insert(source_id); - source_had_match = true; - } - matched_target_ids.insert(target_node->id); - // Use connection pool to avoid allocation - auto& conn = query_state.connection_pool().get(); - conn.source = traverse->source(); - conn.source_id = source_id; - conn.edge_id = edge->get_id(); - conn.edge_alias = traverse->edge_alias(); - conn.edge_type = traverse->edge_type(); - conn.label = ""; - conn.target = traverse->target(); - conn.target_id = target_node->id; - - query_state - .connections()[traverse->source().value()][source_id] - .push_back(conn); - query_state.incoming()[target_node->id].push_back(conn); - } - } - } else { - log_warn("Failed to get node {}:{}, error: {}", - traverse->target().value(), target_id, - node_result.status().ToString()); - } - } - if (!source_had_match) { - IF_DEBUG_ENABLED { - log_debug("no edge found from {}:{}", source.value(), source_id); - } - unmatched_source_ids.insert(source_id); - } - } - IF_DEBUG_ENABLED { - log_debug("found {} neighbors for {}", matched_target_ids.size(), - traverse->target().toString()); - } + std::vector> where_clauses; + std::vector> edge_where_clauses; + if (query.inline_where()) { + where_clauses = get_where_to_inline(traverse->target().value(), + clause_index + 1, query.clauses()); + } + if (traverse->edge_alias().has_value()) { + edge_where_clauses = get_where_to_inline(traverse->edge_alias().value(), + clause_index + 1, query.clauses()); + } + for (const auto& wc : where_clauses) wc->set_inlined(true); + for (const auto& wc : edge_where_clauses) wc->set_inlined(true); + result.mutable_execution_stats().num_where_clauses_inlined += + where_clauses.size() + edge_where_clauses.size(); - // For RIGHT/FULL joins we need all target IDs from the table - llvm::DenseSet all_target_ids; - if (traverse->traverse_type() == TraverseType::Right || - traverse->traverse_type() == TraverseType::Full) { - all_target_ids = - get_ids_from_table( - get_table(target_schema, query_state.temporal_context.get()) - .ValueOrDie()) - .ValueOrDie(); - } + IF_DEBUG_ENABLED { + log_debug("Processing TRAVERSE {}-({})->{}", traverse->source().toString(), + traverse->edge_type(), traverse->target().toString()); + } + auto source = traverse->source(); + if (!query_state.tables.contains(source.value())) { + IF_DEBUG_ENABLED { + log_debug("Source table '{}' not found. Loading", + traverse->source().toString()); + } + ARROW_ASSIGN_OR_RAISE( + auto source_table, + this->get_table(source_schema, query_state.temporal_context.get())); + ARROW_RETURN_NOT_OK( + query_state.update_table(source_table, traverse->source())); + } - const bool is_self_join = source_schema == target_schema; - auto strategy = JoinStrategyFactory::create(traverse->traverse_type(), - is_self_join); + ARROW_ASSIGN_OR_RAISE( + auto hop_result, + expand_traverse_hop(*traverse, target_schema, query_state, where_clauses, + edge_where_clauses)); + + llvm::DenseSet all_target_ids; + if (traverse->traverse_type() == TraverseType::Right || + traverse->traverse_type() == TraverseType::Full) { + all_target_ids = + get_ids_from_table( + get_table(target_schema, query_state.temporal_context.get()) + .ValueOrDie()) + .ValueOrDie(); + } - IF_DEBUG_ENABLED { - log_debug("Using {} join strategy (self_join={})", strategy->name(), - is_self_join); - } + const bool is_self_join = source_schema == target_schema; + auto strategy = + JoinStrategyFactory::create(traverse->traverse_type(), is_self_join); - JoinInput join_input{ - .source_ids = query_state.ids()[source.value()], - .all_target_ids = all_target_ids, - .matched_source_ids = matched_source_ids, - .matched_target_ids = matched_target_ids, - .existing_target_ids = query_state.get_ids(traverse->target()), - .unmatched_source_ids = unmatched_source_ids, - .is_self_join = is_self_join, - }; + IF_DEBUG_ENABLED { + log_debug("Using {} join strategy (self_join={})", strategy->name(), + is_self_join); + } - auto join_output = strategy->compute(join_input); + JoinInput join_input{ + .source_ids = query_state.ids()[source.value()], + .all_target_ids = all_target_ids, + .matched_source_ids = hop_result.matched_source_ids, + .matched_target_ids = hop_result.matched_target_ids, + .existing_target_ids = query_state.get_ids(traverse->target()), + .unmatched_source_ids = hop_result.unmatched_source_ids, + .is_self_join = is_self_join, + }; - // Apply target IDs - query_state.ids()[traverse->target().value()] = join_output.target_ids; + auto join_output = strategy->compute(join_input); - // Apply source pruning (INNER join removes unmatched sources) - if (join_output.rebuild_source_table) { - for (auto id : join_output.source_ids_to_remove) { - IF_DEBUG_ENABLED { - log_debug("remove unmatched node={}:{}", source.value(), id); - } - query_state.remove_node(id, source); - } - auto table_result = - filter_table_by_id(query_state.tables[source.value()], - query_state.ids()[source.value()]); - if (!table_result.ok()) { - return table_result.status(); - } - query_state.tables[source.value()] = table_result.ValueOrDie(); - } + query_state.ids()[traverse->target().value()] = join_output.target_ids; - std::vector> neighbors; - for (auto id : query_state.ids()[traverse->target().value()]) { - if (auto node_res = node_manager_->get_node(target_schema, id); - node_res.ok()) { - neighbors.push_back(node_res.ValueOrDie()); - } - } - auto target_table_schema = - schema_registry_->get(target_schema).ValueOrDie(); - auto table_result = - create_table_from_nodes(target_table_schema, neighbors); - if (!table_result.ok()) { - log_error("Failed to create table from neighbors: {}", - table_result.status().ToString()); - return table_result.status(); - } - ARROW_RETURN_NOT_OK(query_state.update_table(table_result.ValueOrDie(), - traverse->target())); - break; + if (join_output.rebuild_source_table) { + for (auto id : join_output.source_ids_to_remove) { + IF_DEBUG_ENABLED { + log_debug("remove unmatched node={}:{}", source.value(), id); } - default: - log_error("Unsupported clause type: {}", - static_cast(clause->type())); - return arrow::Status::NotImplemented( - "Database::query unsupported clause"); + query_state.remove_node(id, source); } + auto table_result = filter_table_by_id(query_state.tables[source.value()], + query_state.ids()[source.value()]); + if (!table_result.ok()) { + return table_result.status(); + } + query_state.tables[source.value()] = table_result.ValueOrDie(); } - IF_DEBUG_ENABLED { - log_debug("Query processing complete, building result"); - log_debug("Query state: {}", query_state.ToString()); - } - - auto output_schema_res = build_denormalized_schema(query_state); - if (!output_schema_res.ok()) { - return output_schema_res.status(); - } - const auto output_schema = output_schema_res.ValueOrDie(); - IF_DEBUG_ENABLED { log_debug("output_schema={}", output_schema->ToString()); } - - auto row_res = populate_rows(query.execution_config(), query_state, - query_state.traversals, output_schema); - if (!row_res.ok()) { - return row_res.status(); - } - auto rows = row_res.ValueOrDie(); - auto output_table_res = create_table_from_rows(rows, output_schema); - if (!output_table_res.ok()) { - log_error("Failed to create table from rows: {}", - output_table_res.status().ToString()); - return output_table_res.status(); - } - auto output_table = output_table_res.ValueOrDie(); - for (const auto& expr : post_where) { - result->mutable_execution_stats().num_where_clauses_post_processed++; - IF_DEBUG_ENABLED { log_debug("post process where: {}", expr->toString()); } - auto filtered = filter(output_table, *expr, false); - if (!filtered.ok()) { - log_error("Post-process WHERE failed: {}", filtered.status().ToString()); - return filtered.status(); + std::vector> neighbors; + for (auto id : query_state.ids()[traverse->target().value()]) { + if (auto node_res = node_manager_->get_node(target_schema, id); + node_res.ok()) { + neighbors.push_back(node_res.ValueOrDie()); } - output_table = filtered.ValueOrDie(); } - result->set_table(apply_select(query.select(), output_table)); - return result; + auto target_table_schema = schema_registry_->get(target_schema).ValueOrDie(); + ARROW_ASSIGN_OR_RAISE(auto target_table, create_table_from_nodes( + target_table_schema, neighbors)); + ARROW_RETURN_NOT_OK( + query_state.update_table(target_table, traverse->target())); + return arrow::Status::OK(); } // --------------------------------------------------------------------------- diff --git a/src/query/execution.cpp b/src/query/execution.cpp index c735a1a..9a636bc 100644 --- a/src/query/execution.cpp +++ b/src/query/execution.cpp @@ -786,4 +786,100 @@ arrow::Status prepare_query(const Query& query, QueryState& query_state) { return arrow::Status::OK(); } +// See declaration in execution.hpp for behavior and parameters. +arrow::Result expand_traverse_hop( + const Traverse& traverse, const std::string& target_schema, + QueryState& query_state, + const std::vector>& node_filters, + const std::vector>& edge_filters) { + llvm::DenseSet matched_source_ids; + llvm::DenseSet matched_target_ids; + llvm::DenseSet unmatched_source_ids; + const auto& source_alias = traverse.source().value(); + for (auto source_id : query_state.ids()[source_alias]) { + auto outgoing_edges = + query_state.edge_store + ->get_outgoing_edges(source_id, traverse.edge_type()) + .ValueOrDie(); + IF_DEBUG_ENABLED { + log_debug("Node {} has {} outgoing edges of type '{}'", source_id, + outgoing_edges.size(), traverse.edge_type()); + } + + bool source_had_match = false; + for (const auto& edge : outgoing_edges) { + auto target_id = edge->get_target_id(); + if (query_state.ids().contains(traverse.target().value()) && + !query_state.ids() + .at(traverse.target().value()) + .contains(target_id)) { + continue; + } + auto node_result = + query_state.node_manager->get_node(target_schema, target_id); + if (!node_result.ok()) { + log_warn("Failed to get node {}:{}, error: {}", + traverse.target().value(), target_id, + node_result.status().ToString()); + continue; + } + const auto target_node = node_result.ValueOrDie(); + if (target_node->schema_name != target_schema) continue; + + const auto inlined_wheres_pass = [&]() -> arrow::Result { + for (const auto& wc : node_filters) { + ARROW_ASSIGN_OR_RAISE(const bool ok, + apply_where_to_node(wc, target_node)); + if (!ok) return false; + } + for (const auto& wc : edge_filters) { + ARROW_ASSIGN_OR_RAISE(const bool ok, apply_where_to_edge(wc, edge)); + if (!ok) return false; + } + return true; + }; + ARROW_ASSIGN_OR_RAISE(const bool ok, inlined_wheres_pass()); + if (!ok) continue; + + IF_DEBUG_ENABLED { + log_debug("found edge {}:{} -[{}{}]-> {}:{}", source_alias, source_id, + traverse.edge_alias().has_value() + ? traverse.edge_alias().value() + ":" + : "", + traverse.edge_type(), traverse.target().value(), + target_node->id); + } + if (!source_had_match) { + matched_source_ids.insert(source_id); + source_had_match = true; + } + matched_target_ids.insert(target_node->id); + + auto& conn = query_state.connection_pool().get(); + conn.source = traverse.source(); + conn.source_id = source_id; + conn.edge_id = edge->get_id(); + conn.edge_alias = traverse.edge_alias(); + conn.edge_type = traverse.edge_type(); + conn.label = ""; + conn.target = traverse.target(); + conn.target_id = target_node->id; + query_state.connections()[source_alias][source_id].push_back(conn); + query_state.incoming()[target_node->id].push_back(conn); + } + if (!source_had_match) { + IF_DEBUG_ENABLED { + log_debug("no edge found from {}:{}", source_alias, source_id); + } + unmatched_source_ids.insert(source_id); + } + } + IF_DEBUG_ENABLED { + log_debug("found {} neighbors for {}", matched_target_ids.size(), + traverse.target().toString()); + } + return ExpandTraverseHopResult{matched_source_ids, matched_target_ids, + unmatched_source_ids}; +} + } // namespace tundradb