Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 102 additions & 11 deletions geomfum/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __call__(self, fmap_matrix, basis_a, basis_b):
return p2p



class SoftmaxNeighborFinder(BaseNeighborFinder, nn.Module):
"""Softmax neighbor finder.

Expand Down Expand Up @@ -190,11 +191,17 @@ def forward(self, X, Y):
neigs : array-like, shape=[n_points_x, n_neighbors]
Indices of the nearest neighbors in Y for each point in X.
"""
P = self.softmax_matrix(X, Y)
# Get the indices of the top-k (self.n_neighbors) highest values for each row
indices = torch.topk(P, self.n_neighbors, dim=-1)[1]
similarity = torch.mm(X, Y.T)

if self.n_neighbors == 1:
# For single neighbor, use argmax (faster than topk)
indices = torch.argmax(similarity / self.tau, dim=-1, keepdim=True)
else:
# Use topk but with optimizations
scaled_similarity = similarity / self.tau
indices = torch.topk(scaled_similarity, self.n_neighbors, dim=-1, sorted=False)[1]

return indices

def softmax_matrix(self, X, Y):
"""Compute the permutation matrix P as a softmax of the similarity.

Expand All @@ -210,15 +217,99 @@ def softmax_matrix(self, X, Y):
P : array-like, shape=[n_points_x, n_points_y]
Permutation matrix, where each row sums to 1.
"""
similarity = torch.mm( X, Y.T)
similarity = torch.mm(X, Y.T)

P = torch.softmax(similarity / self.tau, dim=-1)

return P

P = torch.exp(
similarity / self.tau
- torch.logsumexp(similarity / self.tau, dim=-1, keepdim=True)
)
class GPUEuclideanNeighborFinder(BaseNeighborFinder, nn.Module):
"""GPU-based Euclidean neighbor finder.

return P
Finds exact nearest neighbors using Euclidean distance on GPU.
Uses brute-force distance computation for exact results.

Parameters
----------
n_neighbors : int
Number of neighbors.
"""

def __init__(self, n_neighbors=1):
BaseNeighborFinder.__init__(self, n_neighbors=n_neighbors)
nn.Module.__init__(self)

def __call__(self, X, Y):
"""Return indices of the points in `X` nearest to the points in `Y`.

Parameters
----------
X : array-like, shape=[n_points_x, n_features]
Reference points.
Y : array-like, shape=[n_points_y, n_features]
Query points.

Returns
-------
neigs : array-like, shape=[n_points_x, n_neighbors]
Indices of the nearest neighbors in Y for each point in X.
"""
return self.forward(X, Y)

def forward(self, X, Y):
"""Find k nearest neighbors using exact Euclidean distance.

Parameters
----------
X : array-like, shape=[n_points_x, n_features]
Reference points.
Y : array-like, shape=[n_points_y, n_features]
Query points.

Returns
-------
neigs : array-like, shape=[n_points_x, n_neighbors]
Indices of the nearest neighbors in Y for each point in X.
"""
# Compute squared Euclidean distances
# ||x - y||^2 = ||x||^2 + ||y||^2 - 2<x,y>
X_norm_sq = torch.sum(X**2, dim=1, keepdim=True) # [n_points_x, 1]
Y_norm_sq = torch.sum(Y**2, dim=1, keepdim=False) # [n_points_y]
similarity = torch.mm(X, Y.T) # [n_points_x, n_points_y]

distances = X_norm_sq + Y_norm_sq - 2 * similarity

if self.n_neighbors == 1:
# For single neighbor, use argmin (faster than topk)
indices = torch.argmin(distances, dim=-1, keepdim=True)
else:
# Use topk with smallest=True for minimum distances
indices = torch.topk(distances, self.n_neighbors, dim=-1,
largest=False, sorted=False)[1]

return indices

def distance_matrix(self, X, Y):
"""Compute full distance matrix between X and Y.

Parameters
----------
X : array-like, shape=[n_points_x, n_features]
Reference points.
Y : array-like, shape=[n_points_y, n_features]
Query points.

Returns
-------
distances : array-like, shape=[n_points_x, n_points_y]
Euclidean distance matrix.
"""
X_norm_sq = torch.sum(X**2, dim=1, keepdim=True)
Y_norm_sq = torch.sum(Y**2, dim=1, keepdim=False)
similarity = torch.mm(X, Y.T)

distances = X_norm_sq + Y_norm_sq - 2 * similarity
return torch.sqrt(torch.clamp(distances, min=0)) # Clamp for numerical stability

class SinkhornP2pFromFmConverter(P2pFromFmConverter):
"""Pointwise map from functional map using Sinkhorn filters.
Expand Down Expand Up @@ -451,5 +542,5 @@ def __call__(self, nam, basis_a, basis_b):
emb1 = nam(xgs.to_torch(basis_a.full_vecs[:, :k2]).to(nam.device).double())
emb2 = xgs.to_torch(basis_b.full_vecs[:, :k1]).to(nam.device).double()

p2p = self.neighbor_finder(emb2.detach().cpu(), emb1.detach().cpu()).flatten()
p2p = self.neighbor_finder(emb2.detach(), emb1.detach()).flatten()
return p2p
Loading