diff --git a/pysages/backends/lammps.py b/pysages/backends/lammps.py index b357a8f4..4659c29d 100644 --- a/pysages/backends/lammps.py +++ b/pysages/backends/lammps.py @@ -30,6 +30,8 @@ from pysages.typing import Callable, Optional from pysages.utils import copy, identity +from ctypes import c_double + kConversionFactor = {"real": 2390.0573615334906, "metal": 1.0364269e-4, "electron": 1.06657236} kDefaultLocation = dlext.kOnHost if not hasattr(ExecutionSpace, "kOnDevice") else dlext.kOnDevice @@ -116,6 +118,19 @@ def restore(self, prev_snapshot): """Replaces this sampler's snapshot with `prev_snapshot`.""" self._restore(self.snapshot, prev_snapshot) + positions = self.snapshot.positions[self.snapshot.ids].ravel() + x_ctypes = (len(positions) * c_double)(*positions) + self.context.scatter_atoms("x", 1, 3, x_ctypes) + + forces = self.snapshot.forces[self.snapshot.ids].ravel() + f_ctypes = (len(forces) * c_double)(*forces) + self.context.scatter_atoms("f", 1, 3, f_ctypes) + + velocities = self.snapshot.vel_mass[0][self.snapshot.ids].ravel() + v_ctypes = (len(velocities) * c_double)(*velocities) + self.context.scatter_atoms("v", 1, 3, v_ctypes) + + def take_snapshot(self): """Returns a copy of the current snapshot of the system.""" s = self._partial_snapshot(include_masses=True) diff --git a/pysages/colvars/ML_committor.py b/pysages/colvars/ML_committor.py new file mode 100644 index 00000000..7e596638 --- /dev/null +++ b/pysages/colvars/ML_committor.py @@ -0,0 +1,89 @@ +import jax +from jax import numpy as np +from functools import partial +from pysages.colvars.core import CollectiveVariable +from .train_committor_dist import CommittorNN_Dist_Lip, make_forward_eval, CommittorNN_PIV, CommittorNN_PIV_shiftsig +from pysages.typing import JaxArray, List, Sequence +from flax import serialization + +class Committor_CV_dist_lipschitz(CollectiveVariable): + def __init__(self, indices: List, params_path: str, tri_idx1: JaxArray, tri_idx2: JaxArray): + super().__init__(indices) + + model = CommittorNN_Dist_Lip(indices=np.arange(len(indices)), + tri_idx1=tri_idx1, + tri_idx2=tri_idx2, + h1=16, h2=16, h3=8, out_dim=1, sig_k=3.0) + + rng = jax.random.PRNGKey(0) + dummy_pos = np.zeros((1, len(indices), 3)) + params = model.init(rng, dummy_pos, training=False) + with open(params_path, "rb") as f: + params = serialization.from_bytes(params, f.read()) + params = jax.tree.map( + lambda x: x.astype(np.float64) if hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating) else x, + params, + ) + self.params = params + self.forward_eval = make_forward_eval(model) + + @property + def function(self): + def wrapped_forward(pos): + y = self.forward_eval(self.params, pos[None, :, :]) + return np.squeeze(y) + return wrapped_forward + + +class Committor_CV_PIV(CollectiveVariable): + def __init__(self, indices: List, params_path: str, blocks: Sequence): + super().__init__(indices) + + model = CommittorNN_PIV(indices=np.arange(len(indices)), blocks=blocks, h1=32, h2=16, h3=8, out_dim=1, sig_k=3.0) + + rng = jax.random.PRNGKey(0) + dummy_pos = np.zeros((1, len(indices), 3)) + params = model.init(rng, dummy_pos, training=False) + with open(params_path, "rb") as f: + params = serialization.from_bytes(params, f.read()) + params = jax.tree.map( + lambda x: x.astype(np.float64) if hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating) else x, + params, + ) + self.params = params + self.forward_eval = make_forward_eval(model) + + @property + def function(self): + def wrapped_forward(pos): + y = self.forward_eval(self.params, pos[None, :, :]) + return np.squeeze(y) + return wrapped_forward + +class Committor_CV_PIV_shiftsig(CollectiveVariable): + def __init__(self, indices: List, params_path: str, blocks: Sequence, h1=32, h2=16, h3=8, sig_k=3.0): + super().__init__(indices) + + model = CommittorNN_PIV_shiftsig(indices=np.arange(len(indices)), blocks=blocks, h1=h1, h2=h2, h3=h3, out_dim=1, sig_k=3.0) + + rng = jax.random.PRNGKey(0) + dummy_pos = np.zeros((1, len(indices), 3)) + params = model.init(rng, dummy_pos, training=False) + with open(params_path, "rb") as f: + params = serialization.from_bytes(params, f.read()) + params = jax.tree.map( + lambda x: x.astype(np.float64) if hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating) else x, + params, + ) + self.params = params + self.forward_eval = make_forward_eval(model) + + @property + def function(self): + def wrapped_forward(pos): + y = self.forward_eval(self.params, pos[None, :, :]) + return np.squeeze(y) + return wrapped_forward + +def cartesian(idx1, idx2): + return np.stack(np.broadcast_arrays(idx1[:, None], idx2[None, :]), axis=-1).reshape(-1, 2) diff --git a/pysages/colvars/train_committor_dist.py b/pysages/colvars/train_committor_dist.py new file mode 100644 index 00000000..5c29a179 --- /dev/null +++ b/pysages/colvars/train_committor_dist.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +JAX/Flax implementation of your committor NN with Lipschitz-normalized +linear layers, boundary + gradient + Lipschitz losses, and an Optax +training loop. This version fixes the earlier JIT issue by closing over +`loss_fn` instead of passing it as an argument to a jitted function. + +Usage (example): + # (optional) to avoid accidental GPU probing / driver issues + # export JAX_PLATFORMS=cpu + python run_committor_flax.py + +Notes: +- Keep this file name NOT equal to 'jax.py' to avoid shadowing the JAX lib. +- Labels: 0=A (reactant), 1=sim/intermediate, 2=B (product). +- Saves params to distance_flax.params and a histogram to fig_jax.png. +""" + +import os +# Respect external setting; uncomment the next line to force CPU: +# os.environ.setdefault("JAX_PLATFORMS", "cpu") +os.environ.setdefault("JAX_ENABLE_X64", "0") # float32 for speed + +from typing import Any, Dict, Tuple, Iterable, Sequence +import numpy as onp +import jax +import jax.numpy as np +from flax import linen as nn +from flax.training import train_state +import optax + +# ------------------------------ +# Layers +# ------------------------------ +class NormalizedLinear(nn.Module): + in_features: int + out_features: int + + @nn.compact + def __call__(self, x: np.ndarray) -> np.ndarray: + kernel = self.param( + "kernel", nn.initializers.lecun_normal(), (self.out_features, self.in_features) + ) + bias = self.param("bias", nn.initializers.zeros, (self.out_features,)) + ci = self.param("ci", nn.initializers.ones, ()) # trainable scalar + + absrowsum = np.sum(np.abs(kernel), axis=1, keepdims=True) + scale = np.minimum(1.0, nn.softplus(ci) / (absrowsum + 1e-12)) + w_norm = kernel * scale + return np.dot(x, w_norm.T) + bias + + +class CommittorNN_Dist_Lip(nn.Module): + indices: np.ndarray # (M,) + tri_idx1: np.ndarray # (K,) + tri_idx2: np.ndarray # (K,) + h1: int = 16 + h2: int = 16 + h3: int = 8 + out_dim: int = 1 + sig_k: float = 3.0 + + @nn.compact + def __call__(self, pos: np.ndarray, training: bool = False) -> np.ndarray: + # pos: (B, N, 3) + r = pos[:, self.indices, :] # (B, M, 3) + diffs = r[:, self.tri_idx1, :] - r[:, self.tri_idx2, :] # (B, K, 3) + dists = np.linalg.norm(diffs, ord=2, axis=-1) # (B, K) + + x = dists + x = NormalizedLinear(dists.shape[-1], self.h1)(x); x = np.tanh(x) + x = NormalizedLinear(self.h1, self.h2)(x); x = np.tanh(x) + x = NormalizedLinear(self.h2, self.h3)(x); x = np.tanh(x) + x = NormalizedLinear(self.h3, self.out_dim)(x) # (B,1) + x = np.squeeze(x, axis=-1) # (B,) + return jax.nn.sigmoid(self.sig_k * x) if training else x + + +def get_block_dist(pos, block): + return np.sort(np.linalg.norm(pos[:,block[:,0]] - pos[:,block[:,1]], axis=-1), axis=-1) + +def get_x(pos, blocks): + return np.concatenate([get_block_dist(pos, block) for block in blocks], axis=-1) + +class CommittorNN_PIV(nn.Module): + indices: np.ndarray # (M,) + blocks: Sequence + h1: int = 32 + h2: int = 16 + h3: int = 8 + out_dim: int = 1 + sig_k: float = 3.0 + + @nn.compact + def __call__(self, pos: np.ndarray, training: bool = False) -> np.ndarray: + # pos: (B, N, 3) + x = get_x(pos[:, self.indices, :], self.blocks) + x = NormalizedLinear(x.shape[-1], self.h1)(x); x = np.tanh(x) + x = NormalizedLinear(self.h1, self.h2)(x); x = np.tanh(x) + x = NormalizedLinear(self.h2, self.h3)(x); x = np.tanh(x) + x = NormalizedLinear(self.h3, self.out_dim)(x) # (B,1) + x = np.squeeze(x, axis=-1) # (B,) + return jax.nn.sigmoid(self.sig_k * x) if training else x + +class CommittorNN_PIV_shiftsig(nn.Module): + indices: np.ndarray # (M,) + blocks: Sequence + h1: int = 32 + h2: int = 16 + h3: int = 8 + out_dim: int = 1 + sig_k: float = 4.0 + + @nn.compact + def __call__(self, pos: np.ndarray, training: bool = False) -> np.ndarray: + # pos: (B, N, 3) + sig_shift1 = 1e-3*self.param("sig_shift1", nn.initializers.constant(0.0), ()) + sig_shift2 = 1e-3*self.param("sig_shift2", nn.initializers.constant(0.0), ()) + + x = get_x(pos[:, self.indices, :], self.blocks) + x = NormalizedLinear(x.shape[-1], self.h1)(x); x = np.tanh(x) + x = NormalizedLinear(self.h1, self.h2)(x); x = np.tanh(x) + x = NormalizedLinear(self.h2, self.h3)(x); x = np.tanh(x) + x = NormalizedLinear(self.h3, self.out_dim)(x) # (B,1) + x = np.squeeze(x, axis=-1) # (B,) + q = 0.5*jax.nn.sigmoid(self.sig_k * (x-sig_shift1) ) + 0.5*jax.nn.sigmoid(self.sig_k * (x-sig_shift2) ) + if training: + return q, x + else: + return x + +# ------------------------------ +# Losses +# ------------------------------ +@jax.jit +def lipschitz_loss_from_params(params: Dict[str, Any]) -> np.ndarray: + prod = 1.0 + + def walk(p): + nonlocal prod + if isinstance(p, dict): + if "ci" in p: + prod = prod * nn.softplus(p["ci"]) # multiply scalar + for v in p.values(): + walk(v) + + walk(params) + return prod + + +def make_loss_fn(model: nn.Module, + masses: np.ndarray, + boundary_weight: float = 100.0, + lipschitz_weight: float = 3e-3, + gradient_weight: float = 1.0): + # masses: (1, N, 1) broadcast to (B, N, 1) + def loss_fn(params, pos_batch: np.ndarray, labels: np.ndarray, weights: np.ndarray): + q = model.apply(params, pos_batch, training=True) # (B,) + + # boundary term (average over non-1 labels) + is_A = (labels == 0) + is_B = (labels == 2) + is_1 = (labels == 1) + num_1 = np.maximum(1, np.sum(is_1)) + num_not1 = np.maximum(1, pos_batch.shape[0] - np.sum(is_1)) + boundary = (np.sum((q**2) * is_A) + np.sum(((q - 1.0)**2) * is_B)) / num_not1 + + # gradient term (only on label==1) + def q_sum_over_batch_pos(pos): + return np.sum(model.apply(params, pos, training=True)) + grad_pos = jax.grad(q_sum_over_batch_pos)(pos_batch) # (B,N,3) + grad_sq = (grad_pos**2) / masses + grad_per_sample = np.sum(grad_sq, axis=(1, 2)) # (B,) + grad_loss = np.sum(np.where(is_1, weights * grad_per_sample, 0.0)) / num_1 + + # lipschitz product of softplus(ci) + lip = lipschitz_loss_from_params(params) + + total = 1e4 * (gradient_weight * grad_loss + boundary_weight * boundary + lipschitz_weight * lip) + return total, (1e4 * gradient_weight * grad_loss, + 1e4 * boundary_weight * boundary, + 1e4 * lipschitz_weight * lip) + + return loss_fn + + +# ------------------------------ +# Train state & steps +# ------------------------------ +class TrainState(train_state.TrainState): + pass + + +def create_train_state(rng, model, learning_rate: float, pos_shape: Tuple[int, int, int]): + params = model.init(rng, np.zeros(pos_shape, onp.float32), training=True) + tx = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adamw(learning_rate=learning_rate, weight_decay=1e-5), + ) + return TrainState.create(apply_fn=model.apply, params=params, tx=tx) + + +def make_train_step(loss_fn): + @jax.jit + def train_step(state: TrainState, pos_b, labels_b, weights_b): + (loss, parts), grads = jax.value_and_grad(loss_fn, has_aux=True)( + state.params, pos_b, labels_b, weights_b + ) + state = state.apply_gradients(grads=grads) + return state, loss, parts + return train_step + + +def make_eval_step(loss_fn): + @jax.jit + def eval_step(params, pos_b, labels_b, weights_b): + loss, parts = loss_fn(params, pos_b, labels_b, weights_b) + return loss, parts + return eval_step + + +# ------------------------------ +# Data utils +# ------------------------------ +def make_batches(pos: onp.ndarray, labels: onp.ndarray, weights: onp.ndarray, + batch_size: int, shuffle: bool = True, + rng: onp.random.Generator | None = None) -> Iterable[Tuple[onp.ndarray, onp.ndarray, onp.ndarray]]: + N = pos.shape[0] + idx = onp.arange(N) + if shuffle: + (rng or onp.random.default_rng()).shuffle(idx) + n_full = N // batch_size # drop_last + idx = idx[: n_full * batch_size].reshape(n_full, batch_size) + for row in idx: + yield pos[row], labels[row], weights[row] + + +# ------------------------------ +# Training loop +# ------------------------------ + +def train(model: nn.Module, + masses: onp.ndarray, + train_data: Tuple[onp.ndarray, onp.ndarray, onp.ndarray], + val_data: Tuple[onp.ndarray, onp.ndarray, onp.ndarray], + batch_size: int = 1024, + num_epochs: int = 100, + lr: float = 1e-3, + seed: int | None = 61982): + rng = jax.random.PRNGKey(seed if seed is not None else 0) + pos_train, y_train, w_train = train_data + pos_val, y_val, w_val = val_data + + _, N, D = pos_train.shape + state = create_train_state(rng, model, lr, (batch_size, N, D)) + + loss_fn = make_loss_fn(model, np.asarray(masses), + boundary_weight=100.0, + lipschitz_weight=3e-3, + gradient_weight=1.0) + train_step = make_train_step(loss_fn) + eval_step = make_eval_step(loss_fn) + + for epoch in range(num_epochs): + tr_losses = []; tr_grad = []; tr_bound = []; tr_lip = [] + for xb, yb, wb in make_batches(pos_train, y_train, w_train, batch_size, shuffle=True): + xb = np.asarray(xb, dtype=onp.float32) + yb = np.asarray(yb, dtype=onp.int32) + wb = np.asarray(wb, dtype=onp.float32) + state, loss, parts = train_step(state, xb, yb, wb) + tr_losses.append(float(loss)); g,b,l = parts; tr_grad.append(float(g)); tr_bound.append(float(b)); tr_lip.append(float(l)) + + va_losses = []; va_grad = []; va_bound = []; va_lip = [] + for xb, yb, wb in make_batches(pos_val, y_val, w_val, batch_size, shuffle=False): + xb = np.asarray(xb, dtype=onp.float32) + yb = np.asarray(yb, dtype=onp.int32) + wb = np.asarray(wb, dtype=onp.float32) + loss, parts = eval_step(state.params, xb, yb, wb) + va_losses.append(float(loss)); g,b,l = parts; va_grad.append(float(g)); va_bound.append(float(b)); va_lip.append(float(l)) + + print(f"Epoch {epoch}") + if tr_losses: + print(" avg train loss:", onp.mean(tr_losses)) + print(" avg train grad loss:", onp.mean(tr_grad)) + print(" avg train bound loss:", onp.mean(tr_bound)) + print(" avg train lipschitz loss:", onp.mean(tr_lip)) + if va_losses: + print(" avg val loss:", onp.mean(va_losses)) + print(" avg val grad loss:", onp.mean(va_grad)) + print(" avg val bound loss:", onp.mean(va_bound)) + print(" avg val lipschitz loss:", onp.mean(va_lip)) + print() + + return state + + +# ------------------------------ +# Inference helper +# ------------------------------ +def make_forward_eval(model): + @jax.jit + def _forward_eval(params, pos_b): + return model.apply(params, pos_b, training=False) + return _forward_eval + + +# ------------------------------ +# Main: mirrors your IO (MDAnalysis/pysages) +# ------------------------------ +if __name__ == "__main__": + import MDAnalysis as mda + import pysages + from ase.io import read as ase_read + import matplotlib.pyplot as plt + + # ----- Load CV and weights ----- + xi = onp.loadtxt("/scratch/aq2212/CV.log")[:, 1][::10] + run_result = pysages.load("/scratch/aq2212/production_abf2_equil15/restart.pkl") + fe_result = pysages.analyze(run_result) + biases = fe_result['fes_fn'](xi) + biases /= (8.6173e-5 * 600) + biases = biases[:, 0] + weightsSim = onp.exp(-biases).astype(onp.float32) + + # ----- Extract downsampled sim frames ----- + u = mda.Universe("sampled.xyz") + posSim = [u.atoms.positions.copy() for _ in u.trajectory] + posSim = onp.asarray(posSim, dtype=onp.float32) + + # ----- Load endpoints ----- + def load_xyz(path): + uu = mda.Universe(path) + return onp.asarray([uu.atoms.positions.copy() for _ in uu.trajectory], dtype=onp.float32) + + posA = load_xyz("/scratch/projects/depablolab/acqin2/lipschitz/visnet/model_for_sim1/A.xyz") + posB = load_xyz("/scratch/projects/depablolab/acqin2/lipschitz/visnet/model_for_sim1/B.xyz") + posC = load_xyz("/scratch/projects/depablolab/acqin2/lipschitz/visnet/model_for_sim1/C.xyz") + + labelsA = onp.zeros((len(posA),), dtype=onp.int32) + labelsB = onp.full((len(posB),), 2, dtype=onp.int32) + labelsC = onp.full((len(posC),), 1, dtype=onp.int32) + labelsSim = onp.full((len(posSim),), 1, dtype=onp.int32) + + weightsA = onp.ones((len(posA),), dtype=onp.float32) + weightsB = onp.ones((len(posB),), dtype=onp.float32) + weightsC = onp.full((len(posC),), 1e-3, dtype=onp.float32) + + # masses shape (1, N, 1) + atoms0 = ase_read('/scratch/projects/depablolab/acqin2/lipschitz/visnet/model_for_sim1/A.xyz', index=0) + masses = onp.asarray(atoms0.get_masses(), dtype=onp.float32)[None, :, None] + + # ----- Train/val split (A+B+Sim) + pos_all = onp.concatenate([posA, posB, posSim], axis=0) + lab_all = onp.concatenate([labelsA, labelsB, labelsSim], axis=0) + w_all = onp.concatenate([weightsA, weightsB, weightsSim], axis=0) + + rng_np = onp.random.default_rng(61982) + idx = onp.arange(len(pos_all)); rng_np.shuffle(idx) + split = int(0.7 * len(idx)) + tr_idx, va_idx = idx[:split], idx[split:] + + train_data = (pos_all[tr_idx], lab_all[tr_idx], w_all[tr_idx]) + val_data = (pos_all[va_idx], lab_all[va_idx], w_all[va_idx]) + + # ----- Model config (indices / triplets) + indices = np.arange(50) + tri_idx1 = np.asarray([44, 47, 47, 6, 5, 47, 7, 5]) + tri_idx2 = np.asarray([46, 46, 49, 49, 44, 49, 46, 7]) + + model = CommittorNN_Dist_Lip(indices=indices, + tri_idx1=tri_idx1, + tri_idx2=tri_idx2, + h1=16, h2=16, h3=8, out_dim=1, sig_k=3.0) + + state = train(model, + masses=masses, + train_data=train_data, + val_data=val_data, + batch_size=10000, + num_epochs=200, # consider reducing for quick tests + lr=1e-3, + seed=61982) + + # ----- Save params ----- + import flax.serialization as serialization + param_bytes = serialization.to_bytes(state.params) + with open("distance_flax.params", "wb") as f: + f.write(param_bytes) + + forward_eval = make_forward_eval(model) + + # ----- Evaluate on A/B/C and plot ----- + def batched_preds(pos_arr, bs=2048): + outs = [] + for xb, _, _ in make_batches(pos_arr, + onp.zeros(len(pos_arr), onp.int32), + onp.ones(len(pos_arr), onp.float32), + bs, shuffle=False): + xb = np.asarray(xb, dtype=onp.float32) + y = forward_eval(state.params, xb) + outs.append(onp.asarray(y)) + return onp.concatenate(outs, axis=0) if outs else onp.zeros((0,), dtype=onp.float32) + + As = batched_preds(posA) + Bs = batched_preds(posB) + Cs = batched_preds(posC) + + import matplotlib + plt.figure() + plt.hist([As, Bs, Cs], bins=100, label=['Reactant', 'Product', 'Intermediate'], + alpha=0.3, histtype='stepfilled', density=True) + plt.legend(); plt.xlabel('Output value'); plt.ylabel('Probability Density') + plt.title('Histogram of outputs by class') + plt.savefig('fig_jax.png', dpi=200) + + print("Saved: distance_flax.params, fig_jax.png") diff --git a/pysages/methods/__init__.py b/pysages/methods/__init__.py index 875baa91..736bc65c 100644 --- a/pysages/methods/__init__.py +++ b/pysages/methods/__init__.py @@ -71,12 +71,17 @@ from .restraints import CVRestraints from .sirens import Sirens from .spectral_abf import SpectralABF +from .funnel_sabf import Funnel_SpectralABF from .spline_string import SplineString from .umbrella_integration import UmbrellaIntegration from .unbiased import Unbiased +from .funnel_static import Static +from .funnel_metadabf import Funnel_MetaDABF +from .funnel_unbiased import Unbias from .utils import ( HistogramLogger, MetaDLogger, + Funnel_Logger, ReplicasConfiguration, SerialExecutor, methods_dispatch, diff --git a/pysages/methods/analysis.py b/pysages/methods/analysis.py index 609a25fb..906bfcba 100644 --- a/pysages/methods/analysis.py +++ b/pysages/methods/analysis.py @@ -173,3 +173,195 @@ def average_forces(hist, Fsum): "fes_fn": first_or_all(fes_fns), "mesh": transpose(mesh).reshape(-1, d).squeeze(), } + + +@dispatch +def _funnelanalyze(result: Result, strategy: GradientLearning, topology): + """ + Computes the free energy from the result of an `FunnelABF`-based run. + Integrates the forces via a gradient learning strategy. + + Parameters + ---------- + + result: Result: + Result bundle containing method, final FunnelABF-like state, and callback. + + strategy: GradientLearning + + topology: Tuple[int] + Defines the architecture of the neural network + (number of nodes in each hidden layer). + + Returns + ------- + + dict: A dictionary with the following keys: + + histogram: JaxArray + Histogram for the states visited during the method. + + mean_force: JaxArray + Average force at each bin of the CV grid. + + free_energy: JaxArray + Free Energy at each bin of the CV grid. + + mesh: JaxArray + Grid used in the method. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the + CV domain defined by the grid. + + NOTE: + For multiple-replicas runs we return a list (one item per-replica) + for each attribute. + """ + + # The ForceNN based analysis occurs in two stages: + # + # 1. The data is smoothed and a first quick fitting is performed to obtain + # an approximate set of network parameters. + # 2. A second training pass is then performed over the raw data starting + # with the parameters from previous step. + + method = result.method + states = result.states + grid = method.grid + mesh = inputs = (compute_mesh(grid) + 1) * grid.size / 2 + grid.lower + + model = MLP(grid.shape.size, 1, topology, transform=partial(_scale, grid=grid)) + loss = GradientsSSE() + regularizer = L2Regularization(1e-4) + + # Stage 1 optimizer + pre_optimizer = LevenbergMarquardt(loss=loss, max_iters=250, reg=regularizer) + pre_fit = build_fitting_function(model, pre_optimizer) + + # Stage 2 optimizer + optimizer = LevenbergMarquardt(loss=loss, max_iters=1000, reg=regularizer) + fit = build_fitting_function(model, optimizer) + + @vmap + def smooth(data, conv_dtype=np.float32): + data_dtype = data.dtype + boundary = "wrap" if grid.is_periodic else "edge" + kernel = np.asarray(blackman_kernel(grid.shape.size, 7), dtype=conv_dtype) + data = np.asarray(data, dtype=conv_dtype) + return np.asarray(convolve(data.T, kernel, boundary=boundary), dtype=data_dtype).T + + @jit + def pre_train(nn, data): + params = pre_fit(nn.params, inputs, smooth(data)).params + return NNData(params, nn.mean, nn.std) + + @jit + def train(nn, data): + params = fit(nn.params, inputs, data).params + return NNData(params, nn.mean, nn.std) + + def build_fes_fn(state): + hist = np.expand_dims(state.hist, state.hist.ndim) + F = state.Fsum / np.maximum(hist, 1) + + # Scale the mean forces + s = np.abs(F).max() + F = F / s + + ps, layout = unpack(model.parameters) + nn = pre_train(NNData(ps, 0.0, s), F) + nn = train(nn, F) + + def fes_fn(x): + params = pack(nn.params, layout) + A = nn.std * model.apply(params, x) + nn.mean + return A.max() - A + + return jit(fes_fn) + + def average_forces(hist, Fsum): + shape = (*Fsum.shape[:-1], 1) + return Fsum / np.maximum(hist.reshape(shape), 1) + + def build_corr_fn(state): + hist = np.expand_dims(state.hist, state.hist.ndim) + F = (state.Fsum + state.Frestr) / np.maximum(hist, 1) + + # Scale the mean forces + s = np.abs(F).max() + F = F / s + + ps, layout = unpack(model.parameters) + nn = pre_train(NNData(ps, 0.0, s), F) + nn = train(nn, F) + + def fes_fn(x): + params = pack(nn.params, layout) + A = nn.std * model.apply(params, x) + nn.mean + return A.max() - A + + return jit(fes_fn) + + def build_restr_fn(state): + hist = np.expand_dims(state.hist, state.hist.ndim) + F = state.Frestr / np.maximum(hist, 1) + + # Scale the mean forces + s = np.abs(F).max() + F = F / s + + ps, layout = unpack(model.parameters) + nn = pre_train(NNData(ps, 0.0, s), F) + nn = train(nn, F) + + def fes_fn(x): + params = pack(nn.params, layout) + A = nn.std * model.apply(params, x) + nn.mean + return A.max() - A + + return jit(fes_fn) + + hists = [] + mean_forces = [] + free_energies = [] + fes_fns = [] + corr_forces = [] + corrected_energies = [] + corr_fns = [] + restr_forces = [] + restraint_energies = [] + restr_fns = [] + + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + d = mesh.shape[-1] + + for state in states: + fes_fn = build_fes_fn(state) + hists.append(transpose(state.hist)) + mean_forces.append(transpose(average_forces(state.hist, state.Fsum))) + free_energies.append(transpose(fes_fn(mesh))) + fes_fns.append(fes_fn) + corr_fn = build_corr_fn(state) + corr_forces.append(transpose(average_forces(state.hist, state.Fsum + state.Frestr))) + corrected_energies.append(transpose(corr_fn(mesh))) + corr_fns.append(corr_fn) + restr_fn = build_restr_fn(state) + restr_forces.append(transpose(average_forces(state.hist, state.Frestr))) + restraint_energies.append(transpose(restr_fn(mesh))) + restr_fns.append(restr_fn) + + return { + "histogram": first_or_all(hists), + "mean_force": first_or_all(mean_forces), + "free_energy": first_or_all(free_energies), + "fes_fn": first_or_all(fes_fns), + "mesh": transpose(mesh).reshape(-1, d).squeeze(), + "corrected_force": first_or_all(corr_forces), + "corrected_energy": first_or_all(corrected_energies), + "corr_fn": first_or_all(corr_fns), + "restraint_force": first_or_all(restr_forces), + "restraint_energy": first_or_all(restraint_energies), + "restr_fn": first_or_all(restr_fns), + } diff --git a/pysages/methods/funnel_function.py b/pysages/methods/funnel_function.py new file mode 100644 index 00000000..5924501b --- /dev/null +++ b/pysages/methods/funnel_function.py @@ -0,0 +1,44 @@ +# funnel functions +from functools import partial + +import jax.numpy as np +from jax import jit, grad +from jax.numpy import linalg + +def distance(r, cell_size): + diff = r[1:] - r[0] + #diff = diff - np.round(diff / cell_size) * cell_size + return np.linalg.norm(diff,axis=1) + +def coordnum_energy(r, cell_size, c_mins, k, idx = np.asarray([ [0,1],[0,2],[0,3],[1,2],[1,3],[2,3] ])): + total = 0 #all O somewhat close to C + c_min = c_mins[0] + dist = distance(r[:5],cell_size) + total += 0.5*(np.where(dist > c_min, dist - c_min, 0.0)**2).sum() + + c_min = c_mins[1] #Os are all close to each other + rO = r[1:5] + dists = np.linalg.norm( rO[idx[:,0]] - rO[idx[:,1]],axis=1 ) + total += 0.5*(np.where(dists > c_min, dists - c_min, 0.0)**2).sum() + return k * total + + +def intermediate_funnel(pos, ids, indexes, cell_size, c_mins, k): + r = pos[ids[indexes]] + return coordnum_energy(r, cell_size, c_mins, k) + +def log_funnel(): + return 0.0 + +def external_funnel(data, indexes, cell_size, c_mins, k): + pos = data.positions[:, :3] + ids = data.indices + bias = grad(intermediate_funnel)(pos, ids, indexes, cell_size, c_mins, k) + proj = log_funnel() + return bias, proj + +def get_funnel_force(indexes, cell_size, c_mins, k): + funnel_force = partial( + external_funnel, + indexes=indexes, cell_size=cell_size, c_mins=c_mins, k=k) + return jit(funnel_force) diff --git a/pysages/methods/funnel_metadabf.py b/pysages/methods/funnel_metadabf.py new file mode 100644 index 00000000..7e4d0cb9 --- /dev/null +++ b/pysages/methods/funnel_metadabf.py @@ -0,0 +1,477 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Implementation of Standard and Well-tempered MetaD-ABF +both with support for grids. +""" + +from jax import grad, jit +from jax import numpy as np +from jax import value_and_grad, vmap +from jax.lax import cond + +from pysages.approxfun import compute_mesh +from pysages.colvars import get_periods, wrap +from pysages.grids import build_indexer +from pysages.methods.analysis import GradientLearning, _funnelanalyze +from pysages.methods.core import GriddedSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.methods.utils import numpyfy_vals +from pysages.typing import JaxArray, NamedTuple, Optional +from pysages.utils import dispatch, gaussian, identity, linear_solver + + +class FMetaDABFState(NamedTuple): + """ + MetaDABF helper state + + Parameters + ---------- + + xi: JaxArray + Collective variable value in the last simulation step. + + bias: JaxArray + Array of Metadynamics bias forces for each particle in the simulation. + + heights: JaxArray + Height values for all accumulated Gaussians (zeros for not yet added Gaussians). + + centers: JaxArray + Centers of the accumulated Gaussians. + + sigmas: JaxArray + Widths of the accumulated Gaussians. + + grid_potential: Optional[JaxArray] + Array of Metadynamics bias potentials stored on a grid. + + grid_gradient: Optional[JaxArray] + Array of Metadynamics bias gradients evaluated on a grid. + + idx: int + Index of the next Gaussian to be deposited. + + idg: int + Index of the CV in the forces array. + + hist: JaxArray (grid.shape) + Histogram of visits to the bins in the collective variable grid. + + Fsum: JaxArray (grid.shape, CV shape) + Cumulative forces at each bin in the CV grid. + + force: JaxArray (grid.shape, CV shape) + Average force at each bin of the CV grid. + + Wp: JaxArray (CV shape) + Product of W matrix and momenta matrix for the current step. + + Wp_: JaxArray (CV shape) + Product of W matrix and momenta matrix for the previous step. + + ncalls: int + Counts the number of times `method.update` has been called. + + perp: JaxArray + Collective variable perpendicular to the Funnel_CV + """ + + xi: JaxArray + bias: JaxArray + heights: JaxArray + centers: JaxArray + sigmas: JaxArray + grid_potential: Optional[JaxArray] + grid_gradient: Optional[JaxArray] + idx: int + idg: int + Wp: JaxArray + Wp_: JaxArray + force: JaxArray + Fsum: JaxArray + Frestr: JaxArray + hist: JaxArray + ncalls: int + perp: JaxArray + + def __repr__(self): + return repr("PySAGES" + type(self).__name__) + + +class PartialMetadynamicsState(NamedTuple): + """ + Helper intermediate Metadynamics state + """ + + xi: JaxArray + heights: JaxArray + centers: JaxArray + sigmas: JaxArray + grid_potential: Optional[JaxArray] + grid_gradient: Optional[JaxArray] + idx: int + grid_idx: Optional[JaxArray] + + +class Funnel_MetaDABF(GriddedSamplingMethod): + """ + Implementation of Standard and Well-tempered Funnel MetaDABF as described in + arXiv preprint arXiv:2504.13575. + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, height, sigma, stride, ngaussians, deltaT=None, **kwargs): + """ + Parameters + ---------- + + cvs: + Set of user selected collective variable. + + height: + Initial height of the deposited Gaussians. + + sigma: + Initial standard deviation of the to-be-deposit Gaussians. + + stride: int + Bias potential deposition frequency. + + ngaussians: int + Total number of expected Gaussians (`timesteps // stride + 1`). + + deltaT: Optional[float] = None + Well-tempered Metadynamics :math:`\\Delta T` parameter + (if `None` standard Metadynamics is used). + + grid: Optional[Grid] = None + If provided, it will be used to accelerate the computation by + approximating the bias potential and its gradient over its centers. + + kB: Optional[float] + Boltzmann constant. Must be provided for well-tempered Metadynamics + simulations and should match the internal units of the backend. + + restraints: Optional[CVRestraints] = None + If provided, it will be used to restraint CV space inside the grid. + + external_force: + External restraint to be used for funnel calculations. + """ + + if deltaT is not None and "kB" not in kwargs: + raise KeyError( + "For well-tempered Metadynamics a keyword argument `kB` for " + "the value of the Boltzmann constant (that matches the " + "internal units of the backend) must be provided." + ) + + kwargs["grid"] = kwargs.get("grid", None) + kwargs["restraints"] = kwargs.get("restraints", None) + super().__init__(cvs, **kwargs) + + self.height = height + self.sigma = sigma + self.stride = stride + self.ngaussians = ngaussians # NOTE: infer from timesteps and stride + self.deltaT = deltaT + self.kB = kwargs.get("kB", None) + self.use_pinv = self.kwargs.get("use_pinv", False) + + def build(self, snapshot, helpers, *args, **kwargs): + self.external_force = self.kwargs.get("external_force", None) + return _metadynamics(self, snapshot, helpers) + + +def _metadynamics(method, snapshot, helpers): + # Initialization and update of biasing forces. Interface expected for methods. + cv = method.cv + stride = method.stride + dims = method.grid.shape.size + dt = snapshot.dt + ngaussians = method.ngaussians + get_grid_index = build_indexer(method.grid) + external_force = method.external_force + tsolve = linear_solver(method.use_pinv) + natoms = np.size(snapshot.positions, 0) + + deposit_gaussian = build_gaussian_accumulator(method) + evaluate_bias_grad = build_bias_grad_evaluator(method) + + def initialize(): + xi, _ = cv(helpers.query(snapshot)) + bias = np.zeros((natoms, helpers.dimensionality())) + perp = 0.0 + # NOTE: for restart; use hills file to initialize corresponding arrays. + heights = np.zeros(ngaussians, dtype=np.float64) + centers = np.zeros((ngaussians, xi.size), dtype=np.float64) + sigmas = np.array(method.sigma, dtype=np.float64, ndmin=2) + hist = np.zeros(method.grid.shape, dtype=np.uint32) + Fsum = np.zeros((*method.grid.shape, dims)) + force = np.zeros(dims) + Wp = np.zeros(dims) + Wp_ = np.zeros(dims) + restr = np.zeros(dims) + Frestr = np.zeros((*method.grid.shape, dims)) + # Arrays to store forces and bias potential on a grid. + if method.grid is None: + grid_potential = grid_gradient = None + else: + shape = method.grid.shape + grid_potential = np.zeros((*shape,), dtype=np.float64) if method.deltaT else None + grid_gradient = np.zeros((*shape, shape.size), dtype=np.float64) + + return FMetaDABFState( + xi, + bias, + heights, + centers, + sigmas, + grid_potential, + grid_gradient, + 0, + 0, + Wp, + Wp_, + force, + Fsum, + Frestr, + hist, + 0, + perp, + ) + + def update(state, data): + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + + # Deposit Gaussian depending on the stride + ncalls = state.ncalls + 1 + in_deposition_step = (ncalls > 1) & (ncalls % stride == 1) + partial_state = deposit_gaussian(xi, state, in_deposition_step) + + # Evaluate gradient of biasing potential (or generalized force) + force = evaluate_bias_grad(partial_state) + + # Calculate biasing forces + bias = -Jxi.T @ force.flatten() + eforce, perp = external_force(data) + bias = bias.reshape(state.bias.shape) - eforce.reshape(state.bias.shape) + p = data.momenta + Wp = tsolve(Jxi, p) + # Second order backward finite difference + dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt + # + I_xi = get_grid_index(xi) + hist = state.hist.at[I_xi].add(1) + Fsum = state.Fsum.at[I_xi].add(dWp_dt + state.force) + restr = tsolve(Jxi, eforce.reshape(p.shape)) + # + Frestr = state.Frestr.at[I_xi].add(restr) + + return FMetaDABFState( + xi, + bias, + *partial_state[1:-1], + I_xi, + Wp, + state.Wp, + force, + Fsum, + Frestr, + hist, + ncalls, + perp + ) + + return snapshot, initialize, generalize(update, helpers, jit_compile=True) + + +def build_gaussian_accumulator(method: Funnel_MetaDABF): + """ + Returns a function that given a `FMetaDABFState`, and the value of the CV, + stores the next Gaussian that is added to the biasing potential. + """ + periods = get_periods(method.cvs) + height_0 = method.height + deltaT = method.deltaT + grid = method.grid + kB = method.kB + + if deltaT is None: + next_height = jit(lambda *args: height_0) + else: # if well-tempered + if grid is None: + evaluate_potential = jit(lambda pstate: sum_of_gaussians(*pstate[:4], periods)) + else: + evaluate_potential = jit(lambda pstate: pstate.grid_potential[pstate.grid_idx]) + + def next_height(pstate): + V = evaluate_potential(pstate) + return height_0 * np.exp(-V / (deltaT * kB)) + + if grid is None: + get_grid_index = jit(lambda arg: None) + update_grids = jit(lambda *args: (None, None)) + should_deposit = jit(lambda pred, _: pred) + else: + grid_mesh = compute_mesh(grid) + get_grid_index = build_indexer(grid) + # Reshape so the dimensions are compatible + accum = jit(lambda total, val: total + val.reshape(total.shape)) + + if deltaT is None: + transform = grad + pack = jit(lambda x: (x,)) + # No need to accumulate values for the potential (V is None) + update = jit(lambda V, dV, vals: (V, accum(dV, vals))) + else: + transform = value_and_grad + pack = identity + update = jit(lambda V, dV, vals, grads: (accum(V, vals), accum(dV, grads))) + + def update_grids(pstate, height, xi, sigma): + # We use `sum_of_gaussians` since it already takes care of the wrapping + current_gaussian = jit(lambda x: sum_of_gaussians(x, height, xi, sigma, periods)) + # Evaluate gradient of bias (and bias potential for WT version) + grid_values = pack(vmap(transform(current_gaussian))(grid_mesh)) + return update(pstate.grid_potential, pstate.grid_gradient, *grid_values) + + def should_deposit(in_deposition_step, I_xi): + in_bounds = ~(np.any(np.array(I_xi) == grid.shape)) + return in_deposition_step & in_bounds + + def deposit_gaussian(pstate): + xi, idx = pstate.xi, pstate.idx + current_height = next_height(pstate) + heights = pstate.heights.at[idx].set(current_height) + centers = pstate.centers.at[idx].set(xi.flatten()) + sigmas = pstate.sigmas + grid_potential, grid_gradient = update_grids(pstate, current_height, xi, sigmas) + return PartialMetadynamicsState( + xi, heights, centers, sigmas, grid_potential, grid_gradient, idx + 1, pstate.grid_idx + ) + + def _deposit_gaussian(xi, state, in_deposition_step): + I_xi = get_grid_index(xi) + pstate = PartialMetadynamicsState(xi, *state[2:-9], I_xi) + predicate = should_deposit(in_deposition_step, I_xi) + return cond(predicate, deposit_gaussian, identity, pstate) + + return _deposit_gaussian + + +def build_bias_grad_evaluator(method: Funnel_MetaDABF): + """ + Returns a function that given the deposited Gaussians parameters, computes the + gradient of the biasing potential with respect to the CVs. + """ + grid = method.grid + restraints = method.restraints + if grid is None: + periods = get_periods(method.cvs) + evaluate_bias_grad = jit(lambda pstate: grad(sum_of_gaussians)(*pstate[:4], periods)) + else: + if restraints: + + def ob_force(pstate): # out-of-bounds force + lo, hi, kl, kh = restraints + xi, *_ = pstate + xi = pstate.xi.reshape(grid.shape.size) + force = apply_restraints(lo, hi, kl, kh, xi) + return force + + else: + + def ob_force(pstate): # out-of-bounds force + return np.zeros(grid.shape.size) + + def get_force(pstate): + return pstate.grid_gradient[pstate.grid_idx] + + def evaluate_bias_grad(pstate): + ob = np.any(np.array(pstate.grid_idx) == grid.shape) # out of bounds + return cond(ob, ob_force, get_force, pstate) + + return evaluate_bias_grad + + +# Helper function to evaluate bias potential -- may be moved to analysis part +def sum_of_gaussians(xi, heights, centers, sigmas, periods): + """ + Sum of n-dimensional Gaussians potential. + """ + delta_x = wrap(xi - centers, periods) + return gaussian(heights, sigmas, delta_x).sum() + + +@dispatch +def analyze(result: Result[Funnel_MetaDABF], **kwargs): + """ + Computes the free energy from the result of an `Funnel_MetaDABF` run. + Integrates the forces via a gradient learning strategy. + + Parameters + ---------- + + result: Result[Funnel_ABF]: + Result bundle containing method, final ABF state, and callback. + + topology: Optional[Tuple[int]] = (8, 8) + Defines the architecture of the neural network + (number of nodes in each hidden layer). + + Returns + ------- + + dict: + A dictionary with the following keys: + + histogram: JaxArray + Histogram for the states visited during the method. + + mean_force: JaxArray + Average force at each bin of the CV grid. + + free_energy: JaxArray + Free Energy at each bin of the CV grid. + + mesh: JaxArray + Grid used in the method. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the + CV domain defined by the grid. + + corrected_force: JaxArray + Average mean force without restraint at each bin of the CV grid. + + corrected_energy: JaxArray + Free Energy without restraint at each bin of the CV grid. + + corr_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy without restraints + in the CV domain defined by the grid. + + restraint_force: JaxArray + Average mean force of the restraints at each bin of the CV grid. + + restraint_energy: JaxArray + Free Energy of the restraints at each bin of the CV grid. + + restr_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy of the restraints + in the CV domain defined by the grid. + + + + NOTE: + For multiple-replicas runs we return a list (one item per-replica) + for each attribute. + """ + topology = kwargs.get("topology", (8, 8)) + _result = _funnelanalyze(result, GradientLearning(), topology) + return numpyfy_vals(_result) diff --git a/pysages/methods/funnel_sabf.py b/pysages/methods/funnel_sabf.py new file mode 100644 index 00000000..269664ae --- /dev/null +++ b/pysages/methods/funnel_sabf.py @@ -0,0 +1,440 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Funnel Spectral Adaptive Biasing Force Sampling Method. + +""" + +from jax import jit +from jax import numpy as np +from jax.lax import cond + +from pysages.approxfun import ( + Fun, + SpectralGradientFit, + build_evaluator, + build_fitter, + build_grad_evaluator, + compute_mesh, +) +from pysages.grids import Chebyshev, Grid, build_indexer, convert, grid_transposer +from pysages.methods.core import GriddedSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.methods.utils import numpyfy_vals +from pysages.typing import JaxArray, NamedTuple, Tuple +from pysages.utils import dispatch, first_or_all, linear_solver + + +class SpectralFABFState(NamedTuple): + """ + Funnel_SpectralABF internal state. + + Parameters + ---------- + + xi: JaxArray [cvs.shape] + Last collective variable recorded in the simulation. + + bias: JaxArray [natoms, 3] + Array with biasing forces for each particle. + + hist: JaxArray [grid.shape] + Histogram of visits to the bins in the collective variable grid. + + Fsum: JaxArray [grid.shape, cvs.shape] + The cumulative force recorded at each bin of the CV grid. + + force: JaxArray [grid.shape, cvs.shape] + Average force at each bin of the CV grid. + + Wp: JaxArray [cvs.shape] + Estimate of the product $W p$ where `p` is the matrix of momenta and + `W` the Moore-Penrose inverse of the Jacobian of the CVs. + + Wp_: JaxArray [cvs.shape] + The value of `Wp` for the previous integration step. + + fun: Fun + Object that holds the coefficients of the basis functions + approximation to the free energy. + + restr: JaxArray [grid.shape, cvs.shape] + Instantaneous restraint force at each bin of the CV grid. + + proj: JaxArray [cvs.shape] + Last collective variable from restraints recorded in the simulation. + + Frestr: JaxArray [grid.shape, cvs.shape] + The cumulative restraint force recorded at each bin of the CV grid. + + ncalls: int + Counts the number of times the method's update has been called. + """ + + xi: JaxArray + bias: JaxArray + hist: JaxArray + Fsum: JaxArray + force: JaxArray + Wp: JaxArray + Wp_: JaxArray + fun: Fun + restr: JaxArray + proj: JaxArray + Frestr: JaxArray + ncalls: int + + def __repr__(self): + return repr("PySAGES " + type(self).__name__) + + +class PartialSpectralFABFState(NamedTuple): + xi: JaxArray + hist: JaxArray + Fsum: JaxArray + ind: Tuple + fun: Fun + pred: bool + + +class Funnel_SpectralABF(GriddedSamplingMethod): + """ + Implementation of the Funnel_Spectral ABF method described in + + Parameters + ---------- + + cvs: Union[List, Tuple] + Set of user selected collective variable. + + grid: Grid + Specifies the collective variables domain and number of bins for discretizing + the CV space along each CV dimension. For non-periodic grids this will be + converted to a Chebyshev-distributed grid. + + N: Optional[int] = 500 + Threshold parameter before accounting for the full average + of the adaptive biasing force. + + fit_freq: Optional[int] = 100 + Fitting frequency. + + fit_threshold: Optional[int] = 500 + Number of time steps after which fitting starts to take place. + + restraints: Optional[CVRestraints] = None + If provided, indicate that harmonic restraints will be applied when any + collective variable lies outside the box from `restraints.lower` to + `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. + + ext_force: Optional[ext_force] = None + If provided, indicate the geometric restraints will be applied for any + collective variables. + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, grid, **kwargs): + super().__init__(cvs, grid, **kwargs) + self.N = np.asarray(self.kwargs.get("N", 500)) + self.fit_freq = self.kwargs.get("fit_freq", 100) + self.fit_threshold = self.kwargs.get("fit_threshold", 500) + self.fit_threshold_upper = self.kwargs.get("fit_threshold_upper", 1e7) + self.grid = self.grid if self.grid.is_periodic else convert(self.grid, Grid[Chebyshev]) + self.model = SpectralGradientFit(self.grid) + self.use_pinv = self.kwargs.get("use_pinv", False) + + def build(self, snapshot, helpers, *_args, **_kwargs): + """ + Returns the `initialize` and `update` functions for the sampling method. + """ + self.ext_force = self.kwargs.get("ext_force", None) + return _spectral_abf(self, snapshot, helpers) + + +def _spectral_abf(method, snapshot, helpers): + cv = method.cv + grid = method.grid + fit_freq = method.fit_freq + fit_threshold = method.fit_threshold + fit_threshold_upper = method.fit_threshold_upper + + dt = snapshot.dt + dims = grid.shape.size + natoms = np.size(snapshot.positions, 0) + + # Helper methods + tsolve = linear_solver(method.use_pinv) + get_grid_index = build_indexer(grid) + fit = build_fitter(method.model) + fit_forces = build_free_energy_fitter(method, fit) + estimate_force = build_force_estimator(method) + ext_force = method.ext_force + + query, dimensionality, to_force_units = helpers + + def initialize(): + xi, _ = cv(query(snapshot)) + bias = np.zeros((natoms, dimensionality())) + hist = np.zeros(grid.shape, dtype=np.uint32) + Fsum = np.zeros((*grid.shape, dims)) + force = np.zeros(dims) + Wp = np.zeros(dims) + Wp_ = np.zeros(dims) + fun = fit(Fsum) + restr = np.zeros(dims) + proj = 0.0 + Frestr = np.zeros((*grid.shape, dims)) + return SpectralFABFState(xi, bias, hist, Fsum, force, Wp, Wp_, fun, restr, proj, Frestr, 0) + + def update(state, data): + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + # Restraint force and logger + e_f, proj = ext_force(data) + # During the intial stage use ABF + I_xi = get_grid_index(xi) + ncalls = state.ncalls + 1 + in_fitting_regime = state.hist[I_xi] > fit_threshold + in_fitting_step = in_fitting_regime & (ncalls % fit_freq == 1) & (ncalls < fit_threshold_upper) + # Fit forces + fun = fit_forces(state, in_fitting_step) + # + p = data.momenta + Wp = tsolve(Jxi, p) + # Second order backward finite difference + dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt + # + hist = state.hist.at[I_xi].add(1) + Fsum = state.Fsum.at[I_xi].add(to_force_units(dWp_dt) + state.force) + # + force = estimate_force( + PartialSpectralFABFState(xi, hist, Fsum, I_xi, fun, in_fitting_regime) + ) + bias = np.reshape(-Jxi.T @ force, state.bias.shape) + np.reshape(-e_f, state.bias.shape) + # Restraint contribution to force + restr = tsolve(Jxi, e_f.reshape(p.shape)) + Frestr = state.Frestr.at[I_xi].add(restr) + return SpectralFABFState( + xi, bias, hist, Fsum, force, Wp, state.Wp, fun, restr, proj, Frestr, ncalls + ) + + return snapshot, initialize, generalize(update, helpers) + + +def build_free_energy_fitter(_method: Funnel_SpectralABF, fit): + """ + Returns a function that given a `SpectralFABFState` performs a least squares fit of the + generalized average forces for finding the coefficients of a basis functions expansion + of the free energy. + """ + + def _fit_forces(state): + shape = (*state.Fsum.shape[:-1], 1) + force = state.Fsum / np.maximum(state.hist.reshape(shape), 1) + return fit(force) + + def skip_fitting(state): + return state.fun + + def fit_forces(state, in_fitting_step): + return cond(in_fitting_step, _fit_forces, skip_fitting, state) + + return fit_forces + + +@dispatch +def build_force_estimator(method: Funnel_SpectralABF): + """ + Returns a function that given the coefficients of basis functions expansion and a CV + value, evaluates the function approximation to the gradient of the free energy. + """ + N = method.N + grid = method.grid + dims = grid.shape.size + model = method.model + get_grad = build_grad_evaluator(model) + + def average_force(state): + i = state.ind + return 0*state.Fsum[i] #state.Fsum[i] / np.maximum(state.hist[i], N) + + def interpolate_force(state): + return get_grad(state.fun, state.xi).reshape(grid.shape.size) + + def _estimate_force(state): + return cond(state.pred, interpolate_force, average_force, state) + + if method.restraints is None: + ob_force = jit(lambda state: np.zeros(dims)) + else: + lo, hi, kl, kh = method.restraints + + def ob_force(state): + xi = state.xi.reshape(grid.shape.size) + return apply_restraints(lo, hi, kl, kh, xi) + + def estimate_force(state): + ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition + return cond(ob, ob_force, _estimate_force, state) + + return estimate_force + + +@dispatch +def analyze(result: Result[Funnel_SpectralABF]): + """ + Parameters + ---------- + + result: Result[FunnelSpectralABF] + Result bundle containing the method, final states, and callbacks. + + dict: + A dictionary with the following keys: + + histogram: JaxArray + A histogram of the visits to each bin in the CV grid. + + mean_force: JaxArray + Generalized mean forces at each bin in the CV grid. + + free_energy: JaxArray + Free energy at each bin in the CV grid. + + mesh: JaxArray + These are the values of the CVs that are used as inputs for training. + + fun: Fun + Coefficients of the basis functions expansion approximating the free energy. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the CV domain defined + by the grid. + + fun: Fun + Coefficients of the basis functions expansion approximating the free energy. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the CV domain defined + by the grid. + + fun_corr: Fun + Coefficients of the basis functions expansion approximating the free energy + defined without external restraints. + + corr_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy without external restraints + in the CV domain defined by the grid. + fun_rstr: Fun + Coefficients of the basis functions expansion approximating the free energy + of the external restraints. + + rstr_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy of theexternal restraints + in the CV domain defined by the grid. + + corrected_force: JaxArray + Generalized mean forces without restraints at each bin in the CV grid. + + corrected_energy: JaxArray + Free energy without restraints at each bin in the CV grid. + + restraint_force: JaxArray + Generalized mean forces of the restraints at each bin in the CV grid. + + restraint_energy: JaxArray + Free energy of restraints at each bin in the CV grid. + + + NOTE: + For multiple-replicas runs we return a list (one item per-replica) for each attribute. + """ + method = result.method + fit = build_fitter(method.model) + grid = method.grid + mesh = compute_mesh(grid) + evaluate = build_evaluator(method.model) + + def average_forces(hist, Fsum): + hist = np.expand_dims(hist, hist.ndim) + return Fsum / np.maximum(hist, 1) + + def fit_corr(state): + shape = (*state.Fsum.shape[:-1], 1) + force = (state.Fsum + state.Frestr) / np.maximum(state.hist.reshape(shape), 1) + return fit(force) + + def fit_restr(state): + shape = (*state.Fsum.shape[:-1], 1) + force = (state.Frestr) / np.maximum(state.hist.reshape(shape), 1) + return fit(force) + + def build_fes_fn(fun): + def fes_fn(x): + A = evaluate(fun, x) + return A.max() - A + + return jit(fes_fn) + + hists = [] + mean_forces = [] + free_energies = [] + funs = [] + fes_fns = [] + forces_corrected = [] + corrected_energies = [] + funs_corr = [] + fes_corr = [] + restraint_forces = [] + restraint_energies = [] + funs_rstr = [] + fes_rstr = [] + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + d = mesh.shape[-1] + + for s in result.states: + fun_corr = fit_corr(s) + fun_rstr = fit_restr(s) + fes_fn = build_fes_fn(s.fun) + corr_fn = build_fes_fn(fun_corr) + restr_fn = build_fes_fn(fun_rstr) + hists.append(transpose(s.hist)) + mean_forces.append(transpose(average_forces(s.hist, s.Fsum))) + free_energies.append(transpose(fes_fn(mesh))) + funs.append(s.fun) + fes_fns.append(fes_fn) + forces_corrected.append(transpose(average_forces(s.hist, s.Fsum + s.Frestr))) + corrected_energies.append(transpose(corr_fn(mesh))) + funs_corr.append(fun_corr) + fes_corr.append(corr_fn) + restraint_forces.append(transpose(average_forces(s.hist, s.Frestr))) + restraint_energies.append(transpose(restr_fn(mesh))) + funs_rstr.append(fun_rstr) + fes_rstr.append(restr_fn) + + ana_result = { + "histogram": first_or_all(hists), + "mean_force": first_or_all(mean_forces), + "free_energy": first_or_all(free_energies), + "mesh": transpose(mesh).reshape(-1, d).squeeze(), + "fun": first_or_all(funs), + "fes_fn": first_or_all(fes_fns), + "corrected_force": first_or_all(forces_corrected), + "corrected_energy": first_or_all(corrected_energies), + "fun_corr": first_or_all(funs_corr), + "corr_fn": first_or_all(fes_corr), + "restraint_force": first_or_all(restraint_forces), + "restraint_energy": first_or_all(restraint_energies), + "fun_rstr": first_or_all(funs_rstr), + "rstr_fn": first_or_all(fes_rstr), + } + + return numpyfy_vals(ana_result) diff --git a/pysages/methods/funnel_static.py b/pysages/methods/funnel_static.py new file mode 100644 index 00000000..330792a7 --- /dev/null +++ b/pysages/methods/funnel_static.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Funnel Adaptive Biasing Force (Funnel_ABF) sampling method. + +ABF partitions the collective variable space into bins determined by a user +provided grid, and keeps a tabulation of the number of visits to each bin +as well as the sum of generalized forces experienced by the system at each +configuration bin. These provide an estimate for the mean generalized force, +which can be integrated to yield the free energy. This update allows to include + external restraint and accurately remove their contribution in the free energy + calculations. + +""" + +from jax import jit +from jax import numpy as np +from jax.lax import cond + +from pysages.grids import build_indexer +from pysages.methods.core import GriddedSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.typing import JaxArray, NamedTuple +import pysages +from pysages.utils import dispatch + +class StaticState(NamedTuple): + """ + Funnel_ABF internal state. + + Parameters + ---------- + + xi: JaxArray (CV shape) + Last collective variable recorded in the simulation. + + bias: JaxArray (Nparticles, d) + Array with biasing forces for each particle. + + """ + + xi: JaxArray + bias: JaxArray + + def __repr__(self): + return repr("PySAGES " + type(self).__name__) + + +class Static(GriddedSamplingMethod): + """ + Funnel with Static. + + Attributes + ---------- + + snapshot_flags: + Indicate the system properties required from a snapshot. + + Parameters + ---------- + + cvs: Union[List, Tuple] + Set of user selected collective variable. + + grid: Grid + Specifies the collective variables domain and number of bins for + discretizing the CV space along each CV dimension. + + restart_file: str + .pkl file from previous pysages simulation. + + restraints: Optional[CVRestraints] = None + If provided, indicate that harmonic restraints will be applied when any + collective variable lies outside the box from `restraints.lower` to + `restraints.upper`. + + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, grid, restart_file, **kwargs): + super().__init__(cvs, grid, **kwargs) + run_result = pysages.load(restart_file) + fe_result = pysages.analyze(run_result) + self.force_array = np.asarray(-fe_result['mean_force']) + + def build(self, snapshot, helpers, *args, **kwargs): + """ + Build the functions for the execution of ABF + + Parameters + ---------- + + snapshot: + PySAGES snapshot of the simulation (backend dependent). + + helpers: + Helper function bundle as generated by + `SamplingMethod.context[0].get_backend().build_helpers`. + + Returns + ------- + + Tuple `(snapshot, initialize, update)` to run ABF simulations. + """ + self.ext_force = self.kwargs.get("ext_force", None) + return _static(self, snapshot, helpers) + + +def _static(method, snapshot, helpers): + """ + Internal function that generates the init and update functions. + + Parameters + ---------- + + method: ABF + Class that generates the functions. + snapshot: + PySAGES snapshot of the simulation (backend dependent). + helpers + Helper function bundle as generated by + `SamplingMethod.context[0].get_backend().build_helpers`. + + Returns + ------- + Tuple `(snapshot, initialize, update)` to run static bias simulations. + """ + cv = method.cv + grid = method.grid + force_array = method.force_array + dims = grid.shape.size + natoms = np.size(snapshot.positions, 0) + get_grid_index = build_indexer(grid) + estimate_force = build_force_estimator(method) + ext_force = method.ext_force + + def initialize(): + """ + Internal function that generates the first FABFState + with correctly shaped JaxArrays. + + Returns + ------- + FABFState + Initialized State + """ + xi, _ = cv(helpers.query(snapshot)) + bias = np.zeros((natoms, helpers.dimensionality())) + return StaticState(xi, bias) + + def update(state, data): + """ + Advance the state of the Funnel_ABF simulation. + + Parameters + ---------- + + state: FABFstate + Old FABFstate from the previous simutlation step. + data: JaxArray + Snapshot to access simulation data. + + Returns + ------- + FABFState + Updated internal state. + """ + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + + # Restraint force and logger + e_f, proj = ext_force(data) + # Second order backward finite difference + I_xi = get_grid_index(xi) + force = estimate_force(xi, I_xi, force_array).reshape(dims) + bias = np.reshape(-Jxi.T @ force, state.bias.shape) + np.reshape(-e_f, state.bias.shape) + # Restraint contribution to force + + return StaticState(xi, bias) + + return snapshot, initialize, generalize(update, helpers) + + +@dispatch +def build_force_estimator(method: Static): + """ + Returns a function that computes the average forces + (or the harmonic restraints forces if provided). + """ + grid = method.grid + + def bias_force(data): + _, I_xi, force_array = data + return force_array[I_xi] + + if method.restraints is None: + estimate_force = jit(lambda *args: bias_force(args)) + else: + lo, hi, kl, kh = method.restraints + + def restraints_force(data): + xi, *_ = data + xi = xi.reshape(grid.shape.size) + return apply_restraints(lo, hi, kl, kh, xi) + + def estimate_force(xi, I_xi, force_array): + ob = np.any(np.array(I_xi) == grid.shape) # Out of bounds condition + data = (xi, I_xi, force_array) + return cond(ob, restraints_force, bias_force, data) + + return estimate_force diff --git a/pysages/methods/funnel_unbiased.py b/pysages/methods/funnel_unbiased.py new file mode 100644 index 00000000..3758abcc --- /dev/null +++ b/pysages/methods/funnel_unbiased.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Funnel Adaptive Biasing Force (Funnel_ABF) sampling method. + +ABF partitions the collective variable space into bins determined by a user +provided grid, and keeps a tabulation of the number of visits to each bin +as well as the sum of generalized forces experienced by the system at each +configuration bin. These provide an estimate for the mean generalized force, +which can be integrated to yield the free energy. This update allows to include + external restraint and accurately remove their contribution in the free energy + calculations. + +""" + +from jax import jit +from jax import numpy as np +from jax.lax import cond + +from pysages.grids import build_indexer +from pysages.methods.core import GriddedSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.typing import JaxArray, NamedTuple +import pysages +from pysages.utils import dispatch + +class UnbiasState(NamedTuple): + """ + Funnel_ABF internal state. + + Parameters + ---------- + + xi: JaxArray (CV shape) + Last collective variable recorded in the simulation. + + bias: JaxArray (Nparticles, d) + Array with biasing forces for each particle. + + """ + + xi: JaxArray + bias: JaxArray + zero: JaxArray + + def __repr__(self): + return repr("PySAGES " + type(self).__name__) + + +class Unbias(GriddedSamplingMethod): + """ + Funnel with Unbias. + + Attributes + ---------- + + snapshot_flags: + Indicate the system properties required from a snapshot. + + Parameters + ---------- + + cvs: Union[List, Tuple] + Set of user selected collective variable. + + grid: Grid + Specifies the collective variables domain and number of bins for + discretizing the CV space along each CV dimension. + + restraints: Optional[CVRestraints] = None + If provided, indicate that harmonic restraints will be applied when any + collective variable lies outside the box from `restraints.lower` to + `restraints.upper`. + + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, grid, **kwargs): + super().__init__(cvs, grid, **kwargs) + + def build(self, snapshot, helpers, *args, **kwargs): + """ + Build the functions for the execution of ABF + + Parameters + ---------- + + snapshot: + PySAGES snapshot of the simulation (backend dependent). + + helpers: + Helper function bundle as generated by + `SamplingMethod.context[0].get_backend().build_helpers`. + + Returns + ------- + + Tuple `(snapshot, initialize, update)` to run ABF simulations. + """ + self.ext_force = self.kwargs.get("ext_force", None) + return _Unbias(self, snapshot, helpers) + + +def _Unbias(method, snapshot, helpers): + """ + Internal function that generates the init and update functions. + + Parameters + ---------- + + method: ABF + Class that generates the functions. + snapshot: + PySAGES snapshot of the simulation (backend dependent). + helpers + Helper function bundle as generated by + `SamplingMethod.context[0].get_backend().build_helpers`. + + Returns + ------- + Tuple `(snapshot, initialize, update)` to run Unbias bias simulations. + """ + cv = method.cv + grid = method.grid + dims = grid.shape.size + natoms = np.size(snapshot.positions, 0) + get_grid_index = build_indexer(grid) + estimate_force = build_force_estimator(method) + ext_force = method.ext_force + + def initialize(): + """ + Internal function that generates the first FABFState + with correctly shaped JaxArrays. + + Returns + ------- + FABFState + Initialized State + """ + xi, _ = cv(helpers.query(snapshot)) + bias = np.zeros((natoms, helpers.dimensionality())) + zero = np.zeros(dims) + return UnbiasState(xi, bias, zero) + + def update(state, data): + """ + Advance the state of the Funnel_ABF simulation. + + Parameters + ---------- + + state: FABFstate + Old FABFstate from the previous simutlation step. + data: JaxArray + Snapshot to access simulation data. + + Returns + ------- + FABFState + Updated internal state. + """ + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + + # Restraint force and logger + e_f, proj = ext_force(data) + # Second order backward finite difference + I_xi = get_grid_index(xi) + force = estimate_force(xi, I_xi, state.zero).reshape(dims) + bias = np.reshape(-Jxi.T @ force, state.bias.shape) + np.reshape(-e_f, state.bias.shape) + # Restraint contribution to force + + return UnbiasState(xi, bias, state.zero) + + return snapshot, initialize, generalize(update, helpers) + + +@dispatch +def build_force_estimator(method: Unbias): + """ + Returns a function that computes the average forces + (or the harmonic restraints forces if provided). + """ + grid = method.grid + + def bias_force(data): + *_, zero = data + return zero + + if method.restraints is None: + estimate_force = jit(lambda *args: bias_force(args)) + else: + lo, hi, kl, kh = method.restraints + + def restraints_force(data): + xi, *_ = data + xi = xi.reshape(grid.shape.size) + return apply_restraints(lo, hi, kl, kh, xi) + + def estimate_force(xi, I_xi, zero): + ob = np.any(np.array(I_xi) == grid.shape) # Out of bounds condition + data = (xi, I_xi, zero) + return cond(ob, restraints_force, bias_force, data) + + return estimate_force diff --git a/pysages/methods/utils.py b/pysages/methods/utils.py index a52894c1..5add537f 100644 --- a/pysages/methods/utils.py +++ b/pysages/methods/utils.py @@ -169,6 +169,45 @@ def __call__(self, snapshot, state, timestep): self.counter += 1 +class Funnel_Logger: + """ + Logs the state of the collective variable and other parameters in Funnel. + Parameters + ---------- + funnel_file: + Name of the output funnel log file. + log_period: + Time steps between logging of collective variables and Funnel parameters. + """ + + def __init__(self, funnel_file, log_period): + """ + Funnel_Logger constructor. + """ + self.funnel_file = funnel_file + self.log_period = log_period + self.counter = 0 + + def save_work(self, xi): + """ + Append the funnel_cv, perp_funnel, and funnel_restraints to log file. + """ + with open(self.funnel_file, "a+", encoding="utf8") as f: + f.write(str(self.counter) + "\t") + f.write("\t".join(map(str, xi.flatten())) + "\n") + # f.write("\t".join(map(str, restr.flatten())) + "\t") + # f.write(str(proj) + "\n") + + def __call__(self, snapshot, state, timestep): + """ + Implements the logging itself. Interface as expected for Callbacks. + """ + if self.counter >= self.log_period and self.counter % self.log_period == 0: + self.save_work(state.xi) + + self.counter += 1 + + def listify(arg, replicas, name, dtype): """ Returns a list of with length `replicas` of `arg` if `arg` is not a list, diff --git a/pysages/serialization.py b/pysages/serialization.py index c83f2fbd..2efcea04 100644 --- a/pysages/serialization.py +++ b/pysages/serialization.py @@ -20,7 +20,7 @@ import dill as pickle from pysages.backends.snapshot import Box, Snapshot -from pysages.methods import Metadynamics +from pysages.methods import Metadynamics, Unbias from pysages.methods.core import GriddedSamplingMethod, Result from pysages.typing import Callable from pysages.utils import dispatch, identity @@ -98,6 +98,10 @@ def _ncalls_estimator(_) -> Callable: # Fallback case. We leave ncalls as zero. return identity +@dispatch +def _ncalls_estimator(_: Unbias) -> Callable: + # Fallback case. We leave ncalls as zero. + return identity @dispatch def _ncalls_estimator(_: Metadynamics) -> Callable: