From a5fd7d13538088d54a8d48f72d028ffc06eb7c1d Mon Sep 17 00:00:00 2001 From: Evgeny Moerman Date: Tue, 17 Jun 2025 13:08:33 +0200 Subject: [PATCH 1/7] Fix order in jax-md step to avoid bug for 1 step md --- pysages/backends/jax-md.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysages/backends/jax-md.py b/pysages/backends/jax-md.py index 9b2529fb..b8ef884e 100644 --- a/pysages/backends/jax-md.py +++ b/pysages/backends/jax-md.py @@ -89,9 +89,9 @@ def build_runner(context, sampler, jit_compile=True): step_fn = context.step_fn def _step(sampling_context_state, snapshot, sampler_state): + sampling_context_state = step_fn(sampling_context_state) # jax_md simulation step context_state = sampling_context_state.state snapshot = update_snapshot(snapshot, context_state) - sampling_context_state = step_fn(sampling_context_state) # jax_md simulation step sampler_state = sampler.update(snapshot, sampler_state) # pysages update if sampler_state.bias is not None: # bias the simulation context_state = sampling_context_state.state From 6b4b6fe0ee31eaf5f807d5538b3a320e65e470af Mon Sep 17 00:00:00 2001 From: Evgeny Moerman Date: Tue, 17 Jun 2025 15:19:31 +0200 Subject: [PATCH 2/7] Add optional chain data to snapshot --- pysages/backends/jax-md.py | 20 +++++++++++++++++++- pysages/backends/snapshot.py | 8 ++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/pysages/backends/jax-md.py b/pysages/backends/jax-md.py index b8ef884e..c10e94ee 100644 --- a/pysages/backends/jax-md.py +++ b/pysages/backends/jax-md.py @@ -53,7 +53,25 @@ def update_snapshot(snapshot, state): positions = state.position vel_mass = (state.velocity, masses) forces = state.force - return snapshot._replace(positions=positions, vel_mass=vel_mass, forces=forces) + if state.chain: + chain_positions = state.chain.position + chain_momenta = state.chain.momentum + chain_mass = state.chain.mass + chain_ekin = state.chain.kinetic_energy + chain_tau = state.chain.tau + chain_dof = state.chain.degrees_of_freedom + return snapshot._replace( + positions=positions, + vel_mass=vel_mass, + forces=forces, + chain_positions=chain_positions, + chain_momenta=chain_momenta, + chain_mass=chain_mass, + chain_ekin=chain_ekin, + chain_tau=chain_tau, + chain_dof=chain_dof) + else: + return snapshot._replace(positions=positions, vel_mass=vel_mass, forces=forces) def build_snapshot_methods(context, sampling_method): diff --git a/pysages/backends/snapshot.py b/pysages/backends/snapshot.py index c88c5c8a..5f082844 100644 --- a/pysages/backends/snapshot.py +++ b/pysages/backends/snapshot.py @@ -36,6 +36,14 @@ class Snapshot(NamedTuple): box: Box dt: Union[JaxArray, float] + #Optional thermostat parameters + chain_positions : Optional[JaxArray] = None + chain_momenta : Optional[JaxArray] = None + chain_mass : Optional[Union[JaxArray, float]] = None + chain_ekin : Optional[Union[JaxArray, float]] = None + chain_tau : Optional[float] = None + chain_dof : Optional[int] = None + def __repr__(self): return "PySAGES " + type(self).__name__ From 676657c72d9273f34c824429799d72c73488d149 Mon Sep 17 00:00:00 2001 From: Evgeny Moerman Date: Tue, 17 Jun 2025 17:06:08 +0200 Subject: [PATCH 3/7] Simply Snapshot chain information into dict --- pysages/backends/jax-md.py | 20 +++++--------------- pysages/backends/snapshot.py | 9 ++------- pysages/typing.py | 1 + pysages/utils/core.py | 3 +++ 4 files changed, 11 insertions(+), 22 deletions(-) diff --git a/pysages/backends/jax-md.py b/pysages/backends/jax-md.py index c10e94ee..373113ca 100644 --- a/pysages/backends/jax-md.py +++ b/pysages/backends/jax-md.py @@ -54,22 +54,12 @@ def update_snapshot(snapshot, state): vel_mass = (state.velocity, masses) forces = state.force if state.chain: - chain_positions = state.chain.position - chain_momenta = state.chain.momentum - chain_mass = state.chain.mass - chain_ekin = state.chain.kinetic_energy - chain_tau = state.chain.tau - chain_dof = state.chain.degrees_of_freedom + chain_data = vars(state.chain) return snapshot._replace( - positions=positions, - vel_mass=vel_mass, + positions=positions, + vel_mass=vel_mass, forces=forces, - chain_positions=chain_positions, - chain_momenta=chain_momenta, - chain_mass=chain_mass, - chain_ekin=chain_ekin, - chain_tau=chain_tau, - chain_dof=chain_dof) + chain_data=chain_data) else: return snapshot._replace(positions=positions, vel_mass=vel_mass, forces=forces) @@ -113,7 +103,7 @@ def _step(sampling_context_state, snapshot, sampler_state): sampler_state = sampler.update(snapshot, sampler_state) # pysages update if sampler_state.bias is not None: # bias the simulation context_state = sampling_context_state.state - biased_forces = context_state.force + sampler_state.bias + biased_forces = context_state.force #+ sampler_state.bias context_state = dataclasses.replace(context_state, force=biased_forces) sampling_context_state = sampling_context_state._replace(state=context_state) return sampling_context_state, snapshot, sampler_state diff --git a/pysages/backends/snapshot.py b/pysages/backends/snapshot.py index 5f082844..cf17dc82 100644 --- a/pysages/backends/snapshot.py +++ b/pysages/backends/snapshot.py @@ -4,7 +4,7 @@ from jax import jit from jax import numpy as np -from pysages.typing import Callable, JaxArray, NamedTuple, Optional, Tuple, Union +from pysages.typing import Callable, JaxArray, NamedTuple, Optional, Tuple, Union, Dict, Any from pysages.utils import copy, dispatch AbstractBox = NamedTuple("AbstractBox", [("H", JaxArray), ("origin", JaxArray)]) @@ -37,12 +37,7 @@ class Snapshot(NamedTuple): dt: Union[JaxArray, float] #Optional thermostat parameters - chain_positions : Optional[JaxArray] = None - chain_momenta : Optional[JaxArray] = None - chain_mass : Optional[Union[JaxArray, float]] = None - chain_ekin : Optional[Union[JaxArray, float]] = None - chain_tau : Optional[float] = None - chain_dof : Optional[int] = None + chain_data : Optional[dict[str,Any]] = None def __repr__(self): return "PySAGES " + type(self).__name__ diff --git a/pysages/typing.py b/pysages/typing.py index 8a4664ed..728c5568 100644 --- a/pysages/typing.py +++ b/pysages/typing.py @@ -41,6 +41,7 @@ Sequence = _typing.Sequence Tuple = _typing.Tuple Union = _typing.Union +Dict = _typing.Dict # Union aliases Scalar = Union[None, bool, int, float] diff --git a/pysages/utils/core.py b/pysages/utils/core.py index 20fe0f3e..14096770 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -27,6 +27,9 @@ def copy(x: Scalar): def copy(t: tuple, *args): # noqa: F811 # pylint: disable=C0116,E0102 return tuple(copy(x, *args) for x in t) # pylint: disable=E1120 +@dispatch +def copy(x: dict): # noqa: F811 # pylint: disable=C0116,E0102 + return x.copy() @dispatch def copy(x: JaxArray): # noqa: F811 # pylint: disable=C0116,E0102 From e17d328fa3da822d4e3b50a8f4738464b12232ec Mon Sep 17 00:00:00 2001 From: Evgeny Moerman Date: Thu, 19 Jun 2025 16:35:38 +0200 Subject: [PATCH 4/7] Add back addition of bias to force (previously switched off for debugging) --- pysages/backends/jax-md.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysages/backends/jax-md.py b/pysages/backends/jax-md.py index 373113ca..54db2f2c 100644 --- a/pysages/backends/jax-md.py +++ b/pysages/backends/jax-md.py @@ -103,7 +103,7 @@ def _step(sampling_context_state, snapshot, sampler_state): sampler_state = sampler.update(snapshot, sampler_state) # pysages update if sampler_state.bias is not None: # bias the simulation context_state = sampling_context_state.state - biased_forces = context_state.force #+ sampler_state.bias + biased_forces = context_state.force + sampler_state.bias context_state = dataclasses.replace(context_state, force=biased_forces) sampling_context_state = sampling_context_state._replace(state=context_state) return sampling_context_state, snapshot, sampler_state From c9d8203e6b6d64af79412cd920ee03f593f3a754 Mon Sep 17 00:00:00 2001 From: Evgeny Moerman Date: Wed, 25 Jun 2025 22:38:28 +0200 Subject: [PATCH 5/7] Accelerate jax-md run fn via jax fori_loop + disable re-compilation of step fn upon restart --- pysages/backends/jax-md.py | 73 +++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/pysages/backends/jax-md.py b/pysages/backends/jax-md.py index 54db2f2c..0592a8a9 100644 --- a/pysages/backends/jax-md.py +++ b/pysages/backends/jax-md.py @@ -15,7 +15,7 @@ ) from pysages.typing import Callable, NamedTuple from pysages.utils import check_device_array, copy - +import jax.lax class Sampler: def __init__(self, method_bundle, context_state, callback: Callable): @@ -43,9 +43,14 @@ def take_snapshot(state, box, dt): vel_mass = (velocities, masses) origin = tuple(0.0 for _ in range(dims)) + if state.chain: + chain_data = vars(state.chain) + else: + chain_data = None + check_device_array(positions) # currently, we only support `DeviceArray`s - return Snapshot(positions, vel_mass, forces, ids, None, Box(box, origin), dt) + return Snapshot(positions, vel_mass, forces, ids, None, Box(box, origin), dt, chain_data=chain_data) def update_snapshot(snapshot, state): @@ -92,39 +97,65 @@ def dimensionality(): return helpers +jax_fn_container = {'is_defined': False, 'run_fn': None} def build_runner(context, sampler, jit_compile=True): step_fn = context.step_fn - def _step(sampling_context_state, snapshot, sampler_state): - sampling_context_state = step_fn(sampling_context_state) # jax_md simulation step - context_state = sampling_context_state.state - snapshot = update_snapshot(snapshot, context_state) - sampler_state = sampler.update(snapshot, sampler_state) # pysages update - if sampler_state.bias is not None: # bias the simulation + if not jax_fn_container['is_defined']: + jax_fn_container['is_defined'] = True + + def _step(sampling_context_state, snapshot, sampler_state): + sampling_context_state = step_fn(sampling_context_state) # jax_md simulation step context_state = sampling_context_state.state - biased_forces = context_state.force + sampler_state.bias - context_state = dataclasses.replace(context_state, force=biased_forces) - sampling_context_state = sampling_context_state._replace(state=context_state) - return sampling_context_state, snapshot, sampler_state + snapshot = update_snapshot(snapshot, context_state) + sampler_state = sampler.update(snapshot, sampler_state) # pysages update + if sampler_state.bias is not None: # bias the simulation + context_state = sampling_context_state.state + biased_forces = context_state.force + sampler_state.bias + context_state = dataclasses.replace(context_state, force=biased_forces) + sampling_context_state = sampling_context_state._replace(state=context_state) + return sampling_context_state, snapshot, sampler_state + + step = jit(_step) if jit_compile else _step + - step = jit(_step) if jit_compile else _step + + def _run_body(i, input_states_and_snapshots): + context_state, snapshot, sampler_state = input_states_and_snapshots + return tuple(step(context_state, snapshot, sampler_state)) + + run_body = jit(_run_body) if jit_compile else _run_body + + + jax_fn_container['run_fn'] = run_body def run(timesteps): # TODO: Allow to optionally batch timesteps with `lax.fori_loop` - for i in range(timesteps): - context_state, snapshot, state = step( - sampler.context_state, sampler.snapshot, sampler.state + + sampler.context_state, sampler.snapshot, sampler.state = jax.block_until_ready( + jax.lax.fori_loop(0, timesteps, jax_fn_container['run_fn'], (sampler.context_state, sampler.snapshot, sampler.state)) ) - sampler.context_state = context_state - sampler.snapshot = snapshot - sampler.state = state - if sampler.callback: - sampler.callback(sampler.snapshot, sampler.state, i) + + #def run(timesteps): + # # TODO: Allow to optionally batch timesteps with `lax.fori_loop` + # for i in range(timesteps): + # context_state, snapshot, state = step( + # sampler.context_state, sampler.snapshot, sampler.state + # ) + # sampler.context_state = context_state + # sampler.snapshot = snapshot + # sampler.state = state + # if sampler.callback: + # sampler.callback(sampler.snapshot, sampler.state, i) + + #return run return run + + class View(NamedTuple): synchronize: Callable From 4289e48d6c3e597e59607f0c4b4a3098acb8b96d Mon Sep 17 00:00:00 2001 From: Evgeny Moerman Date: Wed, 25 Jun 2025 23:04:53 +0200 Subject: [PATCH 6/7] Add callback option to _run_body --- pysages/backends/jax-md.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pysages/backends/jax-md.py b/pysages/backends/jax-md.py index 0592a8a9..dbb6e313 100644 --- a/pysages/backends/jax-md.py +++ b/pysages/backends/jax-md.py @@ -123,7 +123,11 @@ def _step(sampling_context_state, snapshot, sampler_state): def _run_body(i, input_states_and_snapshots): context_state, snapshot, sampler_state = input_states_and_snapshots - return tuple(step(context_state, snapshot, sampler_state)) + context_state, snapshot, sampler_state = step(context_state, snapshot, sampler_state) + + if sampler.callback: + sampler.callback(snapshot, sampler_state, i) + return (context_state, snapshot, sampler_state) run_body = jit(_run_body) if jit_compile else _run_body From bcacf6007ff27f7e8a6a6eabc465c85b84a1850f Mon Sep 17 00:00:00 2001 From: Evgeny Moerman Date: Wed, 25 Jun 2025 23:51:19 +0200 Subject: [PATCH 7/7] Make fori_loop version optional --- pysages/backends/jax-md.py | 44 +++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/pysages/backends/jax-md.py b/pysages/backends/jax-md.py index dbb6e313..dfbdf7d6 100644 --- a/pysages/backends/jax-md.py +++ b/pysages/backends/jax-md.py @@ -132,26 +132,30 @@ def _run_body(i, input_states_and_snapshots): run_body = jit(_run_body) if jit_compile else _run_body - jax_fn_container['run_fn'] = run_body - - def run(timesteps): - # TODO: Allow to optionally batch timesteps with `lax.fori_loop` - - sampler.context_state, sampler.snapshot, sampler.state = jax.block_until_ready( - jax.lax.fori_loop(0, timesteps, jax_fn_container['run_fn'], (sampler.context_state, sampler.snapshot, sampler.state)) - ) - - #def run(timesteps): - # # TODO: Allow to optionally batch timesteps with `lax.fori_loop` - # for i in range(timesteps): - # context_state, snapshot, state = step( - # sampler.context_state, sampler.snapshot, sampler.state - # ) - # sampler.context_state = context_state - # sampler.snapshot = snapshot - # sampler.state = state - # if sampler.callback: - # sampler.callback(sampler.snapshot, sampler.state, i) + if jit_compile: + jax_fn_container['run_fn'] = run_body + else: + jax_fn_container['run_fn'] = step + + if jit_compile: + def run(timesteps): + # TODO: Allow to optionally batch timesteps with `lax.fori_loop` + + sampler.context_state, sampler.snapshot, sampler.state = jax.block_until_ready( + jax.lax.fori_loop(0, timesteps, jax_fn_container['run_fn'], (sampler.context_state, sampler.snapshot, sampler.state)) + ) + else: + def run(timesteps): + # TODO: Allow to optionally batch timesteps with `lax.fori_loop` + for i in range(timesteps): + context_state, snapshot, state = jax_fn_container['run_fn']( + sampler.context_state, sampler.snapshot, sampler.state + ) + sampler.context_state = context_state + sampler.snapshot = snapshot + sampler.state = state + if sampler.callback: + sampler.callback(sampler.snapshot, sampler.state, i) #return run