Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/models/test_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def ts_mace_model() -> MaceModel:
dtype=DTYPE,
)

test_mace_model_outputs = make_validate_model_outputs_test(
model_fixture_name="ts_mace_model", device=DEVICE, dtype=DTYPE
)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_mace_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None:
Expand Down Expand Up @@ -101,10 +105,6 @@ def ts_mace_off_model() -> MaceModel:
dtype=DTYPE,
)

test_mace_off_model_outputs = make_validate_model_outputs_test(
model_fixture_name="ts_mace_model", dtype=DTYPE
)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_mace_off_dtype_working(
Expand Down
57 changes: 40 additions & 17 deletions torch_sim/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def validate_model_outputs( # noqa: C901, PLR0915
validate_model_outputs(model, device=torch.device("cuda"), dtype=torch.float64)

Notes:
This validator creates small test systems (silicon and iron) for validation.
It tests both single and multi-batch processing capabilities.
This validator creates small test systems (diamond silicon, HCP magnesium,
and primitive BCC iron) for validation. It tests both single and
multi-batch processing capabilities.
"""
from ase.build import bulk

Expand All @@ -260,9 +261,9 @@ def validate_model_outputs( # noqa: C901, PLR0915
force_computed = False

si_atoms = bulk("Si", "diamond", a=5.43, cubic=True)
fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True).repeat([3, 1, 1])

sim_state = ts.io.atoms_to_state([si_atoms, fe_atoms], device, dtype)
mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21).repeat([3, 2, 1])
fe_atoms = bulk("Fe", "bcc", a=2.87)
sim_state = ts.io.atoms_to_state([si_atoms, mg_atoms, fe_atoms], device, dtype)

og_positions = sim_state.positions.clone()
og_cell = sim_state.cell.clone()
Expand Down Expand Up @@ -299,12 +300,12 @@ def validate_model_outputs( # noqa: C901, PLR0915
raise ValueError("stress not in model output")

# assert model output shapes are correct
if model_output["energy"].shape != (2,):
raise ValueError(f"{model_output['energy'].shape=} != (2,)")
if force_computed and model_output["forces"].shape != (20, 3):
raise ValueError(f"{model_output['forces'].shape=} != (20, 3)")
if stress_computed and model_output["stress"].shape != (2, 3, 3):
raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)")
if model_output["energy"].shape != (3,):
raise ValueError(f"{model_output['energy'].shape=} != (3,)")
if force_computed and model_output["forces"].shape != (21, 3):
raise ValueError(f"{model_output['forces'].shape=} != (21, 3)")
if stress_computed and model_output["stress"].shape != (3, 3, 3):
raise ValueError(f"{model_output['stress'].shape=} != (3, 3, 3)")

si_state = ts.io.atoms_to_state([si_atoms], device, dtype)

Expand All @@ -328,23 +329,45 @@ def validate_model_outputs( # noqa: C901, PLR0915
if stress_computed and si_model_output["stress"].shape != (1, 3, 3):
raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)")

mg_state = ts.io.atoms_to_state([mg_atoms], device, dtype)
mg_model_output = model.forward(mg_state)
if not torch.allclose(
mg_model_output["energy"], model_output["energy"][1], atol=1e-3
):
raise ValueError(f"{mg_model_output['energy']=} != {model_output['energy'][1]=}")
mg_n = mg_state.n_atoms
mg_slice = slice(si_state.n_atoms, si_state.n_atoms + mg_n)
if not torch.allclose(
forces := mg_model_output["forces"],
expected_forces := model_output["forces"][mg_slice],
atol=1e-3,
):
raise ValueError(f"{forces=} != {expected_forces=}")

# Test single Mg system output shapes (12 atoms)
if mg_model_output["energy"].shape != (1,):
raise ValueError(f"{mg_model_output['energy'].shape=} != (1,)")
if force_computed and mg_model_output["forces"].shape != (12, 3):
raise ValueError(f"{mg_model_output['forces'].shape=} != (12, 3)")
if stress_computed and mg_model_output["stress"].shape != (1, 3, 3):
raise ValueError(f"{mg_model_output['stress'].shape=} != (1, 3, 3)")

fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype)
fe_model_output = model.forward(fe_state)
if not torch.allclose(
fe_model_output["energy"], model_output["energy"][1], atol=1e-3
fe_model_output["energy"], model_output["energy"][2], atol=1e-3
):
raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}")
raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][2]=}")
if not torch.allclose(
forces := fe_model_output["forces"],
expected_forces := model_output["forces"][si_state.n_atoms :],
expected_forces := model_output["forces"][si_state.n_atoms + mg_n :],
atol=1e-3,
):
raise ValueError(f"{forces=} != {expected_forces=}")

# Test single Fe system output shapes (12 atoms)
if fe_model_output["energy"].shape != (1,):
raise ValueError(f"{fe_model_output['energy'].shape=} != (1,)")
if force_computed and fe_model_output["forces"].shape != (12, 3):
raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)")
if force_computed and fe_model_output["forces"].shape != (1, 3):
raise ValueError(f"{fe_model_output['forces'].shape=} != (1, 3)")
if stress_computed and fe_model_output["stress"].shape != (1, 3, 3):
raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)")
Loading