diff --git a/src/schnetpack/transform/neighborlist.py b/src/schnetpack/transform/neighborlist.py index 90ec4ff0b..68afe8975 100644 --- a/src/schnetpack/transform/neighborlist.py +++ b/src/schnetpack/transform/neighborlist.py @@ -230,7 +230,7 @@ def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff): class PymatgenNeighborList(NeighborListTransform): """ - Calculate neighbor list using pymatgen. + Calculate neighbor list using pymatgen. Automatically casts Z and positions to np.float64. """ def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff): @@ -241,6 +241,9 @@ def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff): device = positions.device dtype = positions.dtype + cell_np = cell_np.astype(np.float64, copy=False) + pos_np = pos_np.astype(np.float64, copy=False) + idx_i, idx_j, offsets, distances = find_points_in_spheres( pos_np, pos_np, diff --git a/tests/data/test_transforms.py b/tests/data/test_transforms.py index 23b70c64f..d3f7e6bda 100644 --- a/tests/data/test_transforms.py +++ b/tests/data/test_transforms.py @@ -22,6 +22,15 @@ def neighbor_list(request): return neighbor_lists[request.param] +@pytest.fixture(params=[0, 1]) +def precision(request): + precisions = [ + torch.float64, + torch.float32, + ] + return precisions[request.param] + + class TestNeighborLists: """ Test for different neighbor lists defined in neighbor_list using the Argon environment fixtures (periodic and @@ -29,8 +38,12 @@ class TestNeighborLists: """ - def test_neighbor_list(self, neighbor_list, environment): + def test_neighbor_list(self, neighbor_list, environment, precision): cutoff, props, neighbors_ref = environment + + if precision == torch.float32: + _ = CastTo32()(props) + neighbor_list = neighbor_list(cutoff) neighbors = neighbor_list(props) R = props[structure.R] @@ -44,6 +57,10 @@ def test_neighbor_list(self, neighbor_list, environment): neighbors_ref = self._sort_neighbors(neighbors_ref) for nbl, nbl_ref in zip(neighbors, neighbors_ref): + + if nbl_ref.dtype == torch.float64: + nbl_ref = nbl_ref.to(dtype=precision) + torch.testing.assert_close(nbl, nbl_ref) def _sort_neighbors(self, neighbors):