diff --git a/src_cpp/py_connection.cpp b/src_cpp/py_connection.cpp index 660196c..7c7667a 100644 --- a/src_cpp/py_connection.cpp +++ b/src_cpp/py_connection.cpp @@ -434,6 +434,15 @@ static LogicalType pyNumpyArrayLogicalType(const py::array& arr) { return type; } +static bool hasNumpyTypeModule(const py::handle& val) { + auto module = py::str(val.get_type().attr("__module__")).cast(); + return module == "numpy" || module.starts_with("numpy."); +} + +static bool isNumpyArray(const py::handle& val) { + return hasNumpyTypeModule(val) && py::isinstance(val); +} + static LogicalType pyLogicalType(const py::handle& val) { auto datetime_datetime = importCache->datetime.datetime(); auto time_delta = importCache->datetime.timedelta(); @@ -514,7 +523,7 @@ static LogicalType pyLogicalType(const py::handle& val) { childValueType = std::move(resultValue); } return LogicalType::MAP(std::move(childKeyType), std::move(childValueType)); - } else if (py::isinstance(val)) { + } else if (isNumpyArray(val)) { return pyNumpyArrayLogicalType(py::reinterpret_borrow(val)); } else if (py::isinstance(val)) { py::list lst = py::reinterpret_borrow(val); @@ -620,7 +629,7 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) { structFields.emplace_back(std::move(keyName), std::move(keyType)); } return LogicalType::STRUCT(std::move(structFields)); - } else if (py::isinstance(val)) { + } else if (isNumpyArray(val)) { return pyNumpyArrayLogicalType(py::reinterpret_borrow(val)); } else if (py::isinstance(val)) { py::list lst = py::reinterpret_borrow(val); @@ -852,7 +861,7 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT return Value{uuidToAppend}; } case LogicalTypeID::LIST: { - if (py::isinstance(val)) { + if (isNumpyArray(val)) { return transformNumpyArrayAs(py::reinterpret_borrow(val), type); } py::list lst = py::reinterpret_borrow(val); @@ -910,7 +919,7 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val, auto jsonStr = pythonObjectToJsonString(val); return Value::createValue(jsonStr); } - if (py::isinstance(val)) { + if (isNumpyArray(val)) { return transformNumpyArrayAs(py::reinterpret_borrow(val), type); } py::list lst = py::reinterpret_borrow(val); diff --git a/src_py/_lbug_capi.py b/src_py/_lbug_capi.py index c20c4b3..7978be8 100644 --- a/src_py/_lbug_capi.py +++ b/src_py/_lbug_capi.py @@ -971,12 +971,21 @@ def _numpy_module() -> Any | None: return np +def _has_numpy_type_module(value: Any) -> bool: + module = type(value).__module__ + return module == "numpy" or module.startswith("numpy.") + + def _is_numpy_scalar(value: Any) -> bool: + if not _has_numpy_type_module(value): + return False np = _numpy_module() return bool(np is not None and isinstance(value, np.generic)) def _is_numpy_array(value: Any) -> bool: + if not _has_numpy_type_module(value): + return False np = _numpy_module() return bool(np is not None and isinstance(value, np.ndarray))