From fb1ea2ad9199c5d3e7d87497e889d41b25c4ba98 Mon Sep 17 00:00:00 2001 From: sz <1366808715@qq.com> Date: Thu, 12 Mar 2026 15:22:46 +0800 Subject: [PATCH] Fix ASE backend duplicate state updates within same timestep --- pysages/backends/ase.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/pysages/backends/ase.py b/pysages/backends/ase.py index 45f5e328..9c125f0e 100644 --- a/pysages/backends/ase.py +++ b/pysages/backends/ase.py @@ -41,6 +41,7 @@ def __init__(self, context, method_bundle, callback: Callable): self.callback = callback self.snapshot = initial_snapshot self.state = initialize() + self._last_update_timestep = -1 self.update = method_update sig = signature(atoms.calc.calculate).parameters @@ -67,23 +68,29 @@ def __getattr__(self, name): def biased_forces(self): return view(copy(self._biased_forces, ToCPU())) - def calculate(self, atoms=None, **kwargs): - properties = kwargs.get("properties", self._default_properties) - system_changes = kwargs.get("system_changes", self._default_changes) + def calculate(self, atoms=None, properties=None, system_changes=None): + properties = self._default_properties if properties is None else properties + system_changes = ( + self._default_changes if system_changes is None else system_changes + ) self._calculator.calculate(atoms, properties, system_changes) - def get_forces(self, atoms=None): + def get_forces(self, atoms=None): # type: ignore[override] forces = self._get_forces(atoms) - self.snapshot = take_snapshot(self._context, forces) - self.state = self.update(self.snapshot, self.state) - new_forces = self.snapshot.forces + timestep = self._context.get_number_of_steps() + + if timestep != self._last_update_timestep: + self.snapshot = take_snapshot(self._context, forces) + self.state = self.update(self.snapshot, self.state) + self._last_update_timestep = timestep + if self.callback: + self.callback(self.snapshot, self.state, timestep) + + new_forces = forces if self.state.bias is not None: new_forces += self.state.bias - if self.callback: - timestep = self._context.get_number_of_steps() - self.callback(self.snapshot, self.state, timestep) self._biased_forces = new_forces - return self.biased_forces + return self.biased_forces # type: ignore[return-value] def restore(self, prev_snapshot): atoms = self.atoms