Skip to content

Commit 902efae

Browse files
committed
Avoid eager NumPy checks for prepared parameters
1 parent 0d9f55c commit 902efae

2 files changed

Lines changed: 22 additions & 4 deletions

File tree

src_cpp/py_connection.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,15 @@ static LogicalType pyNumpyArrayLogicalType(const py::array& arr) {
434434
return type;
435435
}
436436

437+
static bool hasNumpyTypeModule(const py::handle& val) {
438+
auto module = py::str(val.get_type().attr("__module__")).cast<std::string>();
439+
return module == "numpy" || module.starts_with("numpy.");
440+
}
441+
442+
static bool isNumpyArray(const py::handle& val) {
443+
return hasNumpyTypeModule(val) && py::isinstance<py::array>(val);
444+
}
445+
437446
static LogicalType pyLogicalType(const py::handle& val) {
438447
auto datetime_datetime = importCache->datetime.datetime();
439448
auto time_delta = importCache->datetime.timedelta();
@@ -514,7 +523,7 @@ static LogicalType pyLogicalType(const py::handle& val) {
514523
childValueType = std::move(resultValue);
515524
}
516525
return LogicalType::MAP(std::move(childKeyType), std::move(childValueType));
517-
} else if (py::isinstance<py::array>(val)) {
526+
} else if (isNumpyArray(val)) {
518527
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
519528
} else if (py::isinstance<py::list>(val)) {
520529
py::list lst = py::reinterpret_borrow<py::list>(val);
@@ -620,7 +629,7 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
620629
structFields.emplace_back(std::move(keyName), std::move(keyType));
621630
}
622631
return LogicalType::STRUCT(std::move(structFields));
623-
} else if (py::isinstance<py::array>(val)) {
632+
} else if (isNumpyArray(val)) {
624633
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
625634
} else if (py::isinstance<py::list>(val)) {
626635
py::list lst = py::reinterpret_borrow<py::list>(val);
@@ -852,7 +861,7 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
852861
return Value{uuidToAppend};
853862
}
854863
case LogicalTypeID::LIST: {
855-
if (py::isinstance<py::array>(val)) {
864+
if (isNumpyArray(val)) {
856865
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
857866
}
858867
py::list lst = py::reinterpret_borrow<py::list>(val);
@@ -910,7 +919,7 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
910919
auto jsonStr = pythonObjectToJsonString(val);
911920
return Value::createValue<std::string>(jsonStr);
912921
}
913-
if (py::isinstance<py::array>(val)) {
922+
if (isNumpyArray(val)) {
914923
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
915924
}
916925
py::list lst = py::reinterpret_borrow<py::list>(val);

src_py/_lbug_capi.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,12 +971,21 @@ def _numpy_module() -> Any | None:
971971
return np
972972

973973

974+
def _has_numpy_type_module(value: Any) -> bool:
975+
module = type(value).__module__
976+
return module == "numpy" or module.startswith("numpy.")
977+
978+
974979
def _is_numpy_scalar(value: Any) -> bool:
980+
if not _has_numpy_type_module(value):
981+
return False
975982
np = _numpy_module()
976983
return bool(np is not None and isinstance(value, np.generic))
977984

978985

979986
def _is_numpy_array(value: Any) -> bool:
987+
if not _has_numpy_type_module(value):
988+
return False
980989
np = _numpy_module()
981990
return bool(np is not None and isinstance(value, np.ndarray))
982991

0 commit comments

Comments
 (0)