From fcc2058a64e20286272f433cdf10d1c42d258765 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 20 Mar 2026 16:19:09 -0400 Subject: [PATCH] fix: add single atom motif edge case to the validate function --- tests/models/test_mace.py | 8 ++--- torch_sim/models/interface.py | 57 ++++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 322f3d12..6a8e3f31 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -64,6 +64,10 @@ def ts_mace_model() -> MaceModel: dtype=DTYPE, ) +test_mace_model_outputs = make_validate_model_outputs_test( + model_fixture_name="ts_mace_model", device=DEVICE, dtype=DTYPE +) + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_mace_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: @@ -101,10 +105,6 @@ def ts_mace_off_model() -> MaceModel: dtype=DTYPE, ) -test_mace_off_model_outputs = make_validate_model_outputs_test( - model_fixture_name="ts_mace_model", dtype=DTYPE -) - @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_mace_off_dtype_working( diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 63681c8e..0d074153 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -236,8 +236,9 @@ def validate_model_outputs( # noqa: C901, PLR0915 validate_model_outputs(model, device=torch.device("cuda"), dtype=torch.float64) Notes: - This validator creates small test systems (silicon and iron) for validation. - It tests both single and multi-batch processing capabilities. + This validator creates small test systems (diamond silicon, HCP magnesium, + and primitive BCC iron) for validation. It tests both single and + multi-batch processing capabilities. """ from ase.build import bulk @@ -260,9 +261,9 @@ def validate_model_outputs( # noqa: C901, PLR0915 force_computed = False si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True).repeat([3, 1, 1]) - - sim_state = ts.io.atoms_to_state([si_atoms, fe_atoms], device, dtype) + mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21).repeat([3, 2, 1]) + fe_atoms = bulk("Fe", "bcc", a=2.87) + sim_state = ts.io.atoms_to_state([si_atoms, mg_atoms, fe_atoms], device, dtype) og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() @@ -299,12 +300,12 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError("stress not in model output") # assert model output shapes are correct - if model_output["energy"].shape != (2,): - raise ValueError(f"{model_output['energy'].shape=} != (2,)") - if force_computed and model_output["forces"].shape != (20, 3): - raise ValueError(f"{model_output['forces'].shape=} != (20, 3)") - if stress_computed and model_output["stress"].shape != (2, 3, 3): - raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)") + if model_output["energy"].shape != (3,): + raise ValueError(f"{model_output['energy'].shape=} != (3,)") + if force_computed and model_output["forces"].shape != (21, 3): + raise ValueError(f"{model_output['forces'].shape=} != (21, 3)") + if stress_computed and model_output["stress"].shape != (3, 3, 3): + raise ValueError(f"{model_output['stress'].shape=} != (3, 3, 3)") si_state = ts.io.atoms_to_state([si_atoms], device, dtype) @@ -328,23 +329,45 @@ def validate_model_outputs( # noqa: C901, PLR0915 if stress_computed and si_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)") + mg_state = ts.io.atoms_to_state([mg_atoms], device, dtype) + mg_model_output = model.forward(mg_state) + if not torch.allclose( + mg_model_output["energy"], model_output["energy"][1], atol=1e-3 + ): + raise ValueError(f"{mg_model_output['energy']=} != {model_output['energy'][1]=}") + mg_n = mg_state.n_atoms + mg_slice = slice(si_state.n_atoms, si_state.n_atoms + mg_n) + if not torch.allclose( + forces := mg_model_output["forces"], + expected_forces := model_output["forces"][mg_slice], + atol=1e-3, + ): + raise ValueError(f"{forces=} != {expected_forces=}") + + # Test single Mg system output shapes (12 atoms) + if mg_model_output["energy"].shape != (1,): + raise ValueError(f"{mg_model_output['energy'].shape=} != (1,)") + if force_computed and mg_model_output["forces"].shape != (12, 3): + raise ValueError(f"{mg_model_output['forces'].shape=} != (12, 3)") + if stress_computed and mg_model_output["stress"].shape != (1, 3, 3): + raise ValueError(f"{mg_model_output['stress'].shape=} != (1, 3, 3)") + fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) fe_model_output = model.forward(fe_state) if not torch.allclose( - fe_model_output["energy"], model_output["energy"][1], atol=1e-3 + fe_model_output["energy"], model_output["energy"][2], atol=1e-3 ): - raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}") + raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][2]=}") if not torch.allclose( forces := fe_model_output["forces"], - expected_forces := model_output["forces"][si_state.n_atoms :], + expected_forces := model_output["forces"][si_state.n_atoms + mg_n :], atol=1e-3, ): raise ValueError(f"{forces=} != {expected_forces=}") - # Test single Fe system output shapes (12 atoms) if fe_model_output["energy"].shape != (1,): raise ValueError(f"{fe_model_output['energy'].shape=} != (1,)") - if force_computed and fe_model_output["forces"].shape != (12, 3): - raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)") + if force_computed and fe_model_output["forces"].shape != (1, 3): + raise ValueError(f"{fe_model_output['forces'].shape=} != (1, 3)") if stress_computed and fe_model_output["stress"].shape != (1, 3, 3): raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)")