From 04b50857ca6a128f86f60d2b7d04fdccda4b27b8 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Thu, 27 Mar 2025 13:37:16 -0400 Subject: [PATCH] fix: use simple function pointer to avoid potential PyCapsule change in pybind11 Signed-off-by: Henry Schreiner --- include/bh_python/transform.hpp | 33 ++------------------------- src/boost_histogram/axis/transform.py | 7 +++++- src/register_transforms.cpp | 8 +++---- 3 files changed, 12 insertions(+), 36 deletions(-) diff --git a/include/bh_python/transform.hpp b/include/bh_python/transform.hpp index 17f8e483..2d05a4be 100644 --- a/include/bh_python/transform.hpp +++ b/include/bh_python/transform.hpp @@ -58,39 +58,10 @@ struct func_transform { return std::make_tuple(ptr, src); } - // If we made it to this point, we probably have a C++ pybind object or an - // invalid object. The following is based on the std::function conversion in - // pybind11/functional.hpp - if(!py::isinstance(src)) - throw py::type_error("Only ctypes double(double) and C++ functions allowed " - "(must be function)"); - - auto func = py::reinterpret_borrow(src); - - if(auto cfunc = func.cpp_function()) { - auto c = py::reinterpret_borrow( - PyCFunction_GET_SELF(cfunc.ptr())); - - auto rec = c.get_pointer(); - - if(rec && rec->is_stateless - && py::detail::same_type( - typeid(raw_t*), - *reinterpret_cast(rec->data[1]))) { - struct capture { - raw_t* f; - }; - return std::make_tuple((reinterpret_cast(&rec->data))->f, - src); - } - - // Note that each error is slightly different just to help with debugging - throw py::type_error("Only ctypes double(double) and C++ functions allowed " - "(must be stateless)"); - } + // If we made it to this point, we probably have an invalid object. throw py::type_error("Only ctypes double(double) and C++ functions allowed " - "(must be cpp function)"); + "(must be a stateless cpp function)"); } func_transform(py::object f, py::object i, py::object c, py::str n) diff --git a/src/boost_histogram/axis/transform.py b/src/boost_histogram/axis/transform.py index 20dca369..8afd396a 100644 --- a/src/boost_histogram/axis/transform.py +++ b/src/boost_histogram/axis/transform.py @@ -1,10 +1,12 @@ from __future__ import annotations import copy +import ctypes from typing import Any, ClassVar, TypeVar import boost_histogram +from .. import _core from .._core import axis as ca from .._utils import register @@ -12,6 +14,8 @@ __all__ = ["AxisTransform", "Function", "Pow", "log", "sqrt"] +LIB = ctypes.CDLL(_core.__file__) + def __dir__() -> list[str]: return __all__ @@ -150,7 +154,8 @@ def _produce(self, bins: int, start: float, stop: float) -> Any: def _internal_conversion(name: str) -> Any: - return getattr(ca.transform, name) + ftype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double) + return ctypes.cast(getattr(LIB, name), ftype) sqrt = Function("_sqrt_fn", "_sq_fn", convert=_internal_conversion, name="sqrt") diff --git a/src/register_transforms.cpp b/src/register_transforms.cpp index 7abe758b..a7a37d84 100644 --- a/src/register_transforms.cpp +++ b/src/register_transforms.cpp @@ -32,10 +32,10 @@ py::class_ register_transform(py::module& mod, Args&&... args) { } extern "C" { -double _log_fn(double v) { return std::log(v); } -double _exp_fn(double v) { return std::exp(v); } -double _sqrt_fn(double v) { return std::sqrt(v); } -double _sq_fn(double v) { return v * v; } +PYBIND11_EXPORT double _log_fn(double v) { return std::log(v); } +PYBIND11_EXPORT double _exp_fn(double v) { return std::exp(v); } +PYBIND11_EXPORT double _sqrt_fn(double v) { return std::sqrt(v); } +PYBIND11_EXPORT double _sq_fn(double v) { return v * v; } } void register_transforms(py::module& mod) {