From 7d649d471de31ffb6fb5808a9296ed313f65a2b9 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 31 Mar 2026 14:38:10 -0700 Subject: [PATCH 1/2] default to fixed symmetry relaxation in elastic, throw warning to user --- src/atomate2/forcefields/flows/elastic.py | 37 +++++++++++++++++------ tests/forcefields/flows/test_elastic.py | 6 ++-- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/atomate2/forcefields/flows/elastic.py b/src/atomate2/forcefields/flows/elastic.py index 1a1f71ba09..5e7ef1a1b1 100644 --- a/src/atomate2/forcefields/flows/elastic.py +++ b/src/atomate2/forcefields/flows/elastic.py @@ -21,6 +21,7 @@ _DEFAULT_RELAX_KWARGS: dict[str, Any] = { "force_field_name": "CHGNet", "relax_kwargs": {"fmax": 0.00001}, + "fix_symmetry": True, } @@ -106,6 +107,7 @@ def from_force_field_name( cls, force_field_name: str | MLFF | dict, calculator_kwargs: dict | None = None, + relax_initial_structure: bool = True, **kwargs, ) -> Self: """ @@ -117,6 +119,9 @@ def from_force_field_name( The name of the force field. calculator_kwargs : dict or None (default) calculator_kwargs to pass to `ForceFieldRelaxMaker`. + relax_initial_structure : bool = True (default) + Whether to relax the structure before computing + the elastic tensor. **kwargs Additional kwargs to pass to ElasticMaker. @@ -124,6 +129,14 @@ def from_force_field_name( ------- ElasticMaker """ + warnings.warn( + "Fixed symmetry relaxations are automatically enabled " + "to improve elastic tensor stability. To disable this " + "specify ForceFieldRelaxMaker objects explicitly. ", + category=UserWarning, + stacklevel=2, + ) + if (mlff_kwargs := kwargs.pop("mlff_kwargs", None)) is not None: warnings.warn( "`mlff_kwargs` has been marked for deprecation. " @@ -148,18 +161,22 @@ def from_force_field_name( "force_field_name": force_field_name, "calculator_kwargs": calculator_kwargs or {}, } - bulk_relax_maker = ForceFieldRelaxMaker( - relax_cell=True, + + elastic_relax_maker = ForceFieldRelaxMaker( + relax_cell=False, **default_kwargs, ) - kwargs.update( - bulk_relax_maker=bulk_relax_maker, - elastic_relax_maker=ForceFieldRelaxMaker( - relax_cell=False, - **default_kwargs, - ), - ) + return cls( - name=f"{bulk_relax_maker.mlff.name} elastic", + name=f"{elastic_relax_maker.mlff.name} elastic", **kwargs, + bulk_relax_maker=( + ForceFieldRelaxMaker( + relax_cell=True, + **default_kwargs, + ) + if relax_initial_structure + else None + ), + elastic_relax_maker=elastic_relax_maker, ) diff --git a/tests/forcefields/flows/test_elastic.py b/tests/forcefields/flows/test_elastic.py index 706001ffa1..907007ca38 100644 --- a/tests/forcefields/flows/test_elastic.py +++ b/tests/forcefields/flows/test_elastic.py @@ -16,7 +16,7 @@ def test_elastic_wf_with_mace( si_prim = SpacegroupAnalyzer(si_structure).get_primitive_standard_structure() model_path = f"{test_dir}/forcefields/mace/MACE.model" common_kwds = { - "force_field_name": "MACE", + "force_field_name": "MACE-MP-0", "calculator_kwargs": {"model": model_path, "default_dtype": "float64"}, "relax_kwargs": {"fmax": 0.00001}, } @@ -29,7 +29,7 @@ def test_elastic_wf_with_mace( ValueError, match="You have specified both `calculator_kwargs` and" ): ElasticMaker.from_force_field_name( - force_field_name="MACE", + force_field_name="MACE-MP-0", mlff_kwargs=common_kwds, calculator_kwargs=common_kwds, ) @@ -38,7 +38,7 @@ def test_elastic_wf_with_mace( UserWarning, match="`mlff_kwargs` has been marked for deprecation." ): maker = ElasticMaker.from_force_field_name( - force_field_name="MACE", + force_field_name="MACE-MP-0", mlff_kwargs=common_kwds, ) assert all( From 1e5262c3492ce2d5cdb171d9ddb7cf2315bbeb9c Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 31 Mar 2026 17:35:55 -0700 Subject: [PATCH 2/2] reintroduce strict pymatgen dep --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6c27f0458f..9def308d82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,7 @@ strict = [ "atomate2[cclib, phonons, lobster, openmm, mp, defects, ase, ase-ext]", "numpy<3.0", "numba>=0.60.0", # needed to get numpy >2,<3 installed + "pymatgen==2026.3.23", ] [project.scripts]