Skip to content

Commit 083561a

Browse files
committed
Fix neighbor list bug for non-orthogonal cells
1 parent b32189f commit 083561a

2 files changed

Lines changed: 14 additions & 23 deletions

File tree

examples/simple_inference.cpp

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -188,27 +188,6 @@ int main(int argc, char **argv) {
188188
return 1;
189189
}
190190

191-
// Build neighbor list
192-
NeighborListOptions nl_opts;
193-
nl_opts.cutoff =
194-
5.0f; // PET-MAD uses 5.0 A cutoff (match PyTorch for comparison)
195-
nl_opts.full_list = true;
196-
197-
NeighborListBuilder builder(nl_opts);
198-
auto nlist = builder.build(system);
199-
200-
log::debug("Built neighbor list with {} pairs", nlist.num_pairs());
201-
202-
// Print first 5 edges
203-
for (int e = 0; e < std::min(5, nlist.num_pairs()); ++e) {
204-
log::trace(
205-
" {}: ({}->{}) shift=[{},{},{}] D=[{:.3f},{:.3f},{:.3f}] d={:.3f}", e,
206-
nlist.centers[e], nlist.neighbors[e], nlist.cell_shifts[e][0],
207-
nlist.cell_shifts[e][1], nlist.cell_shifts[e][2],
208-
nlist.edge_vectors[e][0], nlist.edge_vectors[e][1],
209-
nlist.edge_vectors[e][2], nlist.distances[e]);
210-
}
211-
212191
// Load model and run inference
213192
try {
214193
log::info("Loading model from {}", model_path);
@@ -234,6 +213,15 @@ int main(int argc, char **argv) {
234213
log::info("Overriding cutoff to: {:.2f} A", cutoff_override);
235214
}
236215

216+
// Log neighbor count using model's cutoff
217+
{
218+
NeighborListBuilder nl_builder(
219+
NeighborListOptions{pet_model.cutoff(), true, false});
220+
auto nlist = nl_builder.build(system);
221+
log::info("Neighbor pairs: {} (avg {:.1f} per atom)", nlist.num_pairs(),
222+
static_cast<double>(nlist.num_pairs()) / system.num_atoms());
223+
}
224+
237225
static constexpr std::array backend_names = {"auto", "cpu", "cuda", "hip",
238226
"metal", "vulkan", "sycl", "cann"};
239227
log::info("Backend preference: {}", backend_names[static_cast<size_t>(backend_pref)]);

src/core/neighbor_list.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,12 @@ struct CellGrid {
6767

6868
std::array<int, 3> position_to_bin(const Vec3 &pos) const {
6969
double frac[3];
70+
// Compute fractional coordinates: frac = pos @ inv_cell (i.e., inv_cell.T @ pos)
71+
// For the standard convention where rows of the cell matrix are lattice vectors,
72+
// fractional coords are frac[i] = sum_j pos[j] * inv_cell[j,i]
7073
for (int i = 0; i < 3; ++i) {
71-
frac[i] = inv_cell[i * 3 + 0] * pos.x + inv_cell[i * 3 + 1] * pos.y +
72-
inv_cell[i * 3 + 2] * pos.z;
74+
frac[i] = inv_cell[0 * 3 + i] * pos.x + inv_cell[1 * 3 + i] * pos.y +
75+
inv_cell[2 * 3 + i] * pos.z;
7376
}
7477
return {static_cast<int>(std::floor(frac[0] * n_bins[0])),
7578
static_cast<int>(std::floor(frac[1] * n_bins[1])),

0 commit comments

Comments
 (0)