diff --git a/src/py.rs b/src/py.rs index 6de871d..1f646f0 100644 --- a/src/py.rs +++ b/src/py.rs @@ -6,13 +6,13 @@ //! Python exception types raised by `adaptive.learner.triangulation` in the //! same situations. -use numpy::PyArray2; +use numpy::{PyArray2, PyReadonlyArray1, PyReadonlyArray2}; use pyo3::exceptions::{ PyAssertionError, PyIndexError, PyNotImplementedError, PyRuntimeError, PyTypeError, PyValueError, PyZeroDivisionError, }; use pyo3::prelude::*; -use pyo3::types::{PyAny, PyDict, PyList, PyModule, PySet, PyTuple}; +use pyo3::types::{PyAny, PyList, PyModule, PySet, PyTuple}; use rustc_hash::FxHashSet; use crate::geometry::GeometryError; @@ -34,6 +34,10 @@ impl TriangulationError { } pub(crate) fn parse_point(obj: &Bound<'_, PyAny>) -> PyResult> { + // Bulk-copy f64 numpy arrays instead of iterating Python objects. + if let Ok(array) = obj.extract::>() { + return Ok(array.as_array().to_vec()); + } let Ok(iter) = obj.try_iter() else { return Err(PyTypeError::new_err("Expected an iterable of floats")); }; @@ -62,6 +66,15 @@ fn parse_points_impl( type_error_message: &str, require_sized: bool, ) -> PyResult>> { + // Bulk-copy f64 numpy arrays instead of iterating Python objects row by + // row; other dtypes and nested sequences take the generic path below. + if let Ok(array) = obj.extract::>() { + return Ok(array + .as_array() + .outer_iter() + .map(|row| row.to_vec()) + .collect()); + } if require_sized { ensure_sized(obj, type_error_message)?; } @@ -217,14 +230,6 @@ pub(crate) fn simplex_set_py<'a>( Ok(PySet::new(py, &tuples)?.into()) } -pub(crate) fn point_list_py(py: Python<'_>, points: &[Vec]) -> PyResult> { - let tuples: Vec> = points - .iter() - .map(|point| point_tuple(py, point).into()) - .collect(); - Ok(PyList::new(py, tuples)?.into()) -} - pub(crate) fn index_list_py(py: Python<'_>, indices: &[usize]) -> PyResult> { Ok(PyList::new(py, indices.iter().copied())?.into()) } @@ -315,6 +320,16 @@ impl PySimplicesProxy { vertex: Some(vertex), } } + + /// The current simplices as a real Python set, for the set-operator + /// protocol below. + fn snapshot_set(&self, py: Python<'_>) -> PyResult> { + let triangulation = self.triangulation.bind(py).borrow(); + match self.vertex { + Some(vertex) => simplex_set_py(py, triangulation.core.simplices_of(vertex)), + None => simplex_set_py(py, triangulation.core.simplices()), + } + } } /// Lazy, sequence-like view of the vertex coordinates (supports `__array__` @@ -406,6 +421,66 @@ impl PySimplicesProxy { None => triangulation.core.num_simplices(), } } + + // The reference exposes `simplices` as a real set, so callers combine it + // with sets freely (adaptive does `... - tri.simplices`). Delegate the + // binary set operators to a snapshot set, in both operand orders. + + fn __sub__(&self, py: Python<'_>, other: &Bound<'_, PyAny>) -> PyResult> { + Ok(self + .snapshot_set(py)? + .bind(py) + .call_method1("__sub__", (other,))? + .unbind()) + } + + fn __rsub__(&self, py: Python<'_>, other: &Bound<'_, PyAny>) -> PyResult> { + Ok(other + .call_method1("__sub__", (self.snapshot_set(py)?,))? + .unbind()) + } + + fn __and__(&self, py: Python<'_>, other: &Bound<'_, PyAny>) -> PyResult> { + Ok(self + .snapshot_set(py)? + .bind(py) + .call_method1("__and__", (other,))? + .unbind()) + } + + fn __rand__(&self, py: Python<'_>, other: &Bound<'_, PyAny>) -> PyResult> { + Ok(other + .call_method1("__and__", (self.snapshot_set(py)?,))? + .unbind()) + } + + fn __or__(&self, py: Python<'_>, other: &Bound<'_, PyAny>) -> PyResult> { + Ok(self + .snapshot_set(py)? + .bind(py) + .call_method1("__or__", (other,))? + .unbind()) + } + + fn __ror__(&self, py: Python<'_>, other: &Bound<'_, PyAny>) -> PyResult> { + Ok(other + .call_method1("__or__", (self.snapshot_set(py)?,))? + .unbind()) + } + + fn __xor__(&self, py: Python<'_>, other: &Bound<'_, PyAny>) -> PyResult> { + Ok(self + .snapshot_set(py)? + .bind(py) + .call_method1("__xor__", (other,))? + .unbind()) + } + + fn __rxor__(&self, py: Python<'_>, other: &Bound<'_, PyAny>) -> PyResult> { + Ok(other + .call_method1("__xor__", (self.snapshot_set(py)?,))? + .unbind()) + } } #[pymethods] @@ -466,22 +541,16 @@ impl PyVerticesProxy { dtype: Option<&Bound<'_, PyAny>>, copy: Option, ) -> PyResult> { + // The snapshot is always a freshly allocated array, so the `copy` + // argument needs no special handling. Build it directly instead of + // round-tripping through a list of Python tuples. + let _ = copy; let triangulation = self.triangulation.bind(py).borrow(); - let vertices = point_list_py(py, &triangulation.core.vertices)?; - let numpy = PyModule::import(py, "numpy")?; - let kwargs = PyDict::new(py); - if let Some(dtype) = dtype { - kwargs.set_item("dtype", dtype)?; + let array = PyArray2::from_vec2(py, &triangulation.core.vertices)?; + match dtype { + Some(dtype) if !dtype.is_none() => Ok(array.call_method1("astype", (dtype,))?.unbind()), + _ => Ok(array.into_any().unbind()), } - let array = if copy == Some(false) { - numpy.call_method("asarray", (vertices,), Some(&kwargs))? - } else { - if let Some(copy) = copy { - kwargs.set_item("copy", copy)?; - } - numpy.call_method("array", (vertices,), Some(&kwargs))? - }; - Ok(array.into()) } } @@ -609,20 +678,32 @@ impl PyTriangulation { self.core.dim } + /// `None` entries pass through as `None`, like the reference (which maps + /// every index through `get_vertex`); adaptive relies on this when it + /// feeds the result of `get_opposing_vertices` straight back in. #[pyo3(name = "get_vertices")] fn get_vertices_method( &self, py: Python<'_>, indices: &Bound<'_, PyAny>, ) -> PyResult> { - let indices = ordered_indices_from_py(indices, self.core.vertices.len())?; - point_list_py( - py, - &self - .core - .get_vertices(&indices) - .map_err(TriangulationError::into_pyerr)?, - ) + let Ok(iter) = indices.try_iter() else { + return Err(PyTypeError::new_err( + "Expected an iterable of vertex indices", + )); + }; + let mut items: Vec> = Vec::new(); + for item in iter { + let item = item?; + if item.is_none() { + items.push(py.None()); + continue; + } + let index = normalize_index(item.extract::()?, self.core.vertices.len()) + .map_err(TriangulationError::into_pyerr)?; + items.push(point_tuple(py, &self.core.vertices[index]).into()); + } + Ok(PyList::new(py, items)?.into()) } fn locate_point(&self, py: Python<'_>, point: &Bound<'_, PyAny>) -> PyResult> { @@ -824,6 +905,83 @@ impl PyTriangulation { Ok((simplex_set_py(py, &deleted)?, simplex_set_py(py, &added)?)) } + #[pyo3(signature = (index))] + fn get_vertex(&self, py: Python<'_>, index: Option) -> PyResult> { + match index { + None => Ok(py.None()), + Some(index) => { + let index = normalize_index(index, self.core.vertices.len()) + .map_err(TriangulationError::into_pyerr)?; + Ok(point_tuple(py, &self.core.vertices[index]).into()) + } + } + } + + fn get_neighbors_from_vertices( + &self, + py: Python<'_>, + simplex: &Bound<'_, PyAny>, + ) -> PyResult> { + let indices = ordered_indices_from_py(simplex, self.core.vertices.len())?; + let neighbors = self + .core + .neighbors_from_vertices(&indices) + .map_err(TriangulationError::into_pyerr)?; + simplex_set_py(py, &neighbors) + } + + fn get_simplices_attached_to_points( + &self, + py: Python<'_>, + indices: &Bound<'_, PyAny>, + ) -> PyResult> { + let indices = ordered_indices_from_py(indices, self.core.vertices.len())?; + let attached = self + .core + .simplices_attached_to_points(&indices) + .map_err(TriangulationError::into_pyerr)?; + simplex_set_py(py, &attached) + } + + fn get_opposing_vertices( + &self, + py: Python<'_>, + simplex: &Bound<'_, PyAny>, + ) -> PyResult> { + let simplex = ordered_indices_from_py(simplex, self.core.vertices.len())?; + let opposing = self + .core + .opposing_vertices(&simplex) + .map_err(TriangulationError::into_pyerr)?; + Ok(PyTuple::new(py, opposing)?.into()) + } + + /// Keep only the simplices sharing a whole face with `simplex`. A pure + /// set filter on the arguments, like the reference implementation; the + /// simplices are not required to be part of the triangulation. + fn get_face_sharing_neighbors( + &self, + py: Python<'_>, + neighbors: &Bound<'_, PyAny>, + simplex: &Bound<'_, PyAny>, + ) -> PyResult> { + let neighbors = simplex_set_from_py(neighbors, self.core.vertices.len())?; + let simplex = ordered_indices_from_py(simplex, self.core.vertices.len())?; + let simplex_set: FxHashSet = simplex.iter().copied().collect(); + let sharing: FxHashSet = neighbors + .into_iter() + .filter(|neighbor| { + let shared: FxHashSet = neighbor + .iter() + .copied() + .filter(|vertex| simplex_set.contains(vertex)) + .collect(); + shared.len() == self.core.dim + }) + .collect(); + simplex_set_py(py, &sharing) + } + fn vertex_invariant(&self, _vertex: usize) -> PyResult { Err(PyNotImplementedError::new_err("vertex_invariant")) } diff --git a/src/triangulation.rs b/src/triangulation.rs index 9f8e06e..c8e5c2b 100644 --- a/src/triangulation.rs +++ b/src/triangulation.rs @@ -822,6 +822,89 @@ impl Triangulation { .collect()) } + /// All simplices that share at least one vertex with the given vertex + /// set (the union of the vertices' simplex sets). + pub fn neighbors_from_vertices( + &self, + indices: &[usize], + ) -> Result, TriangulationError> { + self.validate_simplex_indices(indices)?; + let mut neighbors = FxHashSet::default(); + for &vertex in indices { + neighbors.extend(self.simplices_of(vertex).cloned()); + } + Ok(neighbors) + } + + /// All simplices sharing exactly `dim` vertices (a whole facet) with the + /// given vertex set. For a full simplex these are its facet neighbours; + /// the simplex itself (sharing dim+1 vertices) is excluded. + pub fn simplices_attached_to_points( + &self, + indices: &[usize], + ) -> Result, TriangulationError> { + self.validate_simplex_indices(indices)?; + let index_set: FxHashSet = indices.iter().copied().collect(); + let mut seen: FxHashSet = FxHashSet::default(); + let mut attached = FxHashSet::default(); + for &vertex in indices { + for &id in &self.vertex_to_ids[vertex] { + if !seen.insert(id) { + continue; + } + let simplex = self.simplex_by_id(id); + let shared = simplex + .iter() + .filter(|vertex| index_set.contains(vertex)) + .count(); + if shared == self.dim { + attached.insert(simplex.clone()); + } + } + } + Ok(attached) + } + + /// For each vertex of `simplex` (in the given order), the vertex of the + /// facet neighbour on the other side of the opposite facet, or `None` + /// when that facet lies on the hull. Errors when the simplex is not part + /// of the triangulation. + pub fn opposing_vertices( + &self, + simplex: &[usize], + ) -> Result>, TriangulationError> { + let mut sorted = simplex.to_vec(); + sorted.sort_unstable(); + self.validate_simplex_indices(&sorted)?; + let Some(&own_id) = self.ids.get(&sorted) else { + return Err(TriangulationError::Value( + "Provided simplex is not part of the triangulation".to_string(), + )); + }; + + simplex + .iter() + .map(|&vertex| { + let position = sorted + .iter() + .position(|&other| other == vertex) + .expect("vertex comes from the simplex itself"); + let facet = facet_excluding(&sorted, position); + let neighbour = self + .facets + .get(&facet) + .and_then(|incident| incident.iter().copied().find(|&id| id != own_id)); + Ok(neighbour.map(|id| { + self.simplex_by_id(id) + .iter() + .copied() + .find(|other| !sorted.contains(other)) + .expect("facet neighbour has exactly one vertex outside the simplex") + })) + }) + .collect() + } + fn simplex_points( &self, simplex: &[usize], diff --git a/tests/test_triangulation.py b/tests/test_triangulation.py index f27c58a..41ea2c6 100644 --- a/tests/test_triangulation.py +++ b/tests/test_triangulation.py @@ -881,3 +881,88 @@ def test_random_cross_validation_1d(): def test_public_bowyer_watson_is_exposed(): tri = rust_tri.Triangulation([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) assert hasattr(tri, "bowyer_watson") + + +@pytest.mark.parametrize("dim", [2, 3]) +def test_neighbor_query_methods_match_reference(dim): + rng = np.random.default_rng(31337 + dim) + coords = rng.random((10, dim)) + rust = rust_tri.Triangulation(coords[: dim + 1]) + reference = reference_module.Triangulation(coords[: dim + 1]) + for point in coords[dim + 1 :]: + rust.add_point(point) + reference.add_point(point) + + assert rust.get_vertex(None) is None + for index in (0, len(coords) - 1, -1): + assert_points_close(rust.get_vertex(index), reference.get_vertex(index)) + + for simplex in sorted(as_simplex_set(reference.simplices)): + assert as_simplex_set(rust.get_neighbors_from_vertices(simplex)) == as_simplex_set( + reference.get_neighbors_from_vertices(simplex) + ) + assert as_simplex_set(rust.get_simplices_attached_to_points(simplex)) == as_simplex_set( + reference.get_simplices_attached_to_points(simplex) + ) + assert rust.get_opposing_vertices(simplex) == reference.get_opposing_vertices(simplex) + neighbors = reference.get_neighbors_from_vertices(simplex) + assert as_simplex_set( + rust.get_face_sharing_neighbors(neighbors, simplex) + ) == as_simplex_set(reference.get_face_sharing_neighbors(neighbors, simplex)) + + +def test_get_opposing_vertices_rejects_unknown_simplex(): + rust = rust_tri.Triangulation([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + reference = reference_module.Triangulation([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + + assert_same_exception_type_name( + lambda: rust.get_opposing_vertices((0, 1, 2)), + lambda: reference.get_opposing_vertices((0, 1, 2)), + ) + + +def test_learnernd_with_neighbor_aware_loss_runs(): + # curvature_loss_function produces a loss with nth_neighbors=1, which + # exercises get_opposing_vertices / get_simplices_attached_to_points on + # every tell -- the part of the LearnerND surface that the default loss + # never touches. + from adaptive.learner import learnerND + from adaptive.learner.learnerND import LearnerND, curvature_loss_function + + original = learnerND.Triangulation + learnerND.Triangulation = rust_tri.Triangulation + try: + learner = LearnerND( + lambda xy: xy[0] * xy[1], + bounds=[(-1, 1), (-1, 1)], + loss_per_simplex=curvature_loss_function(), + ) + for _ in range(100): + points, _ = learner.ask(1) + for p in points: + learner.tell(p, p[0] * p[1]) + finally: + learnerND.Triangulation = original + assert learner.npoints >= 100 + + +def test_simplices_proxy_supports_set_operators(): + tri = rust_tri.Triangulation([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + simplices = as_simplex_set(tri.simplices) + one = {next(iter(simplices))} + + assert tri.simplices - one == simplices - one + assert one - tri.simplices == set() + assert tri.simplices & one == one + assert one & tri.simplices == one + assert tri.simplices | one == simplices + assert one | tri.simplices == simplices + assert tri.simplices ^ one == simplices - one + assert one ^ tri.simplices == simplices - one + + +def test_get_vertices_passes_none_through(): + tri = rust_tri.Triangulation([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + reference = reference_module.Triangulation([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) + + assert tri.get_vertices([0, None, 2]) == reference.get_vertices([0, None, 2])