From 4f06d92b88de27f4f863ddd761222c4b30150ca4 Mon Sep 17 00:00:00 2001 From: qindrew <168153413+qindrew@users.noreply.github.com> Date: Thu, 20 Nov 2025 18:30:30 -0500 Subject: [PATCH 01/13] Fix lammps restart When restarting, the positions, velocities, etc. must be overwritten in the lammps context because that is what is run. Overwriting the snapshot in the Sampler is still needed for the correct bias at the first step. --- pysages/backends/lammps.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pysages/backends/lammps.py b/pysages/backends/lammps.py index b357a8f4..b995adeb 100644 --- a/pysages/backends/lammps.py +++ b/pysages/backends/lammps.py @@ -30,6 +30,8 @@ from pysages.typing import Callable, Optional from pysages.utils import copy, identity +from ctypes import c_double + kConversionFactor = {"real": 2390.0573615334906, "metal": 1.0364269e-4, "electron": 1.06657236} kDefaultLocation = dlext.kOnHost if not hasattr(ExecutionSpace, "kOnDevice") else dlext.kOnDevice @@ -115,7 +117,19 @@ def _update_snapshot(self): def restore(self, prev_snapshot): """Replaces this sampler's snapshot with `prev_snapshot`.""" self._restore(self.snapshot, prev_snapshot) - + + positions = self.snapshot.positions[self.snapshot.ids].ravel() + x_ctypes = (len(positions) * c_double)(*positions) + self.context.scatter_atoms("x", 1, 3, x_ctypes) + + forces = self.snapshot.forces[self.snapshot.ids].ravel() + f_ctypes = (len(forces) * c_double)(*forces) + self.context.scatter_atoms("f", 1, 3, f_ctypes) + + velocities = self.snapshot.vel_mass[0][self.snapshot.ids].ravel() + v_ctypes = (len(velocities) * c_double)(*velocities) + self.context.scatter_atoms("v", 1, 3, v_ctypes) + def take_snapshot(self): """Returns a copy of the current snapshot of the system.""" s = self._partial_snapshot(include_masses=True) From 63be7cd05aa391e0fbdb60859965de72d951f7eb Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Tue, 25 Nov 2025 16:47:36 -0500 Subject: [PATCH 02/13] Sync pysages/ with installed version from site-packages --- pysages/backends/lammps.py | 6 +- pysages/colvars/ML_committor.py | 64 ++++ pysages/colvars/train_committor_dist.py | 390 +++++++++++++++++++++ pysages/methods/__init__.py | 2 + pysages/methods/analysis.py | 192 +++++++++++ pysages/methods/constraints.txt | 3 + pysages/methods/core.py | 8 + pysages/methods/funnel_function.py | 44 +++ pysages/methods/funnel_sabf.py | 440 ++++++++++++++++++++++++ pysages/methods/utils.py | 39 +++ 10 files changed, 1186 insertions(+), 2 deletions(-) create mode 100644 pysages/colvars/ML_committor.py create mode 100644 pysages/colvars/train_committor_dist.py create mode 100644 pysages/methods/constraints.txt create mode 100644 pysages/methods/funnel_function.py create mode 100644 pysages/methods/funnel_sabf.py diff --git a/pysages/backends/lammps.py b/pysages/backends/lammps.py index b995adeb..6a8d791f 100644 --- a/pysages/backends/lammps.py +++ b/pysages/backends/lammps.py @@ -30,6 +30,7 @@ from pysages.typing import Callable, Optional from pysages.utils import copy, identity +import numpy as onp from ctypes import c_double kConversionFactor = {"real": 2390.0573615334906, "metal": 1.0364269e-4, "electron": 1.06657236} @@ -117,7 +118,7 @@ def _update_snapshot(self): def restore(self, prev_snapshot): """Replaces this sampler's snapshot with `prev_snapshot`.""" self._restore(self.snapshot, prev_snapshot) - + positions = self.snapshot.positions[self.snapshot.ids].ravel() x_ctypes = (len(positions) * c_double)(*positions) self.context.scatter_atoms("x", 1, 3, x_ctypes) @@ -129,7 +130,8 @@ def restore(self, prev_snapshot): velocities = self.snapshot.vel_mass[0][self.snapshot.ids].ravel() v_ctypes = (len(velocities) * c_double)(*velocities) self.context.scatter_atoms("v", 1, 3, v_ctypes) - + + def take_snapshot(self): """Returns a copy of the current snapshot of the system.""" s = self._partial_snapshot(include_masses=True) diff --git a/pysages/colvars/ML_committor.py b/pysages/colvars/ML_committor.py new file mode 100644 index 00000000..6769067e --- /dev/null +++ b/pysages/colvars/ML_committor.py @@ -0,0 +1,64 @@ +import jax +from jax import numpy as np +from functools import partial +from pysages.colvars.core import CollectiveVariable +from .train_committor_dist import CommittorNN_Dist_Lip, make_forward_eval, CommittorNN_PIV +from pysages.typing import JaxArray, List, Sequence +from flax import serialization + +class Committor_CV_dist_lipschitz(CollectiveVariable): + def __init__(self, indices: List, params_path: str, tri_idx1: JaxArray, tri_idx2: JaxArray): + super().__init__(indices) + + model = CommittorNN_Dist_Lip(indices=np.arange(len(indices)), + tri_idx1=tri_idx1, + tri_idx2=tri_idx2, + h1=16, h2=16, h3=8, out_dim=1, sig_k=3.0) + + rng = jax.random.PRNGKey(0) + dummy_pos = np.zeros((1, len(indices), 3)) + params = model.init(rng, dummy_pos, training=False) + with open(params_path, "rb") as f: + params = serialization.from_bytes(params, f.read()) + params = jax.tree.map( + lambda x: x.astype(np.float64) if hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating) else x, + params, + ) + self.params = params + self.forward_eval = make_forward_eval(model) + + @property + def function(self): + def wrapped_forward(pos): + y = self.forward_eval(self.params, pos[None, :, :]) + return np.squeeze(y) + return wrapped_forward + + +class Committor_CV_PIV(CollectiveVariable): + def __init__(self, indices: List, params_path: str, blocks: Sequence): + super().__init__(indices) + + model = CommittorNN_PIV(indices=np.arange(len(indices)), blocks=blocks, h1=32, h2=16, h3=8, out_dim=1, sig_k=3.0) + + rng = jax.random.PRNGKey(0) + dummy_pos = np.zeros((1, len(indices), 3)) + params = model.init(rng, dummy_pos, training=False) + with open(params_path, "rb") as f: + params = serialization.from_bytes(params, f.read()) + params = jax.tree.map( + lambda x: x.astype(np.float64) if hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating) else x, + params, + ) + self.params = params + self.forward_eval = make_forward_eval(model) + + @property + def function(self): + def wrapped_forward(pos): + y = self.forward_eval(self.params, pos[None, :, :]) + return np.squeeze(y) + return wrapped_forward + +def cartesian(idx1, idx2): + return np.stack(np.broadcast_arrays(idx1[:, None], idx2[None, :]), axis=-1).reshape(-1, 2) diff --git a/pysages/colvars/train_committor_dist.py b/pysages/colvars/train_committor_dist.py new file mode 100644 index 00000000..b654f6ba --- /dev/null +++ b/pysages/colvars/train_committor_dist.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +JAX/Flax implementation of your committor NN with Lipschitz-normalized +linear layers, boundary + gradient + Lipschitz losses, and an Optax +training loop. This version fixes the earlier JIT issue by closing over +`loss_fn` instead of passing it as an argument to a jitted function. + +Usage (example): + # (optional) to avoid accidental GPU probing / driver issues + # export JAX_PLATFORMS=cpu + python run_committor_flax.py + +Notes: +- Keep this file name NOT equal to 'jax.py' to avoid shadowing the JAX lib. +- Labels: 0=A (reactant), 1=sim/intermediate, 2=B (product). +- Saves params to distance_flax.params and a histogram to fig_jax.png. +""" + +import os +# Respect external setting; uncomment the next line to force CPU: +# os.environ.setdefault("JAX_PLATFORMS", "cpu") +os.environ.setdefault("JAX_ENABLE_X64", "0") # float32 for speed + +from typing import Any, Dict, Tuple, Iterable, Sequence +import numpy as onp +import jax +import jax.numpy as np +from flax import linen as nn +from flax.training import train_state +import optax + +# ------------------------------ +# Layers +# ------------------------------ +class NormalizedLinear(nn.Module): + in_features: int + out_features: int + + @nn.compact + def __call__(self, x: np.ndarray) -> np.ndarray: + kernel = self.param( + "kernel", nn.initializers.lecun_normal(), (self.out_features, self.in_features) + ) + bias = self.param("bias", nn.initializers.zeros, (self.out_features,)) + ci = self.param("ci", nn.initializers.ones, ()) # trainable scalar + + absrowsum = np.sum(np.abs(kernel), axis=1, keepdims=True) + scale = np.minimum(1.0, nn.softplus(ci) / (absrowsum + 1e-12)) + w_norm = kernel * scale + return np.dot(x, w_norm.T) + bias + + +class CommittorNN_Dist_Lip(nn.Module): + indices: np.ndarray # (M,) + tri_idx1: np.ndarray # (K,) + tri_idx2: np.ndarray # (K,) + h1: int = 16 + h2: int = 16 + h3: int = 8 + out_dim: int = 1 + sig_k: float = 3.0 + + @nn.compact + def __call__(self, pos: np.ndarray, training: bool = False) -> np.ndarray: + # pos: (B, N, 3) + r = pos[:, self.indices, :] # (B, M, 3) + diffs = r[:, self.tri_idx1, :] - r[:, self.tri_idx2, :] # (B, K, 3) + dists = np.linalg.norm(diffs, ord=2, axis=-1) # (B, K) + + x = dists + x = NormalizedLinear(dists.shape[-1], self.h1)(x); x = np.tanh(x) + x = NormalizedLinear(self.h1, self.h2)(x); x = np.tanh(x) + x = NormalizedLinear(self.h2, self.h3)(x); x = np.tanh(x) + x = NormalizedLinear(self.h3, self.out_dim)(x) # (B,1) + x = np.squeeze(x, axis=-1) # (B,) + return jax.nn.sigmoid(self.sig_k * x) if training else x + + +def get_block_dist(pos, block): + return np.sort(np.linalg.norm(pos[:,block[:,0]] - pos[:,block[:,1]], axis=-1), axis=-1) + +def get_x(pos, blocks): + return np.concatenate([get_block_dist(pos, block) for block in blocks], axis=-1) + +class CommittorNN_PIV(nn.Module): + indices: np.ndarray # (M,) + blocks: Sequence + h1: int = 32 + h2: int = 16 + h3: int = 8 + out_dim: int = 1 + sig_k: float = 3.0 + + @nn.compact + def __call__(self, pos: np.ndarray, training: bool = False) -> np.ndarray: + # pos: (B, N, 3) + x = get_x(pos[:, self.indices, :], self.blocks) + x = NormalizedLinear(x.shape[-1], self.h1)(x); x = np.tanh(x) + x = NormalizedLinear(self.h1, self.h2)(x); x = np.tanh(x) + x = NormalizedLinear(self.h2, self.h3)(x); x = np.tanh(x) + x = NormalizedLinear(self.h3, self.out_dim)(x) # (B,1) + x = np.squeeze(x, axis=-1) # (B,) + return jax.nn.sigmoid(self.sig_k * x) if training else x + +# ------------------------------ +# Losses +# ------------------------------ +@jax.jit +def lipschitz_loss_from_params(params: Dict[str, Any]) -> np.ndarray: + prod = 1.0 + + def walk(p): + nonlocal prod + if isinstance(p, dict): + if "ci" in p: + prod = prod * nn.softplus(p["ci"]) # multiply scalar + for v in p.values(): + walk(v) + + walk(params) + return prod + + +def make_loss_fn(model: nn.Module, + masses: np.ndarray, + boundary_weight: float = 100.0, + lipschitz_weight: float = 3e-3, + gradient_weight: float = 1.0): + # masses: (1, N, 1) broadcast to (B, N, 1) + def loss_fn(params, pos_batch: np.ndarray, labels: np.ndarray, weights: np.ndarray): + q = model.apply(params, pos_batch, training=True) # (B,) + + # boundary term (average over non-1 labels) + is_A = (labels == 0) + is_B = (labels == 2) + is_1 = (labels == 1) + num_1 = np.maximum(1, np.sum(is_1)) + num_not1 = np.maximum(1, pos_batch.shape[0] - np.sum(is_1)) + boundary = (np.sum((q**2) * is_A) + np.sum(((q - 1.0)**2) * is_B)) / num_not1 + + # gradient term (only on label==1) + def q_sum_over_batch_pos(pos): + return np.sum(model.apply(params, pos, training=True)) + grad_pos = jax.grad(q_sum_over_batch_pos)(pos_batch) # (B,N,3) + grad_sq = (grad_pos**2) / masses + grad_per_sample = np.sum(grad_sq, axis=(1, 2)) # (B,) + grad_loss = np.sum(np.where(is_1, weights * grad_per_sample, 0.0)) / num_1 + + # lipschitz product of softplus(ci) + lip = lipschitz_loss_from_params(params) + + total = 1e4 * (gradient_weight * grad_loss + boundary_weight * boundary + lipschitz_weight * lip) + return total, (1e4 * gradient_weight * grad_loss, + 1e4 * boundary_weight * boundary, + 1e4 * lipschitz_weight * lip) + + return loss_fn + + +# ------------------------------ +# Train state & steps +# ------------------------------ +class TrainState(train_state.TrainState): + pass + + +def create_train_state(rng, model, learning_rate: float, pos_shape: Tuple[int, int, int]): + params = model.init(rng, np.zeros(pos_shape, onp.float32), training=True) + tx = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adamw(learning_rate=learning_rate, weight_decay=1e-5), + ) + return TrainState.create(apply_fn=model.apply, params=params, tx=tx) + + +def make_train_step(loss_fn): + @jax.jit + def train_step(state: TrainState, pos_b, labels_b, weights_b): + (loss, parts), grads = jax.value_and_grad(loss_fn, has_aux=True)( + state.params, pos_b, labels_b, weights_b + ) + state = state.apply_gradients(grads=grads) + return state, loss, parts + return train_step + + +def make_eval_step(loss_fn): + @jax.jit + def eval_step(params, pos_b, labels_b, weights_b): + loss, parts = loss_fn(params, pos_b, labels_b, weights_b) + return loss, parts + return eval_step + + +# ------------------------------ +# Data utils +# ------------------------------ +def make_batches(pos: onp.ndarray, labels: onp.ndarray, weights: onp.ndarray, + batch_size: int, shuffle: bool = True, + rng: onp.random.Generator | None = None) -> Iterable[Tuple[onp.ndarray, onp.ndarray, onp.ndarray]]: + N = pos.shape[0] + idx = onp.arange(N) + if shuffle: + (rng or onp.random.default_rng()).shuffle(idx) + n_full = N // batch_size # drop_last + idx = idx[: n_full * batch_size].reshape(n_full, batch_size) + for row in idx: + yield pos[row], labels[row], weights[row] + + +# ------------------------------ +# Training loop +# ------------------------------ + +def train(model: nn.Module, + masses: onp.ndarray, + train_data: Tuple[onp.ndarray, onp.ndarray, onp.ndarray], + val_data: Tuple[onp.ndarray, onp.ndarray, onp.ndarray], + batch_size: int = 1024, + num_epochs: int = 100, + lr: float = 1e-3, + seed: int | None = 61982): + rng = jax.random.PRNGKey(seed if seed is not None else 0) + pos_train, y_train, w_train = train_data + pos_val, y_val, w_val = val_data + + _, N, D = pos_train.shape + state = create_train_state(rng, model, lr, (batch_size, N, D)) + + loss_fn = make_loss_fn(model, np.asarray(masses), + boundary_weight=100.0, + lipschitz_weight=3e-3, + gradient_weight=1.0) + train_step = make_train_step(loss_fn) + eval_step = make_eval_step(loss_fn) + + for epoch in range(num_epochs): + tr_losses = []; tr_grad = []; tr_bound = []; tr_lip = [] + for xb, yb, wb in make_batches(pos_train, y_train, w_train, batch_size, shuffle=True): + xb = np.asarray(xb, dtype=onp.float32) + yb = np.asarray(yb, dtype=onp.int32) + wb = np.asarray(wb, dtype=onp.float32) + state, loss, parts = train_step(state, xb, yb, wb) + tr_losses.append(float(loss)); g,b,l = parts; tr_grad.append(float(g)); tr_bound.append(float(b)); tr_lip.append(float(l)) + + va_losses = []; va_grad = []; va_bound = []; va_lip = [] + for xb, yb, wb in make_batches(pos_val, y_val, w_val, batch_size, shuffle=False): + xb = np.asarray(xb, dtype=onp.float32) + yb = np.asarray(yb, dtype=onp.int32) + wb = np.asarray(wb, dtype=onp.float32) + loss, parts = eval_step(state.params, xb, yb, wb) + va_losses.append(float(loss)); g,b,l = parts; va_grad.append(float(g)); va_bound.append(float(b)); va_lip.append(float(l)) + + print(f"Epoch {epoch}") + if tr_losses: + print(" avg train loss:", onp.mean(tr_losses)) + print(" avg train grad loss:", onp.mean(tr_grad)) + print(" avg train bound loss:", onp.mean(tr_bound)) + print(" avg train lipschitz loss:", onp.mean(tr_lip)) + if va_losses: + print(" avg val loss:", onp.mean(va_losses)) + print(" avg val grad loss:", onp.mean(va_grad)) + print(" avg val bound loss:", onp.mean(va_bound)) + print(" avg val lipschitz loss:", onp.mean(va_lip)) + print() + + return state + + +# ------------------------------ +# Inference helper +# ------------------------------ +def make_forward_eval(model): + @jax.jit + def _forward_eval(params, pos_b): + return model.apply(params, pos_b, training=False) + return _forward_eval + + +# ------------------------------ +# Main: mirrors your IO (MDAnalysis/pysages) +# ------------------------------ +if __name__ == "__main__": + import MDAnalysis as mda + import pysages + from ase.io import read as ase_read + import matplotlib.pyplot as plt + + # ----- Load CV and weights ----- + xi = onp.loadtxt("/scratch/aq2212/CV.log")[:, 1][::10] + run_result = pysages.load("/scratch/aq2212/production_abf2_equil15/restart.pkl") + fe_result = pysages.analyze(run_result) + biases = fe_result['fes_fn'](xi) + biases /= (8.6173e-5 * 600) + biases = biases[:, 0] + weightsSim = onp.exp(-biases).astype(onp.float32) + + # ----- Extract downsampled sim frames ----- + u = mda.Universe("sampled.xyz") + posSim = [u.atoms.positions.copy() for _ in u.trajectory] + posSim = onp.asarray(posSim, dtype=onp.float32) + + # ----- Load endpoints ----- + def load_xyz(path): + uu = mda.Universe(path) + return onp.asarray([uu.atoms.positions.copy() for _ in uu.trajectory], dtype=onp.float32) + + posA = load_xyz("/scratch/projects/depablolab/acqin2/lipschitz/visnet/model_for_sim1/A.xyz") + posB = load_xyz("/scratch/projects/depablolab/acqin2/lipschitz/visnet/model_for_sim1/B.xyz") + posC = load_xyz("/scratch/projects/depablolab/acqin2/lipschitz/visnet/model_for_sim1/C.xyz") + + labelsA = onp.zeros((len(posA),), dtype=onp.int32) + labelsB = onp.full((len(posB),), 2, dtype=onp.int32) + labelsC = onp.full((len(posC),), 1, dtype=onp.int32) + labelsSim = onp.full((len(posSim),), 1, dtype=onp.int32) + + weightsA = onp.ones((len(posA),), dtype=onp.float32) + weightsB = onp.ones((len(posB),), dtype=onp.float32) + weightsC = onp.full((len(posC),), 1e-3, dtype=onp.float32) + + # masses shape (1, N, 1) + atoms0 = ase_read('/scratch/projects/depablolab/acqin2/lipschitz/visnet/model_for_sim1/A.xyz', index=0) + masses = onp.asarray(atoms0.get_masses(), dtype=onp.float32)[None, :, None] + + # ----- Train/val split (A+B+Sim) + pos_all = onp.concatenate([posA, posB, posSim], axis=0) + lab_all = onp.concatenate([labelsA, labelsB, labelsSim], axis=0) + w_all = onp.concatenate([weightsA, weightsB, weightsSim], axis=0) + + rng_np = onp.random.default_rng(61982) + idx = onp.arange(len(pos_all)); rng_np.shuffle(idx) + split = int(0.7 * len(idx)) + tr_idx, va_idx = idx[:split], idx[split:] + + train_data = (pos_all[tr_idx], lab_all[tr_idx], w_all[tr_idx]) + val_data = (pos_all[va_idx], lab_all[va_idx], w_all[va_idx]) + + # ----- Model config (indices / triplets) + indices = np.arange(50) + tri_idx1 = np.asarray([44, 47, 47, 6, 5, 47, 7, 5]) + tri_idx2 = np.asarray([46, 46, 49, 49, 44, 49, 46, 7]) + + model = CommittorNN_Dist_Lip(indices=indices, + tri_idx1=tri_idx1, + tri_idx2=tri_idx2, + h1=16, h2=16, h3=8, out_dim=1, sig_k=3.0) + + state = train(model, + masses=masses, + train_data=train_data, + val_data=val_data, + batch_size=10000, + num_epochs=200, # consider reducing for quick tests + lr=1e-3, + seed=61982) + + # ----- Save params ----- + import flax.serialization as serialization + param_bytes = serialization.to_bytes(state.params) + with open("distance_flax.params", "wb") as f: + f.write(param_bytes) + + forward_eval = make_forward_eval(model) + + # ----- Evaluate on A/B/C and plot ----- + def batched_preds(pos_arr, bs=2048): + outs = [] + for xb, _, _ in make_batches(pos_arr, + onp.zeros(len(pos_arr), onp.int32), + onp.ones(len(pos_arr), onp.float32), + bs, shuffle=False): + xb = np.asarray(xb, dtype=onp.float32) + y = forward_eval(state.params, xb) + outs.append(onp.asarray(y)) + return onp.concatenate(outs, axis=0) if outs else onp.zeros((0,), dtype=onp.float32) + + As = batched_preds(posA) + Bs = batched_preds(posB) + Cs = batched_preds(posC) + + import matplotlib + plt.figure() + plt.hist([As, Bs, Cs], bins=100, label=['Reactant', 'Product', 'Intermediate'], + alpha=0.3, histtype='stepfilled', density=True) + plt.legend(); plt.xlabel('Output value'); plt.ylabel('Probability Density') + plt.title('Histogram of outputs by class') + plt.savefig('fig_jax.png', dpi=200) + + print("Saved: distance_flax.params, fig_jax.png") diff --git a/pysages/methods/__init__.py b/pysages/methods/__init__.py index 875baa91..19859394 100644 --- a/pysages/methods/__init__.py +++ b/pysages/methods/__init__.py @@ -71,12 +71,14 @@ from .restraints import CVRestraints from .sirens import Sirens from .spectral_abf import SpectralABF +from .funnel_sabf import Funnel_SpectralABF from .spline_string import SplineString from .umbrella_integration import UmbrellaIntegration from .unbiased import Unbiased from .utils import ( HistogramLogger, MetaDLogger, + Funnel_Logger, ReplicasConfiguration, SerialExecutor, methods_dispatch, diff --git a/pysages/methods/analysis.py b/pysages/methods/analysis.py index 609a25fb..906bfcba 100644 --- a/pysages/methods/analysis.py +++ b/pysages/methods/analysis.py @@ -173,3 +173,195 @@ def average_forces(hist, Fsum): "fes_fn": first_or_all(fes_fns), "mesh": transpose(mesh).reshape(-1, d).squeeze(), } + + +@dispatch +def _funnelanalyze(result: Result, strategy: GradientLearning, topology): + """ + Computes the free energy from the result of an `FunnelABF`-based run. + Integrates the forces via a gradient learning strategy. + + Parameters + ---------- + + result: Result: + Result bundle containing method, final FunnelABF-like state, and callback. + + strategy: GradientLearning + + topology: Tuple[int] + Defines the architecture of the neural network + (number of nodes in each hidden layer). + + Returns + ------- + + dict: A dictionary with the following keys: + + histogram: JaxArray + Histogram for the states visited during the method. + + mean_force: JaxArray + Average force at each bin of the CV grid. + + free_energy: JaxArray + Free Energy at each bin of the CV grid. + + mesh: JaxArray + Grid used in the method. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the + CV domain defined by the grid. + + NOTE: + For multiple-replicas runs we return a list (one item per-replica) + for each attribute. + """ + + # The ForceNN based analysis occurs in two stages: + # + # 1. The data is smoothed and a first quick fitting is performed to obtain + # an approximate set of network parameters. + # 2. A second training pass is then performed over the raw data starting + # with the parameters from previous step. + + method = result.method + states = result.states + grid = method.grid + mesh = inputs = (compute_mesh(grid) + 1) * grid.size / 2 + grid.lower + + model = MLP(grid.shape.size, 1, topology, transform=partial(_scale, grid=grid)) + loss = GradientsSSE() + regularizer = L2Regularization(1e-4) + + # Stage 1 optimizer + pre_optimizer = LevenbergMarquardt(loss=loss, max_iters=250, reg=regularizer) + pre_fit = build_fitting_function(model, pre_optimizer) + + # Stage 2 optimizer + optimizer = LevenbergMarquardt(loss=loss, max_iters=1000, reg=regularizer) + fit = build_fitting_function(model, optimizer) + + @vmap + def smooth(data, conv_dtype=np.float32): + data_dtype = data.dtype + boundary = "wrap" if grid.is_periodic else "edge" + kernel = np.asarray(blackman_kernel(grid.shape.size, 7), dtype=conv_dtype) + data = np.asarray(data, dtype=conv_dtype) + return np.asarray(convolve(data.T, kernel, boundary=boundary), dtype=data_dtype).T + + @jit + def pre_train(nn, data): + params = pre_fit(nn.params, inputs, smooth(data)).params + return NNData(params, nn.mean, nn.std) + + @jit + def train(nn, data): + params = fit(nn.params, inputs, data).params + return NNData(params, nn.mean, nn.std) + + def build_fes_fn(state): + hist = np.expand_dims(state.hist, state.hist.ndim) + F = state.Fsum / np.maximum(hist, 1) + + # Scale the mean forces + s = np.abs(F).max() + F = F / s + + ps, layout = unpack(model.parameters) + nn = pre_train(NNData(ps, 0.0, s), F) + nn = train(nn, F) + + def fes_fn(x): + params = pack(nn.params, layout) + A = nn.std * model.apply(params, x) + nn.mean + return A.max() - A + + return jit(fes_fn) + + def average_forces(hist, Fsum): + shape = (*Fsum.shape[:-1], 1) + return Fsum / np.maximum(hist.reshape(shape), 1) + + def build_corr_fn(state): + hist = np.expand_dims(state.hist, state.hist.ndim) + F = (state.Fsum + state.Frestr) / np.maximum(hist, 1) + + # Scale the mean forces + s = np.abs(F).max() + F = F / s + + ps, layout = unpack(model.parameters) + nn = pre_train(NNData(ps, 0.0, s), F) + nn = train(nn, F) + + def fes_fn(x): + params = pack(nn.params, layout) + A = nn.std * model.apply(params, x) + nn.mean + return A.max() - A + + return jit(fes_fn) + + def build_restr_fn(state): + hist = np.expand_dims(state.hist, state.hist.ndim) + F = state.Frestr / np.maximum(hist, 1) + + # Scale the mean forces + s = np.abs(F).max() + F = F / s + + ps, layout = unpack(model.parameters) + nn = pre_train(NNData(ps, 0.0, s), F) + nn = train(nn, F) + + def fes_fn(x): + params = pack(nn.params, layout) + A = nn.std * model.apply(params, x) + nn.mean + return A.max() - A + + return jit(fes_fn) + + hists = [] + mean_forces = [] + free_energies = [] + fes_fns = [] + corr_forces = [] + corrected_energies = [] + corr_fns = [] + restr_forces = [] + restraint_energies = [] + restr_fns = [] + + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + d = mesh.shape[-1] + + for state in states: + fes_fn = build_fes_fn(state) + hists.append(transpose(state.hist)) + mean_forces.append(transpose(average_forces(state.hist, state.Fsum))) + free_energies.append(transpose(fes_fn(mesh))) + fes_fns.append(fes_fn) + corr_fn = build_corr_fn(state) + corr_forces.append(transpose(average_forces(state.hist, state.Fsum + state.Frestr))) + corrected_energies.append(transpose(corr_fn(mesh))) + corr_fns.append(corr_fn) + restr_fn = build_restr_fn(state) + restr_forces.append(transpose(average_forces(state.hist, state.Frestr))) + restraint_energies.append(transpose(restr_fn(mesh))) + restr_fns.append(restr_fn) + + return { + "histogram": first_or_all(hists), + "mean_force": first_or_all(mean_forces), + "free_energy": first_or_all(free_energies), + "fes_fn": first_or_all(fes_fns), + "mesh": transpose(mesh).reshape(-1, d).squeeze(), + "corrected_force": first_or_all(corr_forces), + "corrected_energy": first_or_all(corrected_energies), + "corr_fn": first_or_all(corr_fns), + "restraint_force": first_or_all(restr_forces), + "restraint_energy": first_or_all(restraint_energies), + "restr_fn": first_or_all(restr_fns), + } diff --git a/pysages/methods/constraints.txt b/pysages/methods/constraints.txt new file mode 100644 index 00000000..4f166c1e --- /dev/null +++ b/pysages/methods/constraints.txt @@ -0,0 +1,3 @@ +jax==0.4.29 +jaxlib==0.4.29 +numpy==1.26.4 diff --git a/pysages/methods/core.py b/pysages/methods/core.py index 13e004d1..0fffdcc2 100644 --- a/pysages/methods/core.py +++ b/pysages/methods/core.py @@ -25,6 +25,8 @@ has_method, identity, ) +import numpy as onp +from ctypes import c_double # Base Classes # ============ @@ -405,6 +407,12 @@ def _run( # noqa: F811 # pylint: disable=C0116,E0102 sampler.restore(prev_snapshot) sampler.state = result.states + #positions = sampler.snapshot.positions.ravel() + #x = sampling_context.context.gather_atoms('x',1,3) + #for i in range(len(positions)): + # x[i] = positions[i] + #sampling_context.context.scatter_atoms('x',1,3,x) + with sampling_context: sampling_context.run(timesteps, **kwargs) if post_run_action: diff --git a/pysages/methods/funnel_function.py b/pysages/methods/funnel_function.py new file mode 100644 index 00000000..5924501b --- /dev/null +++ b/pysages/methods/funnel_function.py @@ -0,0 +1,44 @@ +# funnel functions +from functools import partial + +import jax.numpy as np +from jax import jit, grad +from jax.numpy import linalg + +def distance(r, cell_size): + diff = r[1:] - r[0] + #diff = diff - np.round(diff / cell_size) * cell_size + return np.linalg.norm(diff,axis=1) + +def coordnum_energy(r, cell_size, c_mins, k, idx = np.asarray([ [0,1],[0,2],[0,3],[1,2],[1,3],[2,3] ])): + total = 0 #all O somewhat close to C + c_min = c_mins[0] + dist = distance(r[:5],cell_size) + total += 0.5*(np.where(dist > c_min, dist - c_min, 0.0)**2).sum() + + c_min = c_mins[1] #Os are all close to each other + rO = r[1:5] + dists = np.linalg.norm( rO[idx[:,0]] - rO[idx[:,1]],axis=1 ) + total += 0.5*(np.where(dists > c_min, dists - c_min, 0.0)**2).sum() + return k * total + + +def intermediate_funnel(pos, ids, indexes, cell_size, c_mins, k): + r = pos[ids[indexes]] + return coordnum_energy(r, cell_size, c_mins, k) + +def log_funnel(): + return 0.0 + +def external_funnel(data, indexes, cell_size, c_mins, k): + pos = data.positions[:, :3] + ids = data.indices + bias = grad(intermediate_funnel)(pos, ids, indexes, cell_size, c_mins, k) + proj = log_funnel() + return bias, proj + +def get_funnel_force(indexes, cell_size, c_mins, k): + funnel_force = partial( + external_funnel, + indexes=indexes, cell_size=cell_size, c_mins=c_mins, k=k) + return jit(funnel_force) diff --git a/pysages/methods/funnel_sabf.py b/pysages/methods/funnel_sabf.py new file mode 100644 index 00000000..b8597124 --- /dev/null +++ b/pysages/methods/funnel_sabf.py @@ -0,0 +1,440 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Funnel Spectral Adaptive Biasing Force Sampling Method. + +""" + +from jax import jit +from jax import numpy as np +from jax.lax import cond + +from pysages.approxfun import ( + Fun, + SpectralGradientFit, + build_evaluator, + build_fitter, + build_grad_evaluator, + compute_mesh, +) +from pysages.grids import Chebyshev, Grid, build_indexer, convert, grid_transposer +from pysages.methods.core import GriddedSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.methods.utils import numpyfy_vals +from pysages.typing import JaxArray, NamedTuple, Tuple +from pysages.utils import dispatch, first_or_all, linear_solver + + +class SpectralFABFState(NamedTuple): + """ + Funnel_SpectralABF internal state. + + Parameters + ---------- + + xi: JaxArray [cvs.shape] + Last collective variable recorded in the simulation. + + bias: JaxArray [natoms, 3] + Array with biasing forces for each particle. + + hist: JaxArray [grid.shape] + Histogram of visits to the bins in the collective variable grid. + + Fsum: JaxArray [grid.shape, cvs.shape] + The cumulative force recorded at each bin of the CV grid. + + force: JaxArray [grid.shape, cvs.shape] + Average force at each bin of the CV grid. + + Wp: JaxArray [cvs.shape] + Estimate of the product $W p$ where `p` is the matrix of momenta and + `W` the Moore-Penrose inverse of the Jacobian of the CVs. + + Wp_: JaxArray [cvs.shape] + The value of `Wp` for the previous integration step. + + fun: Fun + Object that holds the coefficients of the basis functions + approximation to the free energy. + + restr: JaxArray [grid.shape, cvs.shape] + Instantaneous restraint force at each bin of the CV grid. + + proj: JaxArray [cvs.shape] + Last collective variable from restraints recorded in the simulation. + + Frestr: JaxArray [grid.shape, cvs.shape] + The cumulative restraint force recorded at each bin of the CV grid. + + ncalls: int + Counts the number of times the method's update has been called. + """ + + xi: JaxArray + bias: JaxArray + hist: JaxArray + Fsum: JaxArray + force: JaxArray + Wp: JaxArray + Wp_: JaxArray + fun: Fun + restr: JaxArray + proj: JaxArray + Frestr: JaxArray + ncalls: int + + def __repr__(self): + return repr("PySAGES " + type(self).__name__) + + +class PartialSpectralFABFState(NamedTuple): + xi: JaxArray + hist: JaxArray + Fsum: JaxArray + ind: Tuple + fun: Fun + pred: bool + + +class Funnel_SpectralABF(GriddedSamplingMethod): + """ + Implementation of the Funnel_Spectral ABF method described in + + Parameters + ---------- + + cvs: Union[List, Tuple] + Set of user selected collective variable. + + grid: Grid + Specifies the collective variables domain and number of bins for discretizing + the CV space along each CV dimension. For non-periodic grids this will be + converted to a Chebyshev-distributed grid. + + N: Optional[int] = 500 + Threshold parameter before accounting for the full average + of the adaptive biasing force. + + fit_freq: Optional[int] = 100 + Fitting frequency. + + fit_threshold: Optional[int] = 500 + Number of time steps after which fitting starts to take place. + + restraints: Optional[CVRestraints] = None + If provided, indicate that harmonic restraints will be applied when any + collective variable lies outside the box from `restraints.lower` to + `restraints.upper`. + + use_pinv: Optional[Bool] = False + If set to True, the product `W @ p` will be estimated using + `np.linalg.pinv` rather than using the `scipy.linalg.solve` function. + This is computationally more expensive but numerically more stable. + + ext_force: Optional[ext_force] = None + If provided, indicate the geometric restraints will be applied for any + collective variables. + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, grid, **kwargs): + super().__init__(cvs, grid, **kwargs) + self.N = np.asarray(self.kwargs.get("N", 500)) + self.fit_freq = self.kwargs.get("fit_freq", 100) + self.fit_threshold = self.kwargs.get("fit_threshold", 500) + self.fit_threshold_upper = self.kwargs.get("fit_threshold_upper", 1e7) + self.grid = self.grid if self.grid.is_periodic else convert(self.grid, Grid[Chebyshev]) + self.model = SpectralGradientFit(self.grid) + self.use_pinv = self.kwargs.get("use_pinv", False) + + def build(self, snapshot, helpers, *_args, **_kwargs): + """ + Returns the `initialize` and `update` functions for the sampling method. + """ + self.ext_force = self.kwargs.get("ext_force", None) + return _spectral_abf(self, snapshot, helpers) + + +def _spectral_abf(method, snapshot, helpers): + cv = method.cv + grid = method.grid + fit_freq = method.fit_freq + fit_threshold = method.fit_threshold + fit_threshold_upper = method.fit_threshold_upper + + dt = snapshot.dt + dims = grid.shape.size + natoms = np.size(snapshot.positions, 0) + + # Helper methods + tsolve = linear_solver(method.use_pinv) + get_grid_index = build_indexer(grid) + fit = build_fitter(method.model) + fit_forces = build_free_energy_fitter(method, fit) + estimate_force = build_force_estimator(method) + ext_force = method.ext_force + + query, dimensionality, to_force_units = helpers + + def initialize(): + xi, _ = cv(query(snapshot)) + bias = np.zeros((natoms, dimensionality())) + hist = np.zeros(grid.shape, dtype=np.uint32) + Fsum = np.zeros((*grid.shape, dims)) + force = np.zeros(dims) + Wp = np.zeros(dims) + Wp_ = np.zeros(dims) + fun = fit(Fsum) + restr = np.zeros(dims) + proj = 0.0 + Frestr = np.zeros((*grid.shape, dims)) + return SpectralFABFState(xi, bias, hist, Fsum, force, Wp, Wp_, fun, restr, proj, Frestr, 0) + + def update(state, data): + # During the intial stage use ABF + ncalls = state.ncalls + 1 + in_fitting_regime = ncalls > fit_threshold + in_fitting_step = in_fitting_regime & (ncalls % fit_freq == 1) & (ncalls < fit_threshold_upper) + # Fit forces + fun = fit_forces(state, in_fitting_step) + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + # Restraint force and logger + e_f, proj = ext_force(data) + # + p = data.momenta + Wp = tsolve(Jxi, p) + # Second order backward finite difference + dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt + # + I_xi = get_grid_index(xi) + hist = state.hist.at[I_xi].add(1) + Fsum = state.Fsum.at[I_xi].add(to_force_units(dWp_dt) + state.force) + # + force = estimate_force( + PartialSpectralFABFState(xi, hist, Fsum, I_xi, fun, in_fitting_regime) + ) + bias = np.reshape(-Jxi.T @ force, state.bias.shape) + np.reshape(-e_f, state.bias.shape) + # Restraint contribution to force + restr = tsolve(Jxi, e_f.reshape(p.shape)) + Frestr = state.Frestr.at[I_xi].add(restr) + return SpectralFABFState( + xi, bias, hist, Fsum, force, Wp, state.Wp, fun, restr, proj, Frestr, ncalls + ) + + return snapshot, initialize, generalize(update, helpers) + + +def build_free_energy_fitter(_method: Funnel_SpectralABF, fit): + """ + Returns a function that given a `SpectralFABFState` performs a least squares fit of the + generalized average forces for finding the coefficients of a basis functions expansion + of the free energy. + """ + + def _fit_forces(state): + shape = (*state.Fsum.shape[:-1], 1) + force = state.Fsum / np.maximum(state.hist.reshape(shape), 1) + return fit(force) + + def skip_fitting(state): + return state.fun + + def fit_forces(state, in_fitting_step): + return cond(in_fitting_step, _fit_forces, skip_fitting, state) + + return fit_forces + + +@dispatch +def build_force_estimator(method: Funnel_SpectralABF): + """ + Returns a function that given the coefficients of basis functions expansion and a CV + value, evaluates the function approximation to the gradient of the free energy. + """ + N = method.N + grid = method.grid + dims = grid.shape.size + model = method.model + get_grad = build_grad_evaluator(model) + + def average_force(state): + i = state.ind + return 0*state.Fsum[i] #state.Fsum[i] / np.maximum(state.hist[i], N) + + def interpolate_force(state): + return get_grad(state.fun, state.xi).reshape(grid.shape.size) + + def _estimate_force(state): + return cond(state.pred, interpolate_force, average_force, state) + + if method.restraints is None: + ob_force = jit(lambda state: np.zeros(dims)) + else: + lo, hi, kl, kh = method.restraints + + def ob_force(state): + xi = state.xi.reshape(grid.shape.size) + return apply_restraints(lo, hi, kl, kh, xi) + + def estimate_force(state): + ob = np.any(np.array(state.ind) == grid.shape) # Out of bounds condition + return cond(ob, ob_force, _estimate_force, state) + + return estimate_force + + +@dispatch +def analyze(result: Result[Funnel_SpectralABF]): + """ + Parameters + ---------- + + result: Result[FunnelSpectralABF] + Result bundle containing the method, final states, and callbacks. + + dict: + A dictionary with the following keys: + + histogram: JaxArray + A histogram of the visits to each bin in the CV grid. + + mean_force: JaxArray + Generalized mean forces at each bin in the CV grid. + + free_energy: JaxArray + Free energy at each bin in the CV grid. + + mesh: JaxArray + These are the values of the CVs that are used as inputs for training. + + fun: Fun + Coefficients of the basis functions expansion approximating the free energy. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the CV domain defined + by the grid. + + fun: Fun + Coefficients of the basis functions expansion approximating the free energy. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the CV domain defined + by the grid. + + fun_corr: Fun + Coefficients of the basis functions expansion approximating the free energy + defined without external restraints. + + corr_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy without external restraints + in the CV domain defined by the grid. + fun_rstr: Fun + Coefficients of the basis functions expansion approximating the free energy + of the external restraints. + + rstr_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy of theexternal restraints + in the CV domain defined by the grid. + + corrected_force: JaxArray + Generalized mean forces without restraints at each bin in the CV grid. + + corrected_energy: JaxArray + Free energy without restraints at each bin in the CV grid. + + restraint_force: JaxArray + Generalized mean forces of the restraints at each bin in the CV grid. + + restraint_energy: JaxArray + Free energy of restraints at each bin in the CV grid. + + + NOTE: + For multiple-replicas runs we return a list (one item per-replica) for each attribute. + """ + method = result.method + fit = build_fitter(method.model) + grid = method.grid + mesh = compute_mesh(grid) + evaluate = build_evaluator(method.model) + + def average_forces(hist, Fsum): + hist = np.expand_dims(hist, hist.ndim) + return Fsum / np.maximum(hist, 1) + + def fit_corr(state): + shape = (*state.Fsum.shape[:-1], 1) + force = (state.Fsum + state.Frestr) / np.maximum(state.hist.reshape(shape), 1) + return fit(force) + + def fit_restr(state): + shape = (*state.Fsum.shape[:-1], 1) + force = (state.Frestr) / np.maximum(state.hist.reshape(shape), 1) + return fit(force) + + def build_fes_fn(fun): + def fes_fn(x): + A = evaluate(fun, x) + return A.max() - A + + return jit(fes_fn) + + hists = [] + mean_forces = [] + free_energies = [] + funs = [] + fes_fns = [] + forces_corrected = [] + corrected_energies = [] + funs_corr = [] + fes_corr = [] + restraint_forces = [] + restraint_energies = [] + funs_rstr = [] + fes_rstr = [] + # We transpose the data for convenience when plotting + transpose = grid_transposer(grid) + d = mesh.shape[-1] + + for s in result.states: + fun_corr = fit_corr(s) + fun_rstr = fit_restr(s) + fes_fn = build_fes_fn(s.fun) + corr_fn = build_fes_fn(fun_corr) + restr_fn = build_fes_fn(fun_rstr) + hists.append(transpose(s.hist)) + mean_forces.append(transpose(average_forces(s.hist, s.Fsum))) + free_energies.append(transpose(fes_fn(mesh))) + funs.append(s.fun) + fes_fns.append(fes_fn) + forces_corrected.append(transpose(average_forces(s.hist, s.Fsum + s.Frestr))) + corrected_energies.append(transpose(corr_fn(mesh))) + funs_corr.append(fun_corr) + fes_corr.append(corr_fn) + restraint_forces.append(transpose(average_forces(s.hist, s.Frestr))) + restraint_energies.append(transpose(restr_fn(mesh))) + funs_rstr.append(fun_rstr) + fes_rstr.append(restr_fn) + + ana_result = { + "histogram": first_or_all(hists), + "mean_force": first_or_all(mean_forces), + "free_energy": first_or_all(free_energies), + "mesh": transpose(mesh).reshape(-1, d).squeeze(), + "fun": first_or_all(funs), + "fes_fn": first_or_all(fes_fns), + "corrected_force": first_or_all(forces_corrected), + "corrected_energy": first_or_all(corrected_energies), + "fun_corr": first_or_all(funs_corr), + "corr_fn": first_or_all(fes_corr), + "restraint_force": first_or_all(restraint_forces), + "restraint_energy": first_or_all(restraint_energies), + "fun_rstr": first_or_all(funs_rstr), + "rstr_fn": first_or_all(fes_rstr), + } + + return numpyfy_vals(ana_result) diff --git a/pysages/methods/utils.py b/pysages/methods/utils.py index a52894c1..5add537f 100644 --- a/pysages/methods/utils.py +++ b/pysages/methods/utils.py @@ -169,6 +169,45 @@ def __call__(self, snapshot, state, timestep): self.counter += 1 +class Funnel_Logger: + """ + Logs the state of the collective variable and other parameters in Funnel. + Parameters + ---------- + funnel_file: + Name of the output funnel log file. + log_period: + Time steps between logging of collective variables and Funnel parameters. + """ + + def __init__(self, funnel_file, log_period): + """ + Funnel_Logger constructor. + """ + self.funnel_file = funnel_file + self.log_period = log_period + self.counter = 0 + + def save_work(self, xi): + """ + Append the funnel_cv, perp_funnel, and funnel_restraints to log file. + """ + with open(self.funnel_file, "a+", encoding="utf8") as f: + f.write(str(self.counter) + "\t") + f.write("\t".join(map(str, xi.flatten())) + "\n") + # f.write("\t".join(map(str, restr.flatten())) + "\t") + # f.write(str(proj) + "\n") + + def __call__(self, snapshot, state, timestep): + """ + Implements the logging itself. Interface as expected for Callbacks. + """ + if self.counter >= self.log_period and self.counter % self.log_period == 0: + self.save_work(state.xi) + + self.counter += 1 + + def listify(arg, replicas, name, dtype): """ Returns a list of with length `replicas` of `arg` if `arg` is not a list, From ecf69d802ebb68f2bf8ea825697235a17216580c Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Tue, 25 Nov 2025 16:56:42 -0500 Subject: [PATCH 03/13] clean up changes --- pysages/backends/lammps.py | 1 - pysages/methods/constraints.txt | 3 --- pysages/methods/core.py | 8 -------- 3 files changed, 12 deletions(-) delete mode 100644 pysages/methods/constraints.txt diff --git a/pysages/backends/lammps.py b/pysages/backends/lammps.py index 6a8d791f..4659c29d 100644 --- a/pysages/backends/lammps.py +++ b/pysages/backends/lammps.py @@ -30,7 +30,6 @@ from pysages.typing import Callable, Optional from pysages.utils import copy, identity -import numpy as onp from ctypes import c_double kConversionFactor = {"real": 2390.0573615334906, "metal": 1.0364269e-4, "electron": 1.06657236} diff --git a/pysages/methods/constraints.txt b/pysages/methods/constraints.txt deleted file mode 100644 index 4f166c1e..00000000 --- a/pysages/methods/constraints.txt +++ /dev/null @@ -1,3 +0,0 @@ -jax==0.4.29 -jaxlib==0.4.29 -numpy==1.26.4 diff --git a/pysages/methods/core.py b/pysages/methods/core.py index 0fffdcc2..13e004d1 100644 --- a/pysages/methods/core.py +++ b/pysages/methods/core.py @@ -25,8 +25,6 @@ has_method, identity, ) -import numpy as onp -from ctypes import c_double # Base Classes # ============ @@ -407,12 +405,6 @@ def _run( # noqa: F811 # pylint: disable=C0116,E0102 sampler.restore(prev_snapshot) sampler.state = result.states - #positions = sampler.snapshot.positions.ravel() - #x = sampling_context.context.gather_atoms('x',1,3) - #for i in range(len(positions)): - # x[i] = positions[i] - #sampling_context.context.scatter_atoms('x',1,3,x) - with sampling_context: sampling_context.run(timesteps, **kwargs) if post_run_action: From 97ffc43e24b7cc819b23e7e23df9fa74c6bfd894 Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Tue, 25 Nov 2025 17:46:14 -0500 Subject: [PATCH 04/13] add funnel_metadabf --- pysages/methods/funnel_metadabf.py | 477 +++++++++++++++++++++++++++++ 1 file changed, 477 insertions(+) create mode 100644 pysages/methods/funnel_metadabf.py diff --git a/pysages/methods/funnel_metadabf.py b/pysages/methods/funnel_metadabf.py new file mode 100644 index 00000000..7e4d0cb9 --- /dev/null +++ b/pysages/methods/funnel_metadabf.py @@ -0,0 +1,477 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Implementation of Standard and Well-tempered MetaD-ABF +both with support for grids. +""" + +from jax import grad, jit +from jax import numpy as np +from jax import value_and_grad, vmap +from jax.lax import cond + +from pysages.approxfun import compute_mesh +from pysages.colvars import get_periods, wrap +from pysages.grids import build_indexer +from pysages.methods.analysis import GradientLearning, _funnelanalyze +from pysages.methods.core import GriddedSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.methods.utils import numpyfy_vals +from pysages.typing import JaxArray, NamedTuple, Optional +from pysages.utils import dispatch, gaussian, identity, linear_solver + + +class FMetaDABFState(NamedTuple): + """ + MetaDABF helper state + + Parameters + ---------- + + xi: JaxArray + Collective variable value in the last simulation step. + + bias: JaxArray + Array of Metadynamics bias forces for each particle in the simulation. + + heights: JaxArray + Height values for all accumulated Gaussians (zeros for not yet added Gaussians). + + centers: JaxArray + Centers of the accumulated Gaussians. + + sigmas: JaxArray + Widths of the accumulated Gaussians. + + grid_potential: Optional[JaxArray] + Array of Metadynamics bias potentials stored on a grid. + + grid_gradient: Optional[JaxArray] + Array of Metadynamics bias gradients evaluated on a grid. + + idx: int + Index of the next Gaussian to be deposited. + + idg: int + Index of the CV in the forces array. + + hist: JaxArray (grid.shape) + Histogram of visits to the bins in the collective variable grid. + + Fsum: JaxArray (grid.shape, CV shape) + Cumulative forces at each bin in the CV grid. + + force: JaxArray (grid.shape, CV shape) + Average force at each bin of the CV grid. + + Wp: JaxArray (CV shape) + Product of W matrix and momenta matrix for the current step. + + Wp_: JaxArray (CV shape) + Product of W matrix and momenta matrix for the previous step. + + ncalls: int + Counts the number of times `method.update` has been called. + + perp: JaxArray + Collective variable perpendicular to the Funnel_CV + """ + + xi: JaxArray + bias: JaxArray + heights: JaxArray + centers: JaxArray + sigmas: JaxArray + grid_potential: Optional[JaxArray] + grid_gradient: Optional[JaxArray] + idx: int + idg: int + Wp: JaxArray + Wp_: JaxArray + force: JaxArray + Fsum: JaxArray + Frestr: JaxArray + hist: JaxArray + ncalls: int + perp: JaxArray + + def __repr__(self): + return repr("PySAGES" + type(self).__name__) + + +class PartialMetadynamicsState(NamedTuple): + """ + Helper intermediate Metadynamics state + """ + + xi: JaxArray + heights: JaxArray + centers: JaxArray + sigmas: JaxArray + grid_potential: Optional[JaxArray] + grid_gradient: Optional[JaxArray] + idx: int + grid_idx: Optional[JaxArray] + + +class Funnel_MetaDABF(GriddedSamplingMethod): + """ + Implementation of Standard and Well-tempered Funnel MetaDABF as described in + arXiv preprint arXiv:2504.13575. + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, height, sigma, stride, ngaussians, deltaT=None, **kwargs): + """ + Parameters + ---------- + + cvs: + Set of user selected collective variable. + + height: + Initial height of the deposited Gaussians. + + sigma: + Initial standard deviation of the to-be-deposit Gaussians. + + stride: int + Bias potential deposition frequency. + + ngaussians: int + Total number of expected Gaussians (`timesteps // stride + 1`). + + deltaT: Optional[float] = None + Well-tempered Metadynamics :math:`\\Delta T` parameter + (if `None` standard Metadynamics is used). + + grid: Optional[Grid] = None + If provided, it will be used to accelerate the computation by + approximating the bias potential and its gradient over its centers. + + kB: Optional[float] + Boltzmann constant. Must be provided for well-tempered Metadynamics + simulations and should match the internal units of the backend. + + restraints: Optional[CVRestraints] = None + If provided, it will be used to restraint CV space inside the grid. + + external_force: + External restraint to be used for funnel calculations. + """ + + if deltaT is not None and "kB" not in kwargs: + raise KeyError( + "For well-tempered Metadynamics a keyword argument `kB` for " + "the value of the Boltzmann constant (that matches the " + "internal units of the backend) must be provided." + ) + + kwargs["grid"] = kwargs.get("grid", None) + kwargs["restraints"] = kwargs.get("restraints", None) + super().__init__(cvs, **kwargs) + + self.height = height + self.sigma = sigma + self.stride = stride + self.ngaussians = ngaussians # NOTE: infer from timesteps and stride + self.deltaT = deltaT + self.kB = kwargs.get("kB", None) + self.use_pinv = self.kwargs.get("use_pinv", False) + + def build(self, snapshot, helpers, *args, **kwargs): + self.external_force = self.kwargs.get("external_force", None) + return _metadynamics(self, snapshot, helpers) + + +def _metadynamics(method, snapshot, helpers): + # Initialization and update of biasing forces. Interface expected for methods. + cv = method.cv + stride = method.stride + dims = method.grid.shape.size + dt = snapshot.dt + ngaussians = method.ngaussians + get_grid_index = build_indexer(method.grid) + external_force = method.external_force + tsolve = linear_solver(method.use_pinv) + natoms = np.size(snapshot.positions, 0) + + deposit_gaussian = build_gaussian_accumulator(method) + evaluate_bias_grad = build_bias_grad_evaluator(method) + + def initialize(): + xi, _ = cv(helpers.query(snapshot)) + bias = np.zeros((natoms, helpers.dimensionality())) + perp = 0.0 + # NOTE: for restart; use hills file to initialize corresponding arrays. + heights = np.zeros(ngaussians, dtype=np.float64) + centers = np.zeros((ngaussians, xi.size), dtype=np.float64) + sigmas = np.array(method.sigma, dtype=np.float64, ndmin=2) + hist = np.zeros(method.grid.shape, dtype=np.uint32) + Fsum = np.zeros((*method.grid.shape, dims)) + force = np.zeros(dims) + Wp = np.zeros(dims) + Wp_ = np.zeros(dims) + restr = np.zeros(dims) + Frestr = np.zeros((*method.grid.shape, dims)) + # Arrays to store forces and bias potential on a grid. + if method.grid is None: + grid_potential = grid_gradient = None + else: + shape = method.grid.shape + grid_potential = np.zeros((*shape,), dtype=np.float64) if method.deltaT else None + grid_gradient = np.zeros((*shape, shape.size), dtype=np.float64) + + return FMetaDABFState( + xi, + bias, + heights, + centers, + sigmas, + grid_potential, + grid_gradient, + 0, + 0, + Wp, + Wp_, + force, + Fsum, + Frestr, + hist, + 0, + perp, + ) + + def update(state, data): + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + + # Deposit Gaussian depending on the stride + ncalls = state.ncalls + 1 + in_deposition_step = (ncalls > 1) & (ncalls % stride == 1) + partial_state = deposit_gaussian(xi, state, in_deposition_step) + + # Evaluate gradient of biasing potential (or generalized force) + force = evaluate_bias_grad(partial_state) + + # Calculate biasing forces + bias = -Jxi.T @ force.flatten() + eforce, perp = external_force(data) + bias = bias.reshape(state.bias.shape) - eforce.reshape(state.bias.shape) + p = data.momenta + Wp = tsolve(Jxi, p) + # Second order backward finite difference + dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt + # + I_xi = get_grid_index(xi) + hist = state.hist.at[I_xi].add(1) + Fsum = state.Fsum.at[I_xi].add(dWp_dt + state.force) + restr = tsolve(Jxi, eforce.reshape(p.shape)) + # + Frestr = state.Frestr.at[I_xi].add(restr) + + return FMetaDABFState( + xi, + bias, + *partial_state[1:-1], + I_xi, + Wp, + state.Wp, + force, + Fsum, + Frestr, + hist, + ncalls, + perp + ) + + return snapshot, initialize, generalize(update, helpers, jit_compile=True) + + +def build_gaussian_accumulator(method: Funnel_MetaDABF): + """ + Returns a function that given a `FMetaDABFState`, and the value of the CV, + stores the next Gaussian that is added to the biasing potential. + """ + periods = get_periods(method.cvs) + height_0 = method.height + deltaT = method.deltaT + grid = method.grid + kB = method.kB + + if deltaT is None: + next_height = jit(lambda *args: height_0) + else: # if well-tempered + if grid is None: + evaluate_potential = jit(lambda pstate: sum_of_gaussians(*pstate[:4], periods)) + else: + evaluate_potential = jit(lambda pstate: pstate.grid_potential[pstate.grid_idx]) + + def next_height(pstate): + V = evaluate_potential(pstate) + return height_0 * np.exp(-V / (deltaT * kB)) + + if grid is None: + get_grid_index = jit(lambda arg: None) + update_grids = jit(lambda *args: (None, None)) + should_deposit = jit(lambda pred, _: pred) + else: + grid_mesh = compute_mesh(grid) + get_grid_index = build_indexer(grid) + # Reshape so the dimensions are compatible + accum = jit(lambda total, val: total + val.reshape(total.shape)) + + if deltaT is None: + transform = grad + pack = jit(lambda x: (x,)) + # No need to accumulate values for the potential (V is None) + update = jit(lambda V, dV, vals: (V, accum(dV, vals))) + else: + transform = value_and_grad + pack = identity + update = jit(lambda V, dV, vals, grads: (accum(V, vals), accum(dV, grads))) + + def update_grids(pstate, height, xi, sigma): + # We use `sum_of_gaussians` since it already takes care of the wrapping + current_gaussian = jit(lambda x: sum_of_gaussians(x, height, xi, sigma, periods)) + # Evaluate gradient of bias (and bias potential for WT version) + grid_values = pack(vmap(transform(current_gaussian))(grid_mesh)) + return update(pstate.grid_potential, pstate.grid_gradient, *grid_values) + + def should_deposit(in_deposition_step, I_xi): + in_bounds = ~(np.any(np.array(I_xi) == grid.shape)) + return in_deposition_step & in_bounds + + def deposit_gaussian(pstate): + xi, idx = pstate.xi, pstate.idx + current_height = next_height(pstate) + heights = pstate.heights.at[idx].set(current_height) + centers = pstate.centers.at[idx].set(xi.flatten()) + sigmas = pstate.sigmas + grid_potential, grid_gradient = update_grids(pstate, current_height, xi, sigmas) + return PartialMetadynamicsState( + xi, heights, centers, sigmas, grid_potential, grid_gradient, idx + 1, pstate.grid_idx + ) + + def _deposit_gaussian(xi, state, in_deposition_step): + I_xi = get_grid_index(xi) + pstate = PartialMetadynamicsState(xi, *state[2:-9], I_xi) + predicate = should_deposit(in_deposition_step, I_xi) + return cond(predicate, deposit_gaussian, identity, pstate) + + return _deposit_gaussian + + +def build_bias_grad_evaluator(method: Funnel_MetaDABF): + """ + Returns a function that given the deposited Gaussians parameters, computes the + gradient of the biasing potential with respect to the CVs. + """ + grid = method.grid + restraints = method.restraints + if grid is None: + periods = get_periods(method.cvs) + evaluate_bias_grad = jit(lambda pstate: grad(sum_of_gaussians)(*pstate[:4], periods)) + else: + if restraints: + + def ob_force(pstate): # out-of-bounds force + lo, hi, kl, kh = restraints + xi, *_ = pstate + xi = pstate.xi.reshape(grid.shape.size) + force = apply_restraints(lo, hi, kl, kh, xi) + return force + + else: + + def ob_force(pstate): # out-of-bounds force + return np.zeros(grid.shape.size) + + def get_force(pstate): + return pstate.grid_gradient[pstate.grid_idx] + + def evaluate_bias_grad(pstate): + ob = np.any(np.array(pstate.grid_idx) == grid.shape) # out of bounds + return cond(ob, ob_force, get_force, pstate) + + return evaluate_bias_grad + + +# Helper function to evaluate bias potential -- may be moved to analysis part +def sum_of_gaussians(xi, heights, centers, sigmas, periods): + """ + Sum of n-dimensional Gaussians potential. + """ + delta_x = wrap(xi - centers, periods) + return gaussian(heights, sigmas, delta_x).sum() + + +@dispatch +def analyze(result: Result[Funnel_MetaDABF], **kwargs): + """ + Computes the free energy from the result of an `Funnel_MetaDABF` run. + Integrates the forces via a gradient learning strategy. + + Parameters + ---------- + + result: Result[Funnel_ABF]: + Result bundle containing method, final ABF state, and callback. + + topology: Optional[Tuple[int]] = (8, 8) + Defines the architecture of the neural network + (number of nodes in each hidden layer). + + Returns + ------- + + dict: + A dictionary with the following keys: + + histogram: JaxArray + Histogram for the states visited during the method. + + mean_force: JaxArray + Average force at each bin of the CV grid. + + free_energy: JaxArray + Free Energy at each bin of the CV grid. + + mesh: JaxArray + Grid used in the method. + + fes_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy in the + CV domain defined by the grid. + + corrected_force: JaxArray + Average mean force without restraint at each bin of the CV grid. + + corrected_energy: JaxArray + Free Energy without restraint at each bin of the CV grid. + + corr_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy without restraints + in the CV domain defined by the grid. + + restraint_force: JaxArray + Average mean force of the restraints at each bin of the CV grid. + + restraint_energy: JaxArray + Free Energy of the restraints at each bin of the CV grid. + + restr_fn: Callable[[JaxArray], JaxArray] + Function that allows to interpolate the free energy of the restraints + in the CV domain defined by the grid. + + + + NOTE: + For multiple-replicas runs we return a list (one item per-replica) + for each attribute. + """ + topology = kwargs.get("topology", (8, 8)) + _result = _funnelanalyze(result, GradientLearning(), topology) + return numpyfy_vals(_result) From 9ff80e3091941c1f755c59a1845690fbbca6cd22 Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Wed, 26 Nov 2025 12:02:45 -0500 Subject: [PATCH 05/13] add static bias method --- pysages/methods/__init__.py | 1 + pysages/methods/funnel_static.py | 212 +++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 pysages/methods/funnel_static.py diff --git a/pysages/methods/__init__.py b/pysages/methods/__init__.py index 19859394..caf26814 100644 --- a/pysages/methods/__init__.py +++ b/pysages/methods/__init__.py @@ -75,6 +75,7 @@ from .spline_string import SplineString from .umbrella_integration import UmbrellaIntegration from .unbiased import Unbiased +from .funnel_static import Static from .utils import ( HistogramLogger, MetaDLogger, diff --git a/pysages/methods/funnel_static.py b/pysages/methods/funnel_static.py new file mode 100644 index 00000000..307d45d7 --- /dev/null +++ b/pysages/methods/funnel_static.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Funnel Adaptive Biasing Force (Funnel_ABF) sampling method. + +ABF partitions the collective variable space into bins determined by a user +provided grid, and keeps a tabulation of the number of visits to each bin +as well as the sum of generalized forces experienced by the system at each +configuration bin. These provide an estimate for the mean generalized force, +which can be integrated to yield the free energy. This update allows to include + external restraint and accurately remove their contribution in the free energy + calculations. + +""" + +from jax import jit +from jax import numpy as np +from jax.lax import cond + +from pysages.grids import build_indexer +from pysages.methods.core import GriddedSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.typing import JaxArray, NamedTuple +import pysages + +class StaticState(NamedTuple): + """ + Funnel_ABF internal state. + + Parameters + ---------- + + xi: JaxArray (CV shape) + Last collective variable recorded in the simulation. + + bias: JaxArray (Nparticles, d) + Array with biasing forces for each particle. + + """ + + xi: JaxArray + bias: JaxArray + + def __repr__(self): + return repr("PySAGES " + type(self).__name__) + + +class Static(GriddedSamplingMethod): + """ + Funnel with Static. + + Attributes + ---------- + + snapshot_flags: + Indicate the system properties required from a snapshot. + + Parameters + ---------- + + cvs: Union[List, Tuple] + Set of user selected collective variable. + + grid: Grid + Specifies the collective variables domain and number of bins for + discretizing the CV space along each CV dimension. + + restart_file: str + .pkl file from previous pysages simulation. + + restraints: Optional[CVRestraints] = None + If provided, indicate that harmonic restraints will be applied when any + collective variable lies outside the box from `restraints.lower` to + `restraints.upper`. + + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, grid, restart_file, **kwargs): + super().__init__(cvs, grid, **kwargs) + run_result = pysages.load(restart_file) + fe_result = pysages.analyze(run_result) + self.force_array = np.asarray(-fe_result['mean_force']) + + def build(self, snapshot, helpers, *args, **kwargs): + """ + Build the functions for the execution of ABF + + Parameters + ---------- + + snapshot: + PySAGES snapshot of the simulation (backend dependent). + + helpers: + Helper function bundle as generated by + `SamplingMethod.context[0].get_backend().build_helpers`. + + Returns + ------- + + Tuple `(snapshot, initialize, update)` to run ABF simulations. + """ + self.ext_force = self.kwargs.get("ext_force", None) + return _static(self, snapshot, helpers) + + +def _static(method, snapshot, helpers): + """ + Internal function that generates the init and update functions. + + Parameters + ---------- + + method: ABF + Class that generates the functions. + snapshot: + PySAGES snapshot of the simulation (backend dependent). + helpers + Helper function bundle as generated by + `SamplingMethod.context[0].get_backend().build_helpers`. + + Returns + ------- + Tuple `(snapshot, initialize, update)` to run static bias simulations. + """ + cv = method.cv + grid = method.grid + force_array = method.force_array + dims = grid.shape.size + natoms = np.size(snapshot.positions, 0) + get_grid_index = build_indexer(grid) + estimate_force = build_force_estimator(method) + ext_force = method.ext_force + + def initialize(): + """ + Internal function that generates the first FABFState + with correctly shaped JaxArrays. + + Returns + ------- + FABFState + Initialized State + """ + xi, _ = cv(helpers.query(snapshot)) + bias = np.zeros((natoms, helpers.dimensionality())) + return StaticState(xi, bias) + + def update(state, data): + """ + Advance the state of the Funnel_ABF simulation. + + Parameters + ---------- + + state: FABFstate + Old FABFstate from the previous simutlation step. + data: JaxArray + Snapshot to access simulation data. + + Returns + ------- + FABFState + Updated internal state. + """ + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + + # Restraint force and logger + e_f, proj = ext_force(data) + # Second order backward finite difference + I_xi = get_grid_index(xi) + force = estimate_force(xi, I_xi, force_array).reshape(dims) + bias = np.reshape(-Jxi.T @ force, state.bias.shape) + np.reshape(-e_f, state.bias.shape) + # Restraint contribution to force + + return StaticState(xi, bias) + + return snapshot, initialize, generalize(update, helpers) + + +@dispatch +def build_force_estimator(method: Static): + """ + Returns a function that computes the average forces + (or the harmonic restraints forces if provided). + """ + grid = method.grid + + def bias_force(data): + _, I_xi, force_array = data + return force_array[I_xi] + + if method.restraints is None: + estimate_force = jit(lambda *args: bias_force(args)) + else: + lo, hi, kl, kh = method.restraints + + def restraints_force(data): + xi, *_ = data + xi = xi.reshape(grid.shape.size) + return apply_restraints(lo, hi, kl, kh, xi) + + def estimate_force(xi, I_xi, force_array): + ob = np.any(np.array(I_xi) == grid.shape) # Out of bounds condition + data = (xi, I_xi, force_array) + return cond(ob, restraints_force, bias_force, data) + + return estimate_force From e545f51b5bdd95e79db0c70935a159cc5e332edf Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Wed, 26 Nov 2025 15:31:05 -0500 Subject: [PATCH 06/13] update methods/__init__.py to include funnel_metadabf.py --- pysages/methods/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pysages/methods/__init__.py b/pysages/methods/__init__.py index caf26814..d5e3792b 100644 --- a/pysages/methods/__init__.py +++ b/pysages/methods/__init__.py @@ -76,6 +76,7 @@ from .umbrella_integration import UmbrellaIntegration from .unbiased import Unbiased from .funnel_static import Static +from .funnel_metadabf.py import Funnel_MetaDABF from .utils import ( HistogramLogger, MetaDLogger, From 521e668963b27601ee8804327f2a8d20960ef5f8 Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Wed, 26 Nov 2025 15:45:47 -0500 Subject: [PATCH 07/13] bug fix --- pysages/methods/funnel_static.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pysages/methods/funnel_static.py b/pysages/methods/funnel_static.py index 307d45d7..330792a7 100644 --- a/pysages/methods/funnel_static.py +++ b/pysages/methods/funnel_static.py @@ -23,6 +23,7 @@ from pysages.methods.restraints import apply_restraints from pysages.typing import JaxArray, NamedTuple import pysages +from pysages.utils import dispatch class StaticState(NamedTuple): """ From ed1aa8ee7c76903efb3523b5c8d79b3fcfb1595a Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Sun, 4 Jan 2026 03:47:36 -0500 Subject: [PATCH 08/13] include shiftable sigmoid as CV --- pysages/colvars/ML_committor.py | 27 ++++++++++++++++++++++++- pysages/colvars/train_committor_dist.py | 27 +++++++++++++++++++++++++ pysages/methods/__init__.py | 2 +- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/pysages/colvars/ML_committor.py b/pysages/colvars/ML_committor.py index 6769067e..84a8def6 100644 --- a/pysages/colvars/ML_committor.py +++ b/pysages/colvars/ML_committor.py @@ -2,7 +2,7 @@ from jax import numpy as np from functools import partial from pysages.colvars.core import CollectiveVariable -from .train_committor_dist import CommittorNN_Dist_Lip, make_forward_eval, CommittorNN_PIV +from .train_committor_dist import CommittorNN_Dist_Lip, make_forward_eval, CommittorNN_PIV, CommittorNN_PIV_shiftsig from pysages.typing import JaxArray, List, Sequence from flax import serialization @@ -60,5 +60,30 @@ def wrapped_forward(pos): return np.squeeze(y) return wrapped_forward +class Committor_CV_PIV_shiftsig(CollectiveVariable): + def __init__(self, indices: List, params_path: str, blocks: Sequence): + super().__init__(indices) + + model = CommittorNN_PIV_shiftsig(indices=np.arange(len(indices)), blocks=blocks, h1=32, h2=16, h3=8, out_dim=1, sig_k=3.0) + + rng = jax.random.PRNGKey(0) + dummy_pos = np.zeros((1, len(indices), 3)) + params = model.init(rng, dummy_pos, training=False) + with open(params_path, "rb") as f: + params = serialization.from_bytes(params, f.read()) + params = jax.tree.map( + lambda x: x.astype(np.float64) if hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating) else x, + params, + ) + self.params = params + self.forward_eval = make_forward_eval(model) + + @property + def function(self): + def wrapped_forward(pos): + y = self.forward_eval(self.params, pos[None, :, :]) + return np.squeeze(y) + return wrapped_forward + def cartesian(idx1, idx2): return np.stack(np.broadcast_arrays(idx1[:, None], idx2[None, :]), axis=-1).reshape(-1, 2) diff --git a/pysages/colvars/train_committor_dist.py b/pysages/colvars/train_committor_dist.py index b654f6ba..5c29a179 100644 --- a/pysages/colvars/train_committor_dist.py +++ b/pysages/colvars/train_committor_dist.py @@ -103,6 +103,33 @@ def __call__(self, pos: np.ndarray, training: bool = False) -> np.ndarray: x = np.squeeze(x, axis=-1) # (B,) return jax.nn.sigmoid(self.sig_k * x) if training else x +class CommittorNN_PIV_shiftsig(nn.Module): + indices: np.ndarray # (M,) + blocks: Sequence + h1: int = 32 + h2: int = 16 + h3: int = 8 + out_dim: int = 1 + sig_k: float = 4.0 + + @nn.compact + def __call__(self, pos: np.ndarray, training: bool = False) -> np.ndarray: + # pos: (B, N, 3) + sig_shift1 = 1e-3*self.param("sig_shift1", nn.initializers.constant(0.0), ()) + sig_shift2 = 1e-3*self.param("sig_shift2", nn.initializers.constant(0.0), ()) + + x = get_x(pos[:, self.indices, :], self.blocks) + x = NormalizedLinear(x.shape[-1], self.h1)(x); x = np.tanh(x) + x = NormalizedLinear(self.h1, self.h2)(x); x = np.tanh(x) + x = NormalizedLinear(self.h2, self.h3)(x); x = np.tanh(x) + x = NormalizedLinear(self.h3, self.out_dim)(x) # (B,1) + x = np.squeeze(x, axis=-1) # (B,) + q = 0.5*jax.nn.sigmoid(self.sig_k * (x-sig_shift1) ) + 0.5*jax.nn.sigmoid(self.sig_k * (x-sig_shift2) ) + if training: + return q, x + else: + return x + # ------------------------------ # Losses # ------------------------------ diff --git a/pysages/methods/__init__.py b/pysages/methods/__init__.py index d5e3792b..9bbd04fa 100644 --- a/pysages/methods/__init__.py +++ b/pysages/methods/__init__.py @@ -76,7 +76,7 @@ from .umbrella_integration import UmbrellaIntegration from .unbiased import Unbiased from .funnel_static import Static -from .funnel_metadabf.py import Funnel_MetaDABF +from .funnel_metadabf import Funnel_MetaDABF from .utils import ( HistogramLogger, MetaDLogger, From a9ea6df46f2245d1698ad07de7ba220572135c36 Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Tue, 13 Jan 2026 11:34:43 -0500 Subject: [PATCH 09/13] fix bug with ncalls, now uses hist[I_xi] to determine fitting regime --- pysages/methods/funnel_sabf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pysages/methods/funnel_sabf.py b/pysages/methods/funnel_sabf.py index b8597124..269664ae 100644 --- a/pysages/methods/funnel_sabf.py +++ b/pysages/methods/funnel_sabf.py @@ -194,23 +194,23 @@ def initialize(): return SpectralFABFState(xi, bias, hist, Fsum, force, Wp, Wp_, fun, restr, proj, Frestr, 0) def update(state, data): + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + # Restraint force and logger + e_f, proj = ext_force(data) # During the intial stage use ABF + I_xi = get_grid_index(xi) ncalls = state.ncalls + 1 - in_fitting_regime = ncalls > fit_threshold + in_fitting_regime = state.hist[I_xi] > fit_threshold in_fitting_step = in_fitting_regime & (ncalls % fit_freq == 1) & (ncalls < fit_threshold_upper) # Fit forces fun = fit_forces(state, in_fitting_step) - # Compute the collective variable and its jacobian - xi, Jxi = cv(data) - # Restraint force and logger - e_f, proj = ext_force(data) # p = data.momenta Wp = tsolve(Jxi, p) # Second order backward finite difference dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt # - I_xi = get_grid_index(xi) hist = state.hist.at[I_xi].add(1) Fsum = state.Fsum.at[I_xi].add(to_force_units(dWp_dt) + state.force) # From d8f46898d21e9ed422d9bf5791f0e7650de2495a Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Tue, 13 Jan 2026 17:17:45 -0500 Subject: [PATCH 10/13] add optional NN hyperparameters in CV definition and new unbaised funnel simulation method --- pysages/colvars/ML_committor.py | 4 +- pysages/methods/funnel_unbiased.py | 208 +++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+), 2 deletions(-) create mode 100644 pysages/methods/funnel_unbiased.py diff --git a/pysages/colvars/ML_committor.py b/pysages/colvars/ML_committor.py index 84a8def6..7e596638 100644 --- a/pysages/colvars/ML_committor.py +++ b/pysages/colvars/ML_committor.py @@ -61,10 +61,10 @@ def wrapped_forward(pos): return wrapped_forward class Committor_CV_PIV_shiftsig(CollectiveVariable): - def __init__(self, indices: List, params_path: str, blocks: Sequence): + def __init__(self, indices: List, params_path: str, blocks: Sequence, h1=32, h2=16, h3=8, sig_k=3.0): super().__init__(indices) - model = CommittorNN_PIV_shiftsig(indices=np.arange(len(indices)), blocks=blocks, h1=32, h2=16, h3=8, out_dim=1, sig_k=3.0) + model = CommittorNN_PIV_shiftsig(indices=np.arange(len(indices)), blocks=blocks, h1=h1, h2=h2, h3=h3, out_dim=1, sig_k=3.0) rng = jax.random.PRNGKey(0) dummy_pos = np.zeros((1, len(indices), 3)) diff --git a/pysages/methods/funnel_unbiased.py b/pysages/methods/funnel_unbiased.py new file mode 100644 index 00000000..717f7267 --- /dev/null +++ b/pysages/methods/funnel_unbiased.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +Funnel Adaptive Biasing Force (Funnel_ABF) sampling method. + +ABF partitions the collective variable space into bins determined by a user +provided grid, and keeps a tabulation of the number of visits to each bin +as well as the sum of generalized forces experienced by the system at each +configuration bin. These provide an estimate for the mean generalized force, +which can be integrated to yield the free energy. This update allows to include + external restraint and accurately remove their contribution in the free energy + calculations. + +""" + +from jax import jit +from jax import numpy as np +from jax.lax import cond + +from pysages.grids import build_indexer +from pysages.methods.core import GriddedSamplingMethod, Result, generalize +from pysages.methods.restraints import apply_restraints +from pysages.typing import JaxArray, NamedTuple +import pysages +from pysages.utils import dispatch + +class UnbiasState(NamedTuple): + """ + Funnel_ABF internal state. + + Parameters + ---------- + + xi: JaxArray (CV shape) + Last collective variable recorded in the simulation. + + bias: JaxArray (Nparticles, d) + Array with biasing forces for each particle. + + """ + + xi: JaxArray + bias: JaxArray + zero: JaxArray + + def __repr__(self): + return repr("PySAGES " + type(self).__name__) + + +class Unbias(GriddedSamplingMethod): + """ + Funnel with Unbias. + + Attributes + ---------- + + snapshot_flags: + Indicate the system properties required from a snapshot. + + Parameters + ---------- + + cvs: Union[List, Tuple] + Set of user selected collective variable. + + grid: Grid + Specifies the collective variables domain and number of bins for + discretizing the CV space along each CV dimension. + + restraints: Optional[CVRestraints] = None + If provided, indicate that harmonic restraints will be applied when any + collective variable lies outside the box from `restraints.lower` to + `restraints.upper`. + + """ + + snapshot_flags = {"positions", "indices", "momenta"} + + def __init__(self, cvs, grid, **kwargs): + super().__init__(cvs, grid, **kwargs) + + def build(self, snapshot, helpers, *args, **kwargs): + """ + Build the functions for the execution of ABF + + Parameters + ---------- + + snapshot: + PySAGES snapshot of the simulation (backend dependent). + + helpers: + Helper function bundle as generated by + `SamplingMethod.context[0].get_backend().build_helpers`. + + Returns + ------- + + Tuple `(snapshot, initialize, update)` to run ABF simulations. + """ + self.ext_force = self.kwargs.get("ext_force", None) + return _Unbias(self, snapshot, helpers) + + +def _Unbias(method, snapshot, helpers): + """ + Internal function that generates the init and update functions. + + Parameters + ---------- + + method: ABF + Class that generates the functions. + snapshot: + PySAGES snapshot of the simulation (backend dependent). + helpers + Helper function bundle as generated by + `SamplingMethod.context[0].get_backend().build_helpers`. + + Returns + ------- + Tuple `(snapshot, initialize, update)` to run Unbias bias simulations. + """ + cv = method.cv + grid = method.grid + dims = grid.shape.size + natoms = np.size(snapshot.positions, 0) + get_grid_index = build_indexer(grid) + estimate_force = build_force_estimator(method) + ext_force = method.ext_force + + def initialize(): + """ + Internal function that generates the first FABFState + with correctly shaped JaxArrays. + + Returns + ------- + FABFState + Initialized State + """ + xi, _ = cv(helpers.query(snapshot)) + bias = np.zeros((natoms, helpers.dimensionality())) + zero = np.zeros((natoms, helpers.dimensionality())) + return UnbiasState(xi, bias, zero) + + def update(state, data): + """ + Advance the state of the Funnel_ABF simulation. + + Parameters + ---------- + + state: FABFstate + Old FABFstate from the previous simutlation step. + data: JaxArray + Snapshot to access simulation data. + + Returns + ------- + FABFState + Updated internal state. + """ + # Compute the collective variable and its jacobian + xi, Jxi = cv(data) + + # Restraint force and logger + e_f, proj = ext_force(data) + # Second order backward finite difference + I_xi = get_grid_index(xi) + force = estimate_force(xi, I_xi, state.zero).reshape(dims) + bias = np.reshape(-Jxi.T @ force, state.bias.shape) + np.reshape(-e_f, state.bias.shape) + # Restraint contribution to force + + return UnbiasState(xi, bias, state.zero) + + return snapshot, initialize, generalize(update, helpers) + + +@dispatch +def build_force_estimator(method: Unbias): + """ + Returns a function that computes the average forces + (or the harmonic restraints forces if provided). + """ + grid = method.grid + + def bias_force(data): + _, zero = data + return zero + + if method.restraints is None: + estimate_force = jit(lambda *args: bias_force(args)) + else: + lo, hi, kl, kh = method.restraints + + def restraints_force(data): + xi, _ = data + xi = xi.reshape(grid.shape.size) + return apply_restraints(lo, hi, kl, kh, xi) + + def estimate_force(xi, I_xi, zero): + ob = np.any(np.array(I_xi) == grid.shape) # Out of bounds condition + data = (xi, zero) + return cond(ob, restraints_force, bias_force, data) + + return estimate_force From 638f7ed158c953c443d7a1c81d29d566ed84d9c3 Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Tue, 13 Jan 2026 17:52:34 -0500 Subject: [PATCH 11/13] bug fix --- pysages/methods/__init__.py | 1 + pysages/methods/funnel_unbiased.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pysages/methods/__init__.py b/pysages/methods/__init__.py index 9bbd04fa..736bc65c 100644 --- a/pysages/methods/__init__.py +++ b/pysages/methods/__init__.py @@ -77,6 +77,7 @@ from .unbiased import Unbiased from .funnel_static import Static from .funnel_metadabf import Funnel_MetaDABF +from .funnel_unbiased import Unbias from .utils import ( HistogramLogger, MetaDLogger, diff --git a/pysages/methods/funnel_unbiased.py b/pysages/methods/funnel_unbiased.py index 717f7267..4210d555 100644 --- a/pysages/methods/funnel_unbiased.py +++ b/pysages/methods/funnel_unbiased.py @@ -187,7 +187,7 @@ def build_force_estimator(method: Unbias): grid = method.grid def bias_force(data): - _, zero = data + *_, zero = data return zero if method.restraints is None: @@ -196,13 +196,13 @@ def bias_force(data): lo, hi, kl, kh = method.restraints def restraints_force(data): - xi, _ = data + xi, *_ = data xi = xi.reshape(grid.shape.size) return apply_restraints(lo, hi, kl, kh, xi) def estimate_force(xi, I_xi, zero): ob = np.any(np.array(I_xi) == grid.shape) # Out of bounds condition - data = (xi, zero) + data = (xi, I_xi, zero) return cond(ob, restraints_force, bias_force, data) return estimate_force From f2ceabb6d60a681b14a9d8a1ab3e041dfabc3f2c Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Tue, 13 Jan 2026 18:59:35 -0500 Subject: [PATCH 12/13] bug fix --- pysages/methods/funnel_unbiased.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysages/methods/funnel_unbiased.py b/pysages/methods/funnel_unbiased.py index 4210d555..3758abcc 100644 --- a/pysages/methods/funnel_unbiased.py +++ b/pysages/methods/funnel_unbiased.py @@ -142,7 +142,7 @@ def initialize(): """ xi, _ = cv(helpers.query(snapshot)) bias = np.zeros((natoms, helpers.dimensionality())) - zero = np.zeros((natoms, helpers.dimensionality())) + zero = np.zeros(dims) return UnbiasState(xi, bias, zero) def update(state, data): From 72b322a2382756497bce11392905c3d4beb196bd Mon Sep 17 00:00:00 2001 From: Andrew Qin Date: Wed, 14 Jan 2026 15:37:25 -0500 Subject: [PATCH 13/13] fix restarting for funnel unbiased --- pysages/serialization.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pysages/serialization.py b/pysages/serialization.py index c83f2fbd..2efcea04 100644 --- a/pysages/serialization.py +++ b/pysages/serialization.py @@ -20,7 +20,7 @@ import dill as pickle from pysages.backends.snapshot import Box, Snapshot -from pysages.methods import Metadynamics +from pysages.methods import Metadynamics, Unbias from pysages.methods.core import GriddedSamplingMethod, Result from pysages.typing import Callable from pysages.utils import dispatch, identity @@ -98,6 +98,10 @@ def _ncalls_estimator(_) -> Callable: # Fallback case. We leave ncalls as zero. return identity +@dispatch +def _ncalls_estimator(_: Unbias) -> Callable: + # Fallback case. We leave ncalls as zero. + return identity @dispatch def _ncalls_estimator(_: Metadynamics) -> Callable: