Skip to content

Commit ead4266

Browse files
committed
add arrow csr dst col name param
1 parent e401748 commit ead4266

5 files changed

Lines changed: 74 additions & 7 deletions

File tree

src_cpp/include/py_connection.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class PyConnection {
5757
py::object arrowTable);
5858
std::unique_ptr<PyQueryResult> createArrowRelTable(const std::string& tableName,
5959
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName,
60-
const std::string& layout, py::object indptrTable);
60+
const std::string& layout, py::object indptrTable, const std::string& dstColName = "to");
6161
std::unique_ptr<PyQueryResult> dropArrowTable(const std::string& tableName);
6262

6363
static Value transformPythonValue(const py::handle& val);

src_cpp/py_connection.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ void PyConnection::initialize(py::handle& m) {
5555
py::arg("arrow_table"))
5656
.def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"),
5757
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"),
58-
py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none())
58+
py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none(),
59+
py::arg("dst_col_name") = "to")
5960
.def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name"));
6061
PyDateTime_IMPORT;
6162
}
@@ -1070,7 +1071,7 @@ std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string&
10701071

10711072
std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::string& tableName,
10721073
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName,
1073-
const std::string& layout, py::object indptrTable) {
1074+
const std::string& layout, py::object indptrTable, const std::string& dstColName) {
10741075
auto& stateRef = refState();
10751076
py::gil_scoped_acquire acquire;
10761077

@@ -1097,7 +1098,7 @@ std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::stri
10971098
keepAlive.append(exportedIndptr.keepAlive);
10981099
result = ArrowTableSupport::createRelTableFromArrowCSR(stateRef.ref(), tableName,
10991100
srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays),
1100-
std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays));
1101+
std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays), dstColName);
11011102
} else {
11021103
throw RuntimeException("Arrow relationship table layout must be FLAT or CSR");
11031104
}

src_py/_lbug_capi.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def _setup_signatures() -> None:
343343
ctypes.POINTER(_ArrowSchema),
344344
ctypes.POINTER(_ArrowArray),
345345
ctypes.c_uint64,
346+
ctypes.c_char_p,
346347
ctypes.POINTER(_LbugQueryResult),
347348
]
348349
_LIB.lbug_connection_create_arrow_rel_table_csr.restype = ctypes.c_int
@@ -2341,6 +2342,7 @@ def create_arrow_rel_table(
23412342
dst_table_name: str,
23422343
layout: Any = "FLAT",
23432344
indptr_dataframe: Any | None = None,
2345+
dst_col_name: str = "to",
23442346
) -> QueryResult:
23452347
layout_value = getattr(layout, "value", layout)
23462348
layout_value = str(layout_value).upper()
@@ -2385,6 +2387,7 @@ def create_arrow_rel_table(
23852387
ctypes.byref(indptr_schema),
23862388
indptr_arrays,
23872389
len(indptr_arrays),
2390+
dst_col_name.encode("utf-8"),
23882391
ctypes.byref(result),
23892392
)
23902393
if state != _LBUG_SUCCESS and not result._query_result:

src_py/connection.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,7 @@ def create_arrow_rel_table(
887887
dst_table_name: str,
888888
layout: ArrowRelTableLayout | str = ArrowRelTableLayout.FLAT,
889889
indptr_dataframe: Any | None = None,
890+
dst_col_name: str = "to",
890891
) -> QueryResult:
891892
"""
892893
Create an Arrow memory-backed relationship table from a DataFrame.
@@ -908,13 +909,17 @@ def create_arrow_rel_table(
908909
layout : ArrowRelTableLayout | str
909910
Relationship layout. FLAT expects ``dataframe`` to contain ``from``
910911
and ``to`` endpoint columns. CSR expects ``dataframe`` to contain a
911-
``to`` destination offset column plus properties, and
912-
``indptr_dataframe`` to contain source offsets.
912+
destination offset column (named by ``dst_col_name``) plus
913+
properties, and ``indptr_dataframe`` to contain source offsets.
913914
914915
indptr_dataframe : Any | None
915916
A pandas DataFrame, polars DataFrame, or PyArrow table containing
916917
CSR source offsets. Required when ``layout`` is CSR.
917918
919+
dst_col_name : str
920+
Name of the destination offset column in the CSR indices table.
921+
Defaults to ``"to"``. Only used when ``layout`` is CSR.
922+
918923
Returns
919924
-------
920925
QueryResult
@@ -936,6 +941,7 @@ def create_arrow_rel_table(
936941
dst_table_name,
937942
layout_value,
938943
indptr_dataframe,
944+
dst_col_name,
939945
)
940946
except NotImplementedError:
941947
py_connection = self._get_pybind_connection()
@@ -949,6 +955,7 @@ def create_arrow_rel_table(
949955
dst_table_name,
950956
layout_value,
951957
indptr_dataframe,
958+
dst_col_name,
952959
)
953960
if not query_result_internal.isSuccess():
954961
raise RuntimeError(query_result_internal.getErrorMessage())

test/test_arrow_memory_backed_table.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,63 @@ def test_arrow_memory_backed_csr_arrow_rel_table(conn_db_empty: ConnDB) -> None:
401401
conn.drop_arrow_table("arrow_csr_people")
402402

403403

404-
def test_arrow_memory_backed_native_node_and_arrow_rel_table(
404+
def test_arrow_memory_backed_csr_rel_table_custom_dst_col(
405+
conn_db_empty: ConnDB,
406+
) -> None:
407+
"""Test Arrow CSR relationship table with a custom destination column name."""
408+
conn, _ = conn_db_empty
409+
410+
import ladybug as lb
411+
412+
pa = pytest.importorskip("pyarrow")
413+
414+
people = pa.Table.from_arrays(
415+
[pa.array([1, 2, 3], type=pa.int64())],
416+
names=["id"],
417+
)
418+
conn.create_arrow_table("csr_custom_dst_people", people)
419+
420+
# Use "destination" instead of the default "to"
421+
indices = pa.Table.from_arrays(
422+
[
423+
pa.array([1, 2, 2], type=pa.uint64()),
424+
pa.array([10, 20, 30], type=pa.int64()),
425+
],
426+
names=["destination", "weight"],
427+
)
428+
indptr = pa.Table.from_arrays(
429+
[pa.array([0, 2, 3, 3], type=pa.uint64())],
430+
names=["indptr"],
431+
)
432+
conn.create_arrow_rel_table(
433+
"csr_custom_dst_knows",
434+
indices,
435+
"csr_custom_dst_people",
436+
"csr_custom_dst_people",
437+
layout=lb.ArrowRelTableLayout.CSR,
438+
indptr_dataframe=indptr,
439+
dst_col_name="destination",
440+
)
441+
442+
result = conn.execute(
443+
"MATCH (a:csr_custom_dst_people)-[r:csr_custom_dst_knows]->(b:csr_custom_dst_people) "
444+
"RETURN a.id, b.id, r.weight ORDER BY a.id, b.id"
445+
)
446+
rows = []
447+
while result.has_next():
448+
rows.append(result.get_next())
449+
450+
assert rows == [
451+
[1, 2, 10],
452+
[1, 3, 20],
453+
[2, 3, 30],
454+
]
455+
456+
conn.drop_arrow_table("csr_custom_dst_knows")
457+
conn.drop_arrow_table("csr_custom_dst_people")
458+
459+
460+
def test_arrow_memory_backed_rel_table_over_native_node_tables(
405461
conn_db_empty: ConnDB,
406462
) -> None:
407463
"""Test an Arrow memory-backed relationship over native node tables."""

0 commit comments

Comments
 (0)