Why JAX for molecular simulation¶
In md_lj.ipynb you wrote an LJ MD code by hand. Half the complexity sat in one line:
f_over_r = 48 * r6_inv * (r6_inv - 0.5) * r2_invThat’s the analytical derivative of the LJ pair potential. We had to derive it. Change the potential — to WCA, Morse, Stillinger–Weber, or a learned neural-network potential — and we’d derive a new force formula and re-implement it. Fifty years of molecular simulation codes have lived with this friction.
JAX removes it. Define the potential energy as a pure function of positions; ask JAX for its gradient. That gradient is the force, evaluated to machine precision and jit-compiled into fast XLA code.
This is more than coding convenience. Once everything is differentiable, parameters are too: , for any potential parameter . Force-field fitting, sensitivity analysis, and end-to-end differentiable simulation become one-liners. This notebook works through the basics; production frameworks like JAX-MD extend the same pattern to neighbor lists, neural-network potentials, and full GPU pipelines.
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
# JAX defaults to float32; enable float64 so derivatives match numpy
# to full double precision (matters for the autodiff sanity checks below).
jax.config.update('jax_enable_x64', True)
rng = np.random.default_rng(0)The energy is the only thing you write¶
The full LJ pair potential
is implemented as a pure function of positions and box length. Periodic boundaries enter through the minimum-image convention , and pairs beyond the cutoff are masked. Forces are then just grad(energy).
def lj_energy(pos, L, rcut=2.5):
diff = pos[:, None, :] - pos[None, :, :]
diff = diff - L * jnp.round(diff / L) # minimum image
r2 = jnp.sum(diff ** 2, axis=-1)
mask = (r2 < rcut * rcut) & (r2 > 1e-6)
safe_r2 = jnp.where(mask, r2, 1.0)
inv2 = 1.0 / safe_r2
inv6 = inv2 ** 3
pair_u = 4.0 * (inv6 ** 2 - inv6)
return 0.5 * jnp.sum(jnp.where(mask, pair_u, 0.0)) # 0.5: pair sum is symmetric
# Forces fall out of autodiff. No hand derivation.
lj_force = jit(lambda pos, L: -grad(lj_energy)(pos, L))
# Velocity Verlet, jit-compiled
@jit
def vv_step(pos, vel, F, dt, L):
vel_half = vel + 0.5 * dt * F
pos_new = (pos + dt * vel_half) % L
F_new = lj_force(pos_new, L)
vel_new = vel_half + 0.5 * dt * F_new
return pos_new, vel_new, F_newA short NVE run from FCC initial conditions¶
Standard MD setup: 4-atom-basis FCC lattice in a cubic box, Maxwell–Boltzmann velocities centered at zero and rescaled to a target , then velocity-Verlet integration. Energy conservation is the basic test that the integrator and force routine are mutually consistent.
def fcc_init(N_cell=3, rho=0.7, T=1.0, seed=0):
N = 4 * N_cell ** 3
L = (N / rho) ** (1 / 3)
base = np.array([[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]])
pos = np.array([b + [i, j, k]
for i in range(N_cell)
for j in range(N_cell)
for k in range(N_cell)
for b in base]) * (L / N_cell)
r = np.random.default_rng(seed)
vel = r.standard_normal((N, 3))
vel -= vel.mean(axis=0)
vel *= np.sqrt(3 * N * T / np.sum(vel ** 2)) # rescale to target T
return pos, vel, L
pos_np, vel_np, L_box = fcc_init()
pos = jnp.asarray(pos_np)
vel = jnp.asarray(vel_np)
F = lj_force(pos, L_box) # warm up the JIT
n_steps, save_every, dt = 800, 20, 0.005
energies = []
for step in range(n_steps):
pos, vel, F = vv_step(pos, vel, F, dt, L_box)
if step % save_every == 0:
E = float(lj_energy(pos, L_box) + 0.5 * jnp.sum(vel ** 2))
energies.append(E)
t_axis = np.arange(len(energies)) * save_every * dt
fig, ax = plt.subplots(figsize=(7, 3.5))
ax.plot(t_axis, energies, '-o', color='seagreen', ms=4)
ax.set(xlabel='time (LJ units)', ylabel='total energy',
title='JAX MD (NVE): forces from autodiff, jit-compiled')
ax.grid(True, ls=':', alpha=0.5)
fig.tight_layout()Three concrete demos: forces, pressure, parameter sensitivity¶
The whole MD engine fits in ~15 lines, but does it work? Three checks make the autodiff payoff visible: forces fall out, pressure falls out, and parameter sensitivities fall out — all from the same one-line mechanism.
The trivial FCC starting point has zero forces by symmetry, so we use the JAX-evolved positions from the run above. We also need a small numpy reference implementation of the LJ force routine for the comparison.
# Hand-derived LJ forces in numpy (the routine you wrote in md_lj)
def lj_forces_numpy(pos, L, rcut=2.5):
N = len(pos)
I, J = np.triu_indices(N, k=1)
r_vec = pos[I] - pos[J]
r_vec -= L * np.round(r_vec / L)
r_sq = np.sum(r_vec ** 2, axis=1)
mask = r_sq < rcut ** 2
r_vec, r_sq = r_vec[mask], r_sq[mask]
r2_inv = 1.0 / r_sq
r6_inv = r2_inv ** 3
f_over_r = 48 * r6_inv * (r6_inv - 0.5) * r2_inv
F_ij = f_over_r[:, None] * r_vec
forces = np.zeros_like(pos)
np.add.at(forces, I[mask], F_ij)
np.add.at(forces, J[mask], -F_ij)
return forces, np.sum(F_ij * r_vec) # forces and virial sum
# === Demo 1: forces from grad(energy) match the hand-derived numpy code ===
F_jax = -np.asarray(grad(lj_energy)(pos, L_box))
F_np, virial_np = lj_forces_numpy(np.asarray(pos), L_box)
print('=== Demo 1: forces ===')
print(f' max |F_jax - F_numpy| = {np.max(np.abs(F_jax - F_np)):.3e}')
print(f' typical |F| = {np.abs(F_np).mean():.3e}')
print(f' -> autodiff agrees with hand-derived forces to machine precision.')
# === Demo 2: configurational pressure as a *different* gradient ===
# Scaling pos and L by lambda multiplies all separations by lambda.
# dE/dlambda at lambda=1 = -W (virial sum), so p_config = -dE/dlambda / (3V).
def E_at_scale(scale, pos, L):
return lj_energy(scale * pos, scale * L)
dE_dscale = float(jax.grad(E_at_scale)(1.0, pos, L_box))
V = L_box ** 3
p_config_jax = -dE_dscale / (3 * V)
p_config_np = virial_np / (3 * V)
print('\n=== Demo 2: configurational pressure ===')
print(f' p_config (autodiff): {p_config_jax:.6f}')
print(f' p_config (virial): {p_config_np:.6f}')
print(f' -> same number, two completely different code paths.')Differentiating through parameters of the potential¶
Demos 1 and 2 took gradients with respect to positions. Nothing in JAX cares whether the input is a coordinate or a parameter — make a function argument and is one grad call away.
This is the workhorse of force-field fitting and sensitivity analysis. The plot below sweeps at fixed configuration (a non-linear curve) and overlays the autodiff-computed tangent at .
def lj_energy_sigma(pos, L, sigma, eps=1.0, rcut=2.5):
diff = pos[:, None, :] - pos[None, :, :]
diff = diff - L * jnp.round(diff / L)
r2 = jnp.sum(diff ** 2, axis=-1)
mask = (r2 < rcut * rcut) & (r2 > 1e-6)
safe_r2 = jnp.where(mask, r2, 1.0)
sr2 = sigma * sigma / safe_r2
sr6 = sr2 ** 3
pair_u = 4.0 * eps * (sr6 ** 2 - sr6)
return 0.5 * jnp.sum(jnp.where(mask, pair_u, 0.0))
sigma_grid = np.linspace(0.85, 1.15, 40)
E_grid = np.array([float(lj_energy_sigma(pos, L_box, s)) for s in sigma_grid])
sigma0 = 1.0
E0 = float(lj_energy_sigma(pos, L_box, sigma0))
dE_dsigma = float(jax.grad(lj_energy_sigma, argnums=2)(pos, L_box, sigma0))
print(f'E(sigma=1) = {E0:.4f}')
print(f'dE/dsigma at 1 = {dE_dsigma:.4f} (one grad call, no finite differences)')
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(sigma_grid, E_grid, color='steelblue', lw=2,
label=r'$E(\sigma)$ at fixed positions')
ax.scatter([sigma0], [E0], color='crimson', s=50, zorder=5)
tan_s = np.array([sigma0 - 0.06, sigma0 + 0.06])
ax.plot(tan_s, E0 + dE_dsigma * (tan_s - sigma0), color='crimson', lw=2.2,
label=fr'tangent from autodiff, slope $= {dE_dsigma:.2f}$')
ax.set(xlabel=r'$\sigma$', ylabel='total LJ energy',
title=r'$E(\sigma)$: autodiff gives the local slope for free')
ax.legend()
ax.grid(True, ls=':', alpha=0.5)
fig.tight_layout()Many simulations at once: vmap¶
JAX’s vmap adds a new leading “batch” dimension to any function. Wrap vv_step with vmap and you can integrate many independent trajectories in parallel — different initial velocities, different temperatures, different random seeds — without writing a Python loop. On a GPU this is a massive speedup; even on a CPU it amortizes JIT overhead and removes the Python interpreter from the inner loop over batch elements.
# Build a batch of 8 trajectories with different initial velocities
n_batch = 8
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, n_batch)
def init_batch_member(seed_key):
pos0 = jnp.asarray(pos_np) # same FCC start
v = jax.random.normal(seed_key, (pos0.shape[0], 3)) # random velocity
v = v - v.mean(axis=0)
v = v * jnp.sqrt(3 * pos0.shape[0] * 1.0 / jnp.sum(v ** 2)) # rescale to T = 1
F0 = lj_force(pos0, L_box)
return pos0, v, F0
batch_state = vmap(init_batch_member)(keys)
# vmap'd integrator step: maps over the leading (batch) axis of pos / vel / F
batched_step = jit(vmap(vv_step, in_axes=(0, 0, 0, None, None)))
n_batch_steps = 600
for step in range(n_batch_steps):
batch_state = batched_step(*batch_state, dt, L_box)
batch_pos, batch_vel, _ = batch_state
batched_energy = jit(vmap(lj_energy, in_axes=(0, None)))
final_PE = batched_energy(batch_pos, L_box)
final_KE = 0.5 * jnp.sum(batch_vel ** 2, axis=(1, 2))
final_E_total = np.asarray(final_PE + final_KE)
fig, ax = plt.subplots(figsize=(7, 3.5))
ax.bar(np.arange(n_batch), final_E_total, color='cornflowerblue', edgecolor='k')
ax.set(xlabel='trajectory index', ylabel='total energy after run',
title=f'{n_batch} independent NVE trajectories run in parallel via vmap')
ax.grid(True, ls=':', alpha=0.4)
fig.tight_layout()
print('Final energies (independent runs, each NVE-conserved):',
np.round(final_E_total, 3))A toy fitting problem: tuning a parameter by gradient descent¶
The endgame of differentiable simulation: take the gradient of a loss (against experimental data, against a target ensemble average, …) with respect to potential parameters, and use it to fit. To keep the demo fast we use the simplest target imaginable: tune so that the LJ energy at this configuration equals a chosen value . This is a one-parameter quadratic problem with a closed-form answer ( is linear in ), but the machinery — value_and_grad, gradient-descent loop — is exactly what you’d use for a real fit.
def lj_energy_eps(pos, L, eps, sigma=1.0, rcut=2.5):
diff = pos[:, None, :] - pos[None, :, :]
diff = diff - L * jnp.round(diff / L)
r2 = jnp.sum(diff ** 2, axis=-1)
mask = (r2 < rcut * rcut) & (r2 > 1e-6)
safe_r2 = jnp.where(mask, r2, 1.0)
sr2 = sigma * sigma / safe_r2
sr6 = sr2 ** 3
pair_u = 4.0 * eps * (sr6 ** 2 - sr6)
return 0.5 * jnp.sum(jnp.where(mask, pair_u, 0.0))
E_target = -700.0
def loss(eps, pos, L):
return (lj_energy_eps(pos, L, eps) - E_target) ** 2
loss_and_grad = jit(jax.value_and_grad(loss))
eps = 0.5 # starting guess
lr = 1e-6 # learning rate
history = [(float(eps), float(loss(eps, pos, L_box)))]
for step in range(40):
L_val, g = loss_and_grad(eps, pos, L_box)
eps = eps - lr * g
history.append((float(eps), float(L_val)))
eps_traj, loss_traj = zip(*history)
final_E = float(lj_energy_eps(pos, L_box, eps))
fig, axes = plt.subplots(1, 2, figsize=(11, 3.8))
axes[0].plot(eps_traj, '-o', color='steelblue', ms=4)
axes[0].axhline(eps_traj[-1], color='r', ls='--', lw=0.8,
label=fr'$\varepsilon^\star = {eps_traj[-1]:.4f}$')
axes[0].set(xlabel='gradient-descent step', ylabel=r'$\varepsilon$',
title='Parameter trajectory')
axes[0].legend(); axes[0].grid(True, ls=':', alpha=0.5)
axes[1].semilogy(loss_traj, '-o', color='crimson', ms=4)
axes[1].set(xlabel='gradient-descent step', ylabel='loss',
title='Loss (log scale)')
axes[1].grid(True, ls=':', alpha=0.5, which='both')
fig.tight_layout()
print(f'\nFinal eps = {eps_traj[-1]:.4f}, E = {final_E:.4f} (target {E_target})')Takeaways¶
Stop writing forces, write energies.
grad(energy)gives forces to machine precision. The same one-line mechanism gives pressure (gradient w.r.t. box scaling) and parameter sensitivities (gradient w.r.t. ).vmapparallelizes for free. Wrapvv_stepwithvmapand integrate many trajectories simultaneously without writing a batch loop.Differentiable simulation enables fitting. With
value_and_grad, a gradient-descent loop over potential parameters becomes a few lines. This is the foundation of modern force-field optimization, including neural-network potentials trained against ab-initio reference data.jitmakes it fast. First call traces and compiles; subsequent calls run as optimized XLA. On a GPU the same code scales to thousands of particles unchanged.
Problems¶
Replace LJ with WCA. WCA is the purely repulsive part of LJ: for r < r_\min = 2^{1/6}\sigma, and 0 otherwise. Modify
lj_energyaccordingly. Verify that forces fromgradare continuous at the cutoff (unlike LJ truncated at , which has a force jump there).Where finite differences fail. Compute at via both
jax.gradand central finite differences. Make a log-log plot of vs. step size . You should see the classic V-shape: truncation error at large , round-off error at small . Autodiff has neither.Caloric curve from a batched temperature scan. Build a batch where each trajectory is initialized at a different temperature (rescale velocities accordingly inside
init_batch_member). Run them in parallel withvmap, time-average the potential energy after equilibration, and plot vs. — the LJ caloric curve.
References¶
S. S. Schoenholz and E. D. Cubuk, JAX, M.D.: A framework for differentiable physics, NeurIPS 2020. github
.com /jax -md /jax -md A. G. Baydin, B. A. Pearlmutter, A. A. Radul, J. M. Siskind, Automatic differentiation in machine learning: a survey, J. Mach. Learn. Res. 18, 1 (2018). arxiv:1502.05767