diff --git a/pysages/backends/jax-md.py b/pysages/backends/jax-md.py index 9b2529fb..dfbdf7d6 100644 --- a/pysages/backends/jax-md.py +++ b/pysages/backends/jax-md.py @@ -15,7 +15,7 @@ ) from pysages.typing import Callable, NamedTuple from pysages.utils import check_device_array, copy - +import jax.lax class Sampler: def __init__(self, method_bundle, context_state, callback: Callable): @@ -43,9 +43,14 @@ def take_snapshot(state, box, dt): vel_mass = (velocities, masses) origin = tuple(0.0 for _ in range(dims)) + if state.chain: + chain_data = vars(state.chain) + else: + chain_data = None + check_device_array(positions) # currently, we only support `DeviceArray`s - return Snapshot(positions, vel_mass, forces, ids, None, Box(box, origin), dt) + return Snapshot(positions, vel_mass, forces, ids, None, Box(box, origin), dt, chain_data=chain_data) def update_snapshot(snapshot, state): @@ -53,7 +58,15 @@ def update_snapshot(snapshot, state): positions = state.position vel_mass = (state.velocity, masses) forces = state.force - return snapshot._replace(positions=positions, vel_mass=vel_mass, forces=forces) + if state.chain: + chain_data = vars(state.chain) + return snapshot._replace( + positions=positions, + vel_mass=vel_mass, + forces=forces, + chain_data=chain_data) + else: + return snapshot._replace(positions=positions, vel_mass=vel_mass, forces=forces) def build_snapshot_methods(context, sampling_method): @@ -84,39 +97,73 @@ def dimensionality(): return helpers +jax_fn_container = {'is_defined': False, 'run_fn': None} def build_runner(context, sampler, jit_compile=True): step_fn = context.step_fn - def _step(sampling_context_state, snapshot, sampler_state): - context_state = sampling_context_state.state - snapshot = update_snapshot(snapshot, context_state) - sampling_context_state = step_fn(sampling_context_state) # jax_md simulation step - sampler_state = sampler.update(snapshot, sampler_state) # pysages update - if sampler_state.bias is not None: # bias the simulation + if not jax_fn_container['is_defined']: + jax_fn_container['is_defined'] = True + + def _step(sampling_context_state, snapshot, sampler_state): + sampling_context_state = step_fn(sampling_context_state) # jax_md simulation step context_state = sampling_context_state.state - biased_forces = context_state.force + sampler_state.bias - context_state = dataclasses.replace(context_state, force=biased_forces) - sampling_context_state = sampling_context_state._replace(state=context_state) - return sampling_context_state, snapshot, sampler_state - - step = jit(_step) if jit_compile else _step - - def run(timesteps): - # TODO: Allow to optionally batch timesteps with `lax.fori_loop` - for i in range(timesteps): - context_state, snapshot, state = step( - sampler.context_state, sampler.snapshot, sampler.state - ) - sampler.context_state = context_state - sampler.snapshot = snapshot - sampler.state = state + snapshot = update_snapshot(snapshot, context_state) + sampler_state = sampler.update(snapshot, sampler_state) # pysages update + if sampler_state.bias is not None: # bias the simulation + context_state = sampling_context_state.state + biased_forces = context_state.force + sampler_state.bias + context_state = dataclasses.replace(context_state, force=biased_forces) + sampling_context_state = sampling_context_state._replace(state=context_state) + return sampling_context_state, snapshot, sampler_state + + step = jit(_step) if jit_compile else _step + + + + def _run_body(i, input_states_and_snapshots): + context_state, snapshot, sampler_state = input_states_and_snapshots + context_state, snapshot, sampler_state = step(context_state, snapshot, sampler_state) + if sampler.callback: - sampler.callback(sampler.snapshot, sampler.state, i) + sampler.callback(snapshot, sampler_state, i) + return (context_state, snapshot, sampler_state) + + run_body = jit(_run_body) if jit_compile else _run_body + + + if jit_compile: + jax_fn_container['run_fn'] = run_body + else: + jax_fn_container['run_fn'] = step + + if jit_compile: + def run(timesteps): + # TODO: Allow to optionally batch timesteps with `lax.fori_loop` + + sampler.context_state, sampler.snapshot, sampler.state = jax.block_until_ready( + jax.lax.fori_loop(0, timesteps, jax_fn_container['run_fn'], (sampler.context_state, sampler.snapshot, sampler.state)) + ) + else: + def run(timesteps): + # TODO: Allow to optionally batch timesteps with `lax.fori_loop` + for i in range(timesteps): + context_state, snapshot, state = jax_fn_container['run_fn']( + sampler.context_state, sampler.snapshot, sampler.state + ) + sampler.context_state = context_state + sampler.snapshot = snapshot + sampler.state = state + if sampler.callback: + sampler.callback(sampler.snapshot, sampler.state, i) + + #return run return run + + class View(NamedTuple): synchronize: Callable diff --git a/pysages/backends/snapshot.py b/pysages/backends/snapshot.py index c88c5c8a..cf17dc82 100644 --- a/pysages/backends/snapshot.py +++ b/pysages/backends/snapshot.py @@ -4,7 +4,7 @@ from jax import jit from jax import numpy as np -from pysages.typing import Callable, JaxArray, NamedTuple, Optional, Tuple, Union +from pysages.typing import Callable, JaxArray, NamedTuple, Optional, Tuple, Union, Dict, Any from pysages.utils import copy, dispatch AbstractBox = NamedTuple("AbstractBox", [("H", JaxArray), ("origin", JaxArray)]) @@ -36,6 +36,9 @@ class Snapshot(NamedTuple): box: Box dt: Union[JaxArray, float] + #Optional thermostat parameters + chain_data : Optional[dict[str,Any]] = None + def __repr__(self): return "PySAGES " + type(self).__name__ diff --git a/pysages/typing.py b/pysages/typing.py index 8a4664ed..728c5568 100644 --- a/pysages/typing.py +++ b/pysages/typing.py @@ -41,6 +41,7 @@ Sequence = _typing.Sequence Tuple = _typing.Tuple Union = _typing.Union +Dict = _typing.Dict # Union aliases Scalar = Union[None, bool, int, float] diff --git a/pysages/utils/core.py b/pysages/utils/core.py index 20fe0f3e..14096770 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -27,6 +27,9 @@ def copy(x: Scalar): def copy(t: tuple, *args): # noqa: F811 # pylint: disable=C0116,E0102 return tuple(copy(x, *args) for x in t) # pylint: disable=E1120 +@dispatch +def copy(x: dict): # noqa: F811 # pylint: disable=C0116,E0102 + return x.copy() @dispatch def copy(x: JaxArray): # noqa: F811 # pylint: disable=C0116,E0102