Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,15 @@ static LogicalType pyNumpyArrayLogicalType(const py::array& arr) {
return type;
}

static bool hasNumpyTypeModule(const py::handle& val) {
auto module = py::str(val.get_type().attr("__module__")).cast<std::string>();
return module == "numpy" || module.starts_with("numpy.");
}

static bool isNumpyArray(const py::handle& val) {
return hasNumpyTypeModule(val) && py::isinstance<py::array>(val);
}

static LogicalType pyLogicalType(const py::handle& val) {
auto datetime_datetime = importCache->datetime.datetime();
auto time_delta = importCache->datetime.timedelta();
Expand Down Expand Up @@ -514,7 +523,7 @@ static LogicalType pyLogicalType(const py::handle& val) {
childValueType = std::move(resultValue);
}
return LogicalType::MAP(std::move(childKeyType), std::move(childValueType));
} else if (py::isinstance<py::array>(val)) {
} else if (isNumpyArray(val)) {
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
} else if (py::isinstance<py::list>(val)) {
py::list lst = py::reinterpret_borrow<py::list>(val);
Expand Down Expand Up @@ -620,7 +629,7 @@ static LogicalType pyLogicalTypeFromParameter(const py::handle& val) {
structFields.emplace_back(std::move(keyName), std::move(keyType));
}
return LogicalType::STRUCT(std::move(structFields));
} else if (py::isinstance<py::array>(val)) {
} else if (isNumpyArray(val)) {
return pyNumpyArrayLogicalType(py::reinterpret_borrow<py::array>(val));
} else if (py::isinstance<py::list>(val)) {
py::list lst = py::reinterpret_borrow<py::list>(val);
Expand Down Expand Up @@ -852,7 +861,7 @@ Value PyConnection::transformPythonValueAs(const py::handle& val, const LogicalT
return Value{uuidToAppend};
}
case LogicalTypeID::LIST: {
if (py::isinstance<py::array>(val)) {
if (isNumpyArray(val)) {
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
}
py::list lst = py::reinterpret_borrow<py::list>(val);
Expand Down Expand Up @@ -910,7 +919,7 @@ Value PyConnection::transformPythonValueFromParameterAs(const py::handle& val,
auto jsonStr = pythonObjectToJsonString(val);
return Value::createValue<std::string>(jsonStr);
}
if (py::isinstance<py::array>(val)) {
if (isNumpyArray(val)) {
return transformNumpyArrayAs(py::reinterpret_borrow<py::array>(val), type);
}
py::list lst = py::reinterpret_borrow<py::list>(val);
Expand Down
9 changes: 9 additions & 0 deletions src_py/_lbug_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,12 +971,21 @@ def _numpy_module() -> Any | None:
return np


def _has_numpy_type_module(value: Any) -> bool:
module = type(value).__module__
return module == "numpy" or module.startswith("numpy.")


def _is_numpy_scalar(value: Any) -> bool:
if not _has_numpy_type_module(value):
return False
np = _numpy_module()
return bool(np is not None and isinstance(value, np.generic))


def _is_numpy_array(value: Any) -> bool:
if not _has_numpy_type_module(value):
return False
np = _numpy_module()
return bool(np is not None and isinstance(value, np.ndarray))

Expand Down
Loading