Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 73 additions & 26 deletions pysages/backends/jax-md.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from pysages.typing import Callable, NamedTuple
from pysages.utils import check_device_array, copy

import jax.lax

class Sampler:
def __init__(self, method_bundle, context_state, callback: Callable):
Expand Down Expand Up @@ -43,17 +43,30 @@ def take_snapshot(state, box, dt):
vel_mass = (velocities, masses)
origin = tuple(0.0 for _ in range(dims))

if state.chain:
chain_data = vars(state.chain)
else:
chain_data = None

check_device_array(positions) # currently, we only support `DeviceArray`s

return Snapshot(positions, vel_mass, forces, ids, None, Box(box, origin), dt)
return Snapshot(positions, vel_mass, forces, ids, None, Box(box, origin), dt, chain_data=chain_data)


def update_snapshot(snapshot, state):
_, masses = snapshot.vel_mass
positions = state.position
vel_mass = (state.velocity, masses)
forces = state.force
return snapshot._replace(positions=positions, vel_mass=vel_mass, forces=forces)
if state.chain:
chain_data = vars(state.chain)
return snapshot._replace(
positions=positions,
vel_mass=vel_mass,
forces=forces,
chain_data=chain_data)
else:
return snapshot._replace(positions=positions, vel_mass=vel_mass, forces=forces)


def build_snapshot_methods(context, sampling_method):
Expand Down Expand Up @@ -84,39 +97,73 @@ def dimensionality():

return helpers

jax_fn_container = {'is_defined': False, 'run_fn': None}

def build_runner(context, sampler, jit_compile=True):
step_fn = context.step_fn

def _step(sampling_context_state, snapshot, sampler_state):
context_state = sampling_context_state.state
snapshot = update_snapshot(snapshot, context_state)
sampling_context_state = step_fn(sampling_context_state) # jax_md simulation step
sampler_state = sampler.update(snapshot, sampler_state) # pysages update
if sampler_state.bias is not None: # bias the simulation
if not jax_fn_container['is_defined']:
jax_fn_container['is_defined'] = True

def _step(sampling_context_state, snapshot, sampler_state):
sampling_context_state = step_fn(sampling_context_state) # jax_md simulation step
context_state = sampling_context_state.state
biased_forces = context_state.force + sampler_state.bias
context_state = dataclasses.replace(context_state, force=biased_forces)
sampling_context_state = sampling_context_state._replace(state=context_state)
return sampling_context_state, snapshot, sampler_state

step = jit(_step) if jit_compile else _step

def run(timesteps):
# TODO: Allow to optionally batch timesteps with `lax.fori_loop`
for i in range(timesteps):
context_state, snapshot, state = step(
sampler.context_state, sampler.snapshot, sampler.state
)
sampler.context_state = context_state
sampler.snapshot = snapshot
sampler.state = state
snapshot = update_snapshot(snapshot, context_state)
sampler_state = sampler.update(snapshot, sampler_state) # pysages update
if sampler_state.bias is not None: # bias the simulation
context_state = sampling_context_state.state
biased_forces = context_state.force + sampler_state.bias
context_state = dataclasses.replace(context_state, force=biased_forces)
sampling_context_state = sampling_context_state._replace(state=context_state)
return sampling_context_state, snapshot, sampler_state

step = jit(_step) if jit_compile else _step



def _run_body(i, input_states_and_snapshots):
context_state, snapshot, sampler_state = input_states_and_snapshots
context_state, snapshot, sampler_state = step(context_state, snapshot, sampler_state)

if sampler.callback:
sampler.callback(sampler.snapshot, sampler.state, i)
sampler.callback(snapshot, sampler_state, i)
return (context_state, snapshot, sampler_state)

run_body = jit(_run_body) if jit_compile else _run_body


if jit_compile:
jax_fn_container['run_fn'] = run_body
else:
jax_fn_container['run_fn'] = step

if jit_compile:
def run(timesteps):
# TODO: Allow to optionally batch timesteps with `lax.fori_loop`

sampler.context_state, sampler.snapshot, sampler.state = jax.block_until_ready(
jax.lax.fori_loop(0, timesteps, jax_fn_container['run_fn'], (sampler.context_state, sampler.snapshot, sampler.state))
)
else:
def run(timesteps):
# TODO: Allow to optionally batch timesteps with `lax.fori_loop`
for i in range(timesteps):
context_state, snapshot, state = jax_fn_container['run_fn'](
sampler.context_state, sampler.snapshot, sampler.state
)
sampler.context_state = context_state
sampler.snapshot = snapshot
sampler.state = state
if sampler.callback:
sampler.callback(sampler.snapshot, sampler.state, i)

#return run

return run




class View(NamedTuple):
synchronize: Callable

Expand Down
5 changes: 4 additions & 1 deletion pysages/backends/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import jit
from jax import numpy as np

from pysages.typing import Callable, JaxArray, NamedTuple, Optional, Tuple, Union
from pysages.typing import Callable, JaxArray, NamedTuple, Optional, Tuple, Union, Dict, Any
from pysages.utils import copy, dispatch

AbstractBox = NamedTuple("AbstractBox", [("H", JaxArray), ("origin", JaxArray)])
Expand Down Expand Up @@ -36,6 +36,9 @@ class Snapshot(NamedTuple):
box: Box
dt: Union[JaxArray, float]

#Optional thermostat parameters
chain_data : Optional[dict[str,Any]] = None

def __repr__(self):
return "PySAGES " + type(self).__name__

Expand Down
1 change: 1 addition & 0 deletions pysages/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Sequence = _typing.Sequence
Tuple = _typing.Tuple
Union = _typing.Union
Dict = _typing.Dict

# Union aliases
Scalar = Union[None, bool, int, float]
Expand Down
3 changes: 3 additions & 0 deletions pysages/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def copy(x: Scalar):
def copy(t: tuple, *args): # noqa: F811 # pylint: disable=C0116,E0102
return tuple(copy(x, *args) for x in t) # pylint: disable=E1120

@dispatch
def copy(x: dict): # noqa: F811 # pylint: disable=C0116,E0102
return x.copy()

@dispatch
def copy(x: JaxArray): # noqa: F811 # pylint: disable=C0116,E0102
Expand Down