Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions geomfum/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def from_registry(cls, *args, which=None, **kwargs):
return instantiator(*args, **kwargs)


class MeshWhichRegistryMixins:
"""Mixin for registry-based instantiation with mesh/point cloud distinction."""
class ShapeWhichRegistryMixins:
"""Mixin for registry-based instantiation with shape type distinction."""

def __init__(self, *args, **kwargs):
# TODO: has to be improved
Expand All @@ -259,13 +259,13 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def from_registry(cls, *args, mesh=True, which=None, **kwargs):
def from_registry(cls, *args, shape_type="mesh", which=None, **kwargs):
"""Create instance from registered implementation based on shape type.

Parameters
----------
mesh : bool
Whether a mesh or point cloud.
shape_type : str
Type of shape (e.g. ``"mesh"``, ``"pointcloud"``).
which : str
A registered implementation.

Expand All @@ -274,7 +274,7 @@ def from_registry(cls, *args, mesh=True, which=None, **kwargs):
obj : Obj
An instantiated object.
"""
instantiator = cls._Registry.get(mesh, which)
instantiator = cls._Registry.get(shape_type, which)
if instantiator is None:
obj = cls.__new__(cls)
obj.__init__(*args, **kwargs)
Expand All @@ -295,8 +295,8 @@ class _PointSetLaplacianFinderRegistry(Registry):

class LaplacianFinderRegistry(NestedRegistry):
Registries = {
True: _MeshLaplacianFinderRegistry,
False: _PointSetLaplacianFinderRegistry,
"mesh": _MeshLaplacianFinderRegistry,
"pointcloud": _PointSetLaplacianFinderRegistry,
}


Expand Down Expand Up @@ -374,8 +374,8 @@ class _PointSetHeatDistanceMetricRegistry(Registry):

class HeatDistanceMetricRegistry(NestedRegistry):
Registries = {
True: _MeshHeatDistanceMetricRegistry,
False: _PointSetHeatDistanceMetricRegistry,
"mesh": _MeshHeatDistanceMetricRegistry,
"pointcloud": _PointSetHeatDistanceMetricRegistry,
}


Expand Down
4 changes: 2 additions & 2 deletions geomfum/laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import gsops.backend as gs

import geomfum.wrap as _wrap # noqa (for register)
from geomfum._registry import LaplacianFinderRegistry, MeshWhichRegistryMixins
from geomfum._registry import LaplacianFinderRegistry, ShapeWhichRegistryMixins
from geomfum.basis import LaplaceEigenBasis
from geomfum.numerics.eig import ScipyEigsh

Expand All @@ -31,7 +31,7 @@ def __call__(self, shape):
"""


class LaplacianFinder(MeshWhichRegistryMixins, BaseLaplacianFinder):
class LaplacianFinder(ShapeWhichRegistryMixins, BaseLaplacianFinder):
"""Algorithm to find the Laplacian."""

_Registry = LaplacianFinderRegistry
Expand Down
4 changes: 2 additions & 2 deletions geomfum/metric/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import gsops.backend as gs

from geomfum._registry import HeatDistanceMetricRegistry, MeshWhichRegistryMixins
from geomfum._registry import HeatDistanceMetricRegistry, ShapeWhichRegistryMixins


class Metric(abc.ABC):
Expand Down Expand Up @@ -135,7 +135,7 @@ def dist_matrix(self):
return self.dist_from_source(gs.arange(self._shape.n_vertices))[0]


class HeatDistanceMetric(MeshWhichRegistryMixins):
class HeatDistanceMetric(ShapeWhichRegistryMixins):
"""Geodesic distance approximation using the heat method.

References
Expand Down
2 changes: 1 addition & 1 deletion geomfum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def find(self, laplacian_finder=None, recompute=False):

if laplacian_finder is None:
laplacian_finder = LaplacianFinder.from_registry(
mesh=self._shape.is_mesh, which="robust"
shape_type=self._shape.shape_type, which="robust"
)

self._stiffness_matrix, self._mass_matrix = laplacian_finder(self._shape)
Expand Down
8 changes: 4 additions & 4 deletions geomfum/shape/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class Shape(abc.ABC):

Parameters
----------
is_mesh : bool
Whether the shape is a mesh (True) or point cloud (False).
shape_type : str
Type of shape (e.g. ``"mesh"``, ``"pointcloud"``).
"""

def __init__(self, is_mesh):
self.is_mesh = is_mesh
def __init__(self, shape_type):
self.shape_type = shape_type

self._basis = None
self.laplacian = Laplacian(self)
Expand Down
2 changes: 1 addition & 1 deletion geomfum/shape/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
vertices,
faces,
):
super().__init__(is_mesh=True)
super().__init__(shape_type="mesh")
self.vertices = gs.asarray(vertices)
self.faces = gs.asarray(faces)

Expand Down
2 changes: 1 addition & 1 deletion geomfum/shape/point_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PointCloud(Shape):
"""

def __init__(self, vertices):
super().__init__(is_mesh=False)
super().__init__(shape_type="pointcloud")
self.vertices = gs.asarray(vertices)

self.n_neighbors = 30
Expand Down
12 changes: 6 additions & 6 deletions geomfum/wrap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@
from geomfum._utils import has_package

register_laplacian_finder(
True,
"mesh",
"pyfm",
"PyfmMeshLaplacianFinder",
requires="pyFM",
as_default=not has_package("robust_laplacian"),
)

register_laplacian_finder(
True,
"mesh",
"robust",
"RobustMeshLaplacianFinder",
requires="robust_laplacian",
as_default=has_package("robust_laplacian"),
)

register_laplacian_finder(True, "igl", "IglMeshLaplacianFinder", requires="igl")
register_laplacian_finder("mesh", "igl", "IglMeshLaplacianFinder", requires="igl")


register_laplacian_finder(
False, "robust", "RobustPointCloudLaplacianFinder", requires="robust_laplacian"
"pointcloud", "robust", "RobustPointCloudLaplacianFinder", requires="robust_laplacian"
)

register_heat_kernel_signature(
Expand Down Expand Up @@ -120,11 +120,11 @@


register_heat_distance_metric(
True, "pp3d", "Pp3dMeshHeatDistanceMetric", requires="potpourri3d", as_default=True
"mesh", "pp3d", "Pp3dMeshHeatDistanceMetric", requires="potpourri3d", as_default=True
)

register_heat_distance_metric(
False,
"pointcloud",
"pp3d",
"Pp3dPointSetHeatDistanceMetric",
requires="potpourri3d",
Expand Down
Loading