Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/c_api/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ lbug_state lbug_connection_create_arrow_rel_table_csr(lbug_connection* connectio
const char* table_name, const char* src_table_name, const char* dst_table_name,
ArrowSchema* indices_schema, ArrowArray* indices_arrays, uint64_t num_indices_arrays,
ArrowSchema* indptr_schema, ArrowArray* indptr_arrays, uint64_t num_indptr_arrays,
lbug_query_result* out_query_result) {
const char* dst_col_name, lbug_query_result* out_query_result) {
if (connection == nullptr || connection->_connection == nullptr || table_name == nullptr ||
src_table_name == nullptr || dst_table_name == nullptr || indices_schema == nullptr ||
indices_arrays == nullptr || indptr_schema == nullptr || indptr_arrays == nullptr ||
Expand All @@ -289,7 +289,8 @@ lbug_state lbug_connection_create_arrow_rel_table_csr(lbug_connection* connectio
*static_cast<Connection*>(connection->_connection), table_name, src_table_name,
dst_table_name, takeArrowSchema(indices_schema),
takeArrowArrays(indices_arrays, num_indices_arrays), takeArrowSchema(indptr_schema),
takeArrowArrays(indptr_arrays, num_indptr_arrays));
takeArrowArrays(indptr_arrays, num_indptr_arrays),
dst_col_name != nullptr ? dst_col_name : "to");
auto state = setQueryResult(std::move(result.queryResult), out_query_result);
if (state == LbugSuccess) {
rememberArrowTableID(static_cast<Connection*>(connection->_connection), table_name,
Expand Down
12 changes: 7 additions & 5 deletions src/include/c_api/lbug.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,16 +440,18 @@ LBUG_C_API lbug_state lbug_connection_create_arrow_rel_table(lbug_connection* co
/**
* @brief Creates a CSR Arrow memory-backed relationship table from Arrow C Data Interface data.
*
* The indices Arrow table must contain a destination offset column named "to" and any relationship
* property columns. The indptr Arrow table must contain one offset column. Ownership of schemas and
* arrays is transferred to lbug on success or failure. The caller must not release them after this
* call.
* The indices Arrow table must contain a destination offset column and any relationship property
* columns. The indptr Arrow table must contain one offset column. Ownership of schemas and arrays
* is transferred to lbug on success or failure. The caller must not release them after this call.
*
* @param dst_col_name Name of the destination offset column in the indices table. If NULL,
* defaults to "to".
*/
LBUG_C_API lbug_state lbug_connection_create_arrow_rel_table_csr(lbug_connection* connection,
const char* table_name, const char* src_table_name, const char* dst_table_name,
struct ArrowSchema* indices_schema, struct ArrowArray* indices_arrays,
uint64_t num_indices_arrays, struct ArrowSchema* indptr_schema,
struct ArrowArray* indptr_arrays, uint64_t num_indptr_arrays,
struct ArrowArray* indptr_arrays, uint64_t num_indptr_arrays, const char* dst_col_name,
lbug_query_result* out_query_result);
/**
* @brief Drops an Arrow memory-backed table.
Expand Down
3 changes: 2 additions & 1 deletion src/include/storage/table/arrow_rel_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class ArrowRelTable final : public ColumnarRelTableBase {
MemoryManager* memoryManager, const NodeTable* fromNodeTable, const NodeTable* toNodeTable,
ArrowRelTableLayout layout, ArrowSchemaWrapper schema,
std::vector<ArrowArrayWrapper> arrays, ArrowSchemaWrapper indptrSchema,
std::vector<ArrowArrayWrapper> indptrArrays, std::string arrowId);
std::vector<ArrowArrayWrapper> indptrArrays, std::string arrowId,
std::string dstColumnName = "to");
~ArrowRelTable();

void initScanState(transaction::Transaction* transaction, TableScanState& scanState,
Expand Down
1 change: 1 addition & 0 deletions src/include/storage/table/arrow_table_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct ArrowRelTableData {
std::vector<ArrowArrayWrapper> arrays;
ArrowSchemaWrapper indptrSchema;
std::vector<ArrowArrayWrapper> indptrArrays;
std::string dstColumnName = "to";
};

// Result of creating an arrow table view
Expand Down
3 changes: 2 additions & 1 deletion src/storage/storage_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ void StorageManager::addRelTable(RelGroupCatalogEntry* entry, const RelTableCata
tables[info.oid] = std::make_unique<ArrowRelTable>(entry, info.nodePair.srcTableID,
info.nodePair.dstTableID, this, &memoryManager, fromNodeTable, toNodeTable,
relData->layout, std::move(schemaCopy), std::move(arraysCopy),
std::move(indptrSchemaCopy), std::move(indptrArraysCopy), arrowId);
std::move(indptrSchemaCopy), std::move(indptrArraysCopy), arrowId,
relData->dstColumnName);
} else {
throw common::RuntimeException(
"Unsupported storage option for rel table: " + entry->getStorage());
Expand Down
11 changes: 6 additions & 5 deletions src/storage/table/arrow_rel_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ ArrowRelTable::ArrowRelTable(catalog::RelGroupCatalogEntry* relGroupEntry, table
const NodeTable* fromNodeTable, const NodeTable* toNodeTable, ArrowRelTableLayout layout,
ArrowSchemaWrapper schema, std::vector<ArrowArrayWrapper> arrays,
ArrowSchemaWrapper indptrSchema, std::vector<ArrowArrayWrapper> indptrArrays,
std::string arrowId)
std::string arrowId, std::string dstColumnName)
: ColumnarRelTableBase{relGroupEntry, fromTableID, toTableID, storageManager, memoryManager},
fromNodeTable{fromNodeTable}, toNodeTable{toNodeTable}, layout{layout},
schema{std::move(schema)}, arrays{std::move(arrays)}, indptrSchema{std::move(indptrSchema)},
Expand Down Expand Up @@ -100,14 +100,15 @@ ArrowRelTable::ArrowRelTable(catalog::RelGroupCatalogEntry* relGroupEntry, table
" must match destination node PK type " + dstPKType.toString());
}
} else {
csrNbrColumnIdx = findColumnIdx(this->schema, "to");
csrNbrColumnIdx = findColumnIdx(this->schema, dstColumnName);
if (csrNbrColumnIdx < 0) {
throw RuntimeException("Arrow CSR relationship table requires a 'to' column");
throw RuntimeException(
"Arrow CSR relationship table requires a '" + dstColumnName + "' column");
}
auto nbrArrowType = ArrowConverter::fromArrowSchema(this->schema.children[csrNbrColumnIdx]);
if (nbrArrowType.getLogicalTypeID() != LogicalTypeID::UINT64) {
throw RuntimeException("Arrow CSR 'to' column type " + nbrArrowType.toString() +
" must be UINT64 node offsets");
throw RuntimeException("Arrow CSR '" + dstColumnName + "' column type " +
nbrArrowType.toString() + " must be UINT64 node offsets");
}
if (!this->indptrSchema.format || this->indptrArrays.empty()) {
throw RuntimeException("Arrow CSR relationship table requires an indptr Arrow table");
Expand Down
1 change: 1 addition & 0 deletions src/storage/table/arrow_table_support.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ ArrowTableCreationResult ArrowTableSupport::createRelTableFromArrowCSR(main::Con
data.arrays = std::move(indicesArrays);
data.indptrSchema = std::move(indptrSchema);
data.indptrArrays = std::move(indptrArrays);
data.dstColumnName = dstColumnName;
std::string arrowId = registerArrowRelData(std::move(data));

std::string statement = "CREATE REL TABLE " + tableName + " " + tableDef +
Expand Down
31 changes: 31 additions & 0 deletions test/api/arrow_csr_rel_table_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,37 @@ TEST_F(ArrowCsrRelTableTest, CsrOverNativeNodeTableScans) {
ASSERT_EQ(countResult->getNext()->getValue(0)->getValue<int64_t>(), 1);
}

TEST_F(ArrowCsrRelTableTest, CustomDstColumnName) {
std::vector<ArrowArrayWrapper> fwdIndices, fwdIndptr;

// Build indices table with column named "dest" instead of "to"
std::vector<uint64_t> dst = {1, 2, 2, 3};
std::vector<int64_t> w = {10, 20, 30, 40};
fwdIndices.push_back(createStructArray(4, {[&](ArrowArray* a) { createUint64Array(a, dst); },
[&](ArrowArray* a) { createInt64Array(a, w); }}));
fwdIndptr.push_back(makeFwdIndptrArray());

ArrowSchemaWrapper idxSchema;
createStructSchema(&idxSchema, 2);
createSchema<uint64_t>(idxSchema.children[0], "dest");
createSchema<int64_t>(idxSchema.children[1], "weight");

auto result = ArrowTableSupport::createRelTableFromArrowCSR(*conn, "csr_knows", "csr_person",
"csr_person", std::move(idxSchema), std::move(fwdIndices), makeIndptrSchema(),
std::move(fwdIndptr), "dest");
ASSERT_TRUE(result.queryResult->isSuccess()) << result.queryResult->getErrorMessage();

auto countResult =
conn->query("MATCH (:csr_person)-[:csr_knows]->(:csr_person) RETURN count(*)");
ASSERT_TRUE(countResult->isSuccess()) << countResult->getErrorMessage();
ASSERT_EQ(countResult->getNext()->getValue(0)->getValue<int64_t>(), 4);

auto sumResult =
conn->query("MATCH (:csr_person)-[e:csr_knows]->(:csr_person) RETURN sum(e.weight)");
ASSERT_TRUE(sumResult->isSuccess()) << sumResult->getErrorMessage();
ASSERT_EQ(sumResult->getNext()->getValue(0)->getValue<common::int128_t>(), 100);
}

TEST_F(ArrowCsrRelTableTest, MultiBatchCsrIndices) {
std::vector<ArrowArrayWrapper> fwdIndices;
fwdIndices.push_back(
Expand Down
Loading