diff --git a/src/c_api/connection.cpp b/src/c_api/connection.cpp index c335e8278..c7a808dd6 100644 --- a/src/c_api/connection.cpp +++ b/src/c_api/connection.cpp @@ -276,7 +276,7 @@ lbug_state lbug_connection_create_arrow_rel_table_csr(lbug_connection* connectio const char* table_name, const char* src_table_name, const char* dst_table_name, ArrowSchema* indices_schema, ArrowArray* indices_arrays, uint64_t num_indices_arrays, ArrowSchema* indptr_schema, ArrowArray* indptr_arrays, uint64_t num_indptr_arrays, - lbug_query_result* out_query_result) { + const char* dst_col_name, lbug_query_result* out_query_result) { if (connection == nullptr || connection->_connection == nullptr || table_name == nullptr || src_table_name == nullptr || dst_table_name == nullptr || indices_schema == nullptr || indices_arrays == nullptr || indptr_schema == nullptr || indptr_arrays == nullptr || @@ -289,7 +289,8 @@ lbug_state lbug_connection_create_arrow_rel_table_csr(lbug_connection* connectio *static_cast(connection->_connection), table_name, src_table_name, dst_table_name, takeArrowSchema(indices_schema), takeArrowArrays(indices_arrays, num_indices_arrays), takeArrowSchema(indptr_schema), - takeArrowArrays(indptr_arrays, num_indptr_arrays)); + takeArrowArrays(indptr_arrays, num_indptr_arrays), + dst_col_name != nullptr ? dst_col_name : "to"); auto state = setQueryResult(std::move(result.queryResult), out_query_result); if (state == LbugSuccess) { rememberArrowTableID(static_cast(connection->_connection), table_name, diff --git a/src/include/c_api/lbug.h b/src/include/c_api/lbug.h index dde510ab5..af186b21b 100644 --- a/src/include/c_api/lbug.h +++ b/src/include/c_api/lbug.h @@ -440,16 +440,18 @@ LBUG_C_API lbug_state lbug_connection_create_arrow_rel_table(lbug_connection* co /** * @brief Creates a CSR Arrow memory-backed relationship table from Arrow C Data Interface data. * - * The indices Arrow table must contain a destination offset column named "to" and any relationship - * property columns. The indptr Arrow table must contain one offset column. Ownership of schemas and - * arrays is transferred to lbug on success or failure. The caller must not release them after this - * call. + * The indices Arrow table must contain a destination offset column and any relationship property + * columns. The indptr Arrow table must contain one offset column. Ownership of schemas and arrays + * is transferred to lbug on success or failure. The caller must not release them after this call. + * + * @param dst_col_name Name of the destination offset column in the indices table. If NULL, + * defaults to "to". */ LBUG_C_API lbug_state lbug_connection_create_arrow_rel_table_csr(lbug_connection* connection, const char* table_name, const char* src_table_name, const char* dst_table_name, struct ArrowSchema* indices_schema, struct ArrowArray* indices_arrays, uint64_t num_indices_arrays, struct ArrowSchema* indptr_schema, - struct ArrowArray* indptr_arrays, uint64_t num_indptr_arrays, + struct ArrowArray* indptr_arrays, uint64_t num_indptr_arrays, const char* dst_col_name, lbug_query_result* out_query_result); /** * @brief Drops an Arrow memory-backed table. diff --git a/src/include/storage/table/arrow_rel_table.h b/src/include/storage/table/arrow_rel_table.h index 9e7c03d33..7f6d8c928 100644 --- a/src/include/storage/table/arrow_rel_table.h +++ b/src/include/storage/table/arrow_rel_table.h @@ -33,7 +33,8 @@ class ArrowRelTable final : public ColumnarRelTableBase { MemoryManager* memoryManager, const NodeTable* fromNodeTable, const NodeTable* toNodeTable, ArrowRelTableLayout layout, ArrowSchemaWrapper schema, std::vector arrays, ArrowSchemaWrapper indptrSchema, - std::vector indptrArrays, std::string arrowId); + std::vector indptrArrays, std::string arrowId, + std::string dstColumnName = "to"); ~ArrowRelTable(); void initScanState(transaction::Transaction* transaction, TableScanState& scanState, diff --git a/src/include/storage/table/arrow_table_support.h b/src/include/storage/table/arrow_table_support.h index 740215fa5..a35ade080 100644 --- a/src/include/storage/table/arrow_table_support.h +++ b/src/include/storage/table/arrow_table_support.h @@ -19,6 +19,7 @@ struct ArrowRelTableData { std::vector arrays; ArrowSchemaWrapper indptrSchema; std::vector indptrArrays; + std::string dstColumnName = "to"; }; // Result of creating an arrow table view diff --git a/src/storage/storage_manager.cpp b/src/storage/storage_manager.cpp index 14af10228..c31aa6a5f 100644 --- a/src/storage/storage_manager.cpp +++ b/src/storage/storage_manager.cpp @@ -200,7 +200,8 @@ void StorageManager::addRelTable(RelGroupCatalogEntry* entry, const RelTableCata tables[info.oid] = std::make_unique(entry, info.nodePair.srcTableID, info.nodePair.dstTableID, this, &memoryManager, fromNodeTable, toNodeTable, relData->layout, std::move(schemaCopy), std::move(arraysCopy), - std::move(indptrSchemaCopy), std::move(indptrArraysCopy), arrowId); + std::move(indptrSchemaCopy), std::move(indptrArraysCopy), arrowId, + relData->dstColumnName); } else { throw common::RuntimeException( "Unsupported storage option for rel table: " + entry->getStorage()); diff --git a/src/storage/table/arrow_rel_table.cpp b/src/storage/table/arrow_rel_table.cpp index e0f65d26c..327443eaf 100644 --- a/src/storage/table/arrow_rel_table.cpp +++ b/src/storage/table/arrow_rel_table.cpp @@ -64,7 +64,7 @@ ArrowRelTable::ArrowRelTable(catalog::RelGroupCatalogEntry* relGroupEntry, table const NodeTable* fromNodeTable, const NodeTable* toNodeTable, ArrowRelTableLayout layout, ArrowSchemaWrapper schema, std::vector arrays, ArrowSchemaWrapper indptrSchema, std::vector indptrArrays, - std::string arrowId) + std::string arrowId, std::string dstColumnName) : ColumnarRelTableBase{relGroupEntry, fromTableID, toTableID, storageManager, memoryManager}, fromNodeTable{fromNodeTable}, toNodeTable{toNodeTable}, layout{layout}, schema{std::move(schema)}, arrays{std::move(arrays)}, indptrSchema{std::move(indptrSchema)}, @@ -100,14 +100,15 @@ ArrowRelTable::ArrowRelTable(catalog::RelGroupCatalogEntry* relGroupEntry, table " must match destination node PK type " + dstPKType.toString()); } } else { - csrNbrColumnIdx = findColumnIdx(this->schema, "to"); + csrNbrColumnIdx = findColumnIdx(this->schema, dstColumnName); if (csrNbrColumnIdx < 0) { - throw RuntimeException("Arrow CSR relationship table requires a 'to' column"); + throw RuntimeException( + "Arrow CSR relationship table requires a '" + dstColumnName + "' column"); } auto nbrArrowType = ArrowConverter::fromArrowSchema(this->schema.children[csrNbrColumnIdx]); if (nbrArrowType.getLogicalTypeID() != LogicalTypeID::UINT64) { - throw RuntimeException("Arrow CSR 'to' column type " + nbrArrowType.toString() + - " must be UINT64 node offsets"); + throw RuntimeException("Arrow CSR '" + dstColumnName + "' column type " + + nbrArrowType.toString() + " must be UINT64 node offsets"); } if (!this->indptrSchema.format || this->indptrArrays.empty()) { throw RuntimeException("Arrow CSR relationship table requires an indptr Arrow table"); diff --git a/src/storage/table/arrow_table_support.cpp b/src/storage/table/arrow_table_support.cpp index 5a09ebf17..504c7a79e 100644 --- a/src/storage/table/arrow_table_support.cpp +++ b/src/storage/table/arrow_table_support.cpp @@ -233,6 +233,7 @@ ArrowTableCreationResult ArrowTableSupport::createRelTableFromArrowCSR(main::Con data.arrays = std::move(indicesArrays); data.indptrSchema = std::move(indptrSchema); data.indptrArrays = std::move(indptrArrays); + data.dstColumnName = dstColumnName; std::string arrowId = registerArrowRelData(std::move(data)); std::string statement = "CREATE REL TABLE " + tableName + " " + tableDef + diff --git a/test/api/arrow_csr_rel_table_test.cpp b/test/api/arrow_csr_rel_table_test.cpp index 3021f8e69..ea9bf58fa 100644 --- a/test/api/arrow_csr_rel_table_test.cpp +++ b/test/api/arrow_csr_rel_table_test.cpp @@ -327,6 +327,37 @@ TEST_F(ArrowCsrRelTableTest, CsrOverNativeNodeTableScans) { ASSERT_EQ(countResult->getNext()->getValue(0)->getValue(), 1); } +TEST_F(ArrowCsrRelTableTest, CustomDstColumnName) { + std::vector fwdIndices, fwdIndptr; + + // Build indices table with column named "dest" instead of "to" + std::vector dst = {1, 2, 2, 3}; + std::vector w = {10, 20, 30, 40}; + fwdIndices.push_back(createStructArray(4, {[&](ArrowArray* a) { createUint64Array(a, dst); }, + [&](ArrowArray* a) { createInt64Array(a, w); }})); + fwdIndptr.push_back(makeFwdIndptrArray()); + + ArrowSchemaWrapper idxSchema; + createStructSchema(&idxSchema, 2); + createSchema(idxSchema.children[0], "dest"); + createSchema(idxSchema.children[1], "weight"); + + auto result = ArrowTableSupport::createRelTableFromArrowCSR(*conn, "csr_knows", "csr_person", + "csr_person", std::move(idxSchema), std::move(fwdIndices), makeIndptrSchema(), + std::move(fwdIndptr), "dest"); + ASSERT_TRUE(result.queryResult->isSuccess()) << result.queryResult->getErrorMessage(); + + auto countResult = + conn->query("MATCH (:csr_person)-[:csr_knows]->(:csr_person) RETURN count(*)"); + ASSERT_TRUE(countResult->isSuccess()) << countResult->getErrorMessage(); + ASSERT_EQ(countResult->getNext()->getValue(0)->getValue(), 4); + + auto sumResult = + conn->query("MATCH (:csr_person)-[e:csr_knows]->(:csr_person) RETURN sum(e.weight)"); + ASSERT_TRUE(sumResult->isSuccess()) << sumResult->getErrorMessage(); + ASSERT_EQ(sumResult->getNext()->getValue(0)->getValue(), 100); +} + TEST_F(ArrowCsrRelTableTest, MultiBatchCsrIndices) { std::vector fwdIndices; fwdIndices.push_back(