Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 192 additions & 7 deletions orb_models/forcefield/forcefield_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from orb_models.common.atoms import graph_featurization as graph_feat
from orb_models.common.atoms.abstract_atoms_adapter import AbstractAtomsAdapter
from orb_models.common.atoms.batch.graph_batch import AtomGraphs
from orb_models.common.torch_utils import get_device


class ForcefieldAtomsAdapter(AbstractAtomsAdapter):
Expand Down Expand Up @@ -115,12 +116,14 @@ def from_ase_atoms(
positions = feat_utils.map_to_pbc_cell(positions, cell, pbc)

max_num_neighbors = max_num_neighbors or self.max_num_neighbors
assert self.radius is not None, "radius must be set"
assert max_num_neighbors is not None, "max_num_neighbors must be set"
edge_index, edge_vectors, unit_shifts = graph_feat.compute_pbc_radius_graph(
positions=positions,
cell=cell,
pbc=pbc,
radius=self.radius, # type: ignore
max_number_neighbors=max_num_neighbors, # type: ignore
radius=self.radius,
max_number_neighbors=max_num_neighbors,
edge_method=edge_method,
half_supercell=half_supercell,
float_dtype=graph_construction_dtype,
Expand Down Expand Up @@ -167,15 +170,192 @@ def from_ase_atoms(
system_targets=deepcopy(graph_targets),
fix_atoms=feat_utils.ase_fix_atoms_to_tensor(atoms),
tags=feat_utils.get_ase_tags(atoms),
radius=self.radius, # type: ignore
radius=self.radius,
max_num_neighbors=torch.tensor([max_num_neighbors]),
system_id=(torch.LongTensor([system_id]) if system_id is not None else system_id),
).to(device=device, dtype=output_dtype)

@override
def from_ase_atoms_list(
self,
atoms: list[ase.Atoms],
*,
max_num_neighbors: int | None = None,
edge_method: graph_feat.EdgeCreationMethod | None = None,
wrap: bool = True,
device: torch.device | str | None = None,
output_dtype: torch.dtype | None = None,
graph_construction_dtype: torch.dtype | None = None,
) -> AtomGraphs:
"""Convert a list of ase.Atoms into a single batched AtomGraphs using parallel graph construction.

This method leverages the batched Alchemi-based graph construction for better performance
compared to processing atoms one by one.

Args:
atoms: List of ase.Atoms objects.
max_num_neighbors: Maximum number of neighbors each node can send messages to.
If None, will use self.max_num_neighbors.
edge_method: The method to use for graph edge construction. If None, defaults to knn_alchemi.
wrap: Whether to wrap atomic positions into the central unit cell.
device: The device to put the tensors on.
output_dtype: The dtype to use for all floating point tensors on the AtomGraphs.
graph_construction_dtype: The dtype to use for floating point tensors in graph construction.
"""
if len(atoms) == 0:
raise ValueError("atoms list must not be empty")

# Fall back to sequential processing for single atom
if len(atoms) == 1:
return self.from_ase_atoms(
atoms[0],
edge_method=edge_method,
wrap=wrap,
device=device,
output_dtype=output_dtype,
graph_construction_dtype=graph_construction_dtype,
)

output_dtype = torch.get_default_dtype() if output_dtype is None else output_dtype
graph_construction_dtype = (
torch.get_default_dtype()
if graph_construction_dtype is None
else graph_construction_dtype
)

# Resolve device early so all tensors are on the same device
resolved_device = get_device(device)

for a in atoms:
self._validate_inputs(a, output_dtype)

# Extract per-system data
all_positions = []
all_cells = []
all_pbcs = []
all_atomic_numbers = []
all_atomic_numbers_embedding = []
all_fix_atoms = []
all_tags = []
n_atoms = []

for a in atoms:
positions_i = torch.from_numpy(a.positions)
cell_i = torch.from_numpy(a.cell.array)
pbc_i = torch.from_numpy(a.pbc)

all_positions.append(positions_i)
all_cells.append(cell_i)
all_pbcs.append(pbc_i)
all_atomic_numbers.append(torch.from_numpy(a.numbers).to(torch.long))
all_atomic_numbers_embedding.append(feat_utils.get_atom_embedding(a))
all_fix_atoms.append(feat_utils.ase_fix_atoms_to_tensor(a))
all_tags.append(feat_utils.get_ase_tags(a))
n_atoms.append(len(a))

# Build batched tensors and move to resolved device
positions = torch.cat(all_positions, dim=0).to(device=resolved_device)
cells = torch.stack(all_cells, dim=0).to(device=resolved_device)
pbcs = torch.stack(all_pbcs, dim=0).to(device=resolved_device)
n_node = torch.tensor(n_atoms, dtype=torch.long, device=resolved_device)
node_batch_index = torch.arange(
len(atoms), dtype=torch.int64, device=resolved_device
).repeat_interleave(n_node)

if wrap:
positions = feat_utils.batch_map_to_pbc_cell(
positions=positions, cell=cells, pbc=pbcs, n_node=n_node
)

max_num_neighbors = max_num_neighbors or self.max_num_neighbors
assert max_num_neighbors is not None, "max_num_neighbors must be set"
assert self.radius is not None, "radius must be set"
max_num_neighbors_tensor = torch.full_like(n_node, fill_value=max_num_neighbors)
(
edge_index,
edge_vectors,
unit_shifts,
batch_num_edges,
) = graph_feat.batch_compute_pbc_radius_graph(
positions=positions.contiguous(),
cells=cells,
pbcs=pbcs,
radius=torch.tensor([self.radius], device=resolved_device),
max_number_neighbors=max_num_neighbors_tensor,
n_node=n_node,
node_batch_index=node_batch_index,
edge_method=edge_method,
device=resolved_device,
)
senders = edge_index[0].long()
receivers = edge_index[1].long()

atomic_numbers = torch.cat(all_atomic_numbers, dim=0)
atomic_numbers_embedding = torch.cat(all_atomic_numbers_embedding, dim=0)

# Concatenate fix_atoms: None if no system has constraints
if any(f is not None for f in all_fix_atoms):
fix_atoms = torch.cat(
[
f if f is not None else torch.zeros(n, dtype=torch.bool)
for f, n in zip(all_fix_atoms, n_atoms, strict=True)
],
dim=0,
)
else:
fix_atoms = None

tags = torch.cat(all_tags, dim=0)

node_feats = {
"positions": positions,
"atomic_numbers": atomic_numbers,
"atomic_numbers_embedding": atomic_numbers_embedding,
}
edge_feats = {
"vectors": edge_vectors,
"unit_shifts": unit_shifts.to(dtype=output_dtype),
}
graph_feats: dict[str, torch.Tensor] = {
"cell": cells,
"pbc": pbcs,
}
# Collect charge and spin: all-or-nothing semantics
charge_spin_list = [_get_charge_and_spin(a) for a in atoms]
has_charge_spin = [bool(cs) for cs in charge_spin_list]
if any(has_charge_spin):
if not all(has_charge_spin):
raise ValueError("Either all atoms must have charge and spin, or none of them.")
graph_feats["total_charge"] = torch.cat(
[cs["total_charge"] for cs in charge_spin_list], dim=0
)
graph_feats["spin_multiplicity"] = torch.cat(
[cs["spin_multiplicity"] for cs in charge_spin_list], dim=0
)

return AtomGraphs(
senders=senders,
receivers=receivers,
n_node=n_node,
n_edge=batch_num_edges,
node_features=node_feats,
edge_features=edge_feats,
system_features=graph_feats,
node_targets={},
edge_targets={},
system_targets={},
system_id=None,
fix_atoms=fix_atoms,
tags=tags,
radius=self.radius,
max_num_neighbors=max_num_neighbors_tensor,
).to(device=resolved_device, dtype=output_dtype)

def from_torchsim_state(
self,
state: ts.SimState,
*,
max_num_neighbors: int | None = None,
edge_method: graph_feat.EdgeCreationMethod | None = None,
wrap: bool = True,
device: torch.device | str | None = None,
Expand All @@ -188,6 +368,8 @@ def from_torchsim_state(

Args:
state: SimState object containing atomic positions, cell, and atomic numbers.
max_num_neighbors: Maximum number of neighbors each node can send messages to.
If None, will use self.max_num_neighbors.
edge_method (EdgeCreationMethod, optional): The method to use for graph edge
construction. If None, the edge method is chosen automatically.
wrap: Whether to wrap atomic positions into the central unit cell.
Expand Down Expand Up @@ -226,7 +408,10 @@ def from_torchsim_state(
positions=positions, cell=cell, pbc=pbc, n_node=n_node
)

max_num_neighbors = torch.full_like(n_node, fill_value=self.max_num_neighbors) # type: ignore[arg-type]
max_num_neighbors = max_num_neighbors or self.max_num_neighbors
assert self.radius is not None, "radius must be set"
assert max_num_neighbors is not None, "max_num_neighbors must be set"
max_num_neighbors_tensor = torch.full_like(n_node, fill_value=max_num_neighbors)
(
edge_index,
edge_vectors,
Expand All @@ -237,7 +422,7 @@ def from_torchsim_state(
cells=cell,
pbcs=pbc,
radius=torch.tensor([self.radius], device=device),
max_number_neighbors=max_num_neighbors,
max_number_neighbors=max_num_neighbors_tensor,
n_node=n_node,
node_batch_index=node_batch_index,
edge_method=edge_method,
Expand Down Expand Up @@ -279,8 +464,8 @@ def from_torchsim_state(
system_id=None,
fix_atoms=None,
tags=None,
radius=self.radius, # type: ignore
max_num_neighbors=max_num_neighbors, # type: ignore
radius=self.radius,
max_num_neighbors=max_num_neighbors_tensor,
).to(device=device, dtype=output_dtype)

def is_compatible_with(self, other: AbstractAtomsAdapter):
Expand Down
Loading
Loading