Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Differentiable MD with JAX

Why JAX for molecular simulation

In md_lj.ipynb you wrote an LJ MD code by hand. Half the complexity sat in one line:

f_over_r = 48 * r6_inv * (r6_inv - 0.5) * r2_inv

That’s the analytical derivative of the LJ pair potential. We had to derive it. Change the potential — to WCA, Morse, Stillinger–Weber, or a learned neural-network potential — and we’d derive a new force formula and re-implement it. Fifty years of molecular simulation codes have lived with this friction.

JAX removes it. Define the potential energy as a pure function of positions; ask JAX for its gradient. That gradient is the force, evaluated to machine precision and jit-compiled into fast XLA code.

This is more than coding convenience. Once everything is differentiable, parameters are too: E/ε\partial E/\partial \varepsilon, A/θ\partial \langle A\rangle / \partial \theta for any potential parameter θ\theta. Force-field fitting, sensitivity analysis, and end-to-end differentiable simulation become one-liners. This notebook works through the basics; production frameworks like JAX-MD extend the same pattern to neighbor lists, neural-network potentials, and full GPU pipelines.

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap

# JAX defaults to float32; enable float64 so derivatives match numpy
# to full double precision (matters for the autodiff sanity checks below).
jax.config.update('jax_enable_x64', True)

rng = np.random.default_rng(0)

The energy is the only thing you write

The full LJ pair potential

u(r)  =  4ε ⁣[(σr)12(σr)6]u(r) \;=\; 4\varepsilon\!\left[\left(\frac{\sigma}{r}\right)^{12} - \left(\frac{\sigma}{r}\right)^6\right]

is implemented as a pure function of positions and box length. Periodic boundaries enter through the minimum-image convention rijrijLround(rij/L)r_{ij} \to r_{ij} - L\,\text{round}(r_{ij}/L), and pairs beyond the cutoff rcr_c are masked. Forces are then just grad(energy).

def lj_energy(pos, L, rcut=2.5):
    diff = pos[:, None, :] - pos[None, :, :]
    diff = diff - L * jnp.round(diff / L)             # minimum image
    r2 = jnp.sum(diff ** 2, axis=-1)
    mask = (r2 < rcut * rcut) & (r2 > 1e-6)
    safe_r2 = jnp.where(mask, r2, 1.0)
    inv2 = 1.0 / safe_r2
    inv6 = inv2 ** 3
    pair_u = 4.0 * (inv6 ** 2 - inv6)
    return 0.5 * jnp.sum(jnp.where(mask, pair_u, 0.0))   # 0.5: pair sum is symmetric


# Forces fall out of autodiff. No hand derivation.
lj_force = jit(lambda pos, L: -grad(lj_energy)(pos, L))


# Velocity Verlet, jit-compiled
@jit
def vv_step(pos, vel, F, dt, L):
    vel_half = vel + 0.5 * dt * F
    pos_new = (pos + dt * vel_half) % L
    F_new = lj_force(pos_new, L)
    vel_new = vel_half + 0.5 * dt * F_new
    return pos_new, vel_new, F_new

A short NVE run from FCC initial conditions

Standard MD setup: 4-atom-basis FCC lattice in a cubic box, Maxwell–Boltzmann velocities centered at zero and rescaled to a target TT, then velocity-Verlet integration. Energy conservation is the basic test that the integrator and force routine are mutually consistent.

def fcc_init(N_cell=3, rho=0.7, T=1.0, seed=0):
    N = 4 * N_cell ** 3
    L = (N / rho) ** (1 / 3)
    base = np.array([[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]])
    pos = np.array([b + [i, j, k]
                    for i in range(N_cell)
                    for j in range(N_cell)
                    for k in range(N_cell)
                    for b in base]) * (L / N_cell)
    r = np.random.default_rng(seed)
    vel = r.standard_normal((N, 3))
    vel -= vel.mean(axis=0)
    vel *= np.sqrt(3 * N * T / np.sum(vel ** 2))      # rescale to target T
    return pos, vel, L


pos_np, vel_np, L_box = fcc_init()
pos = jnp.asarray(pos_np)
vel = jnp.asarray(vel_np)
F = lj_force(pos, L_box)                              # warm up the JIT

n_steps, save_every, dt = 800, 20, 0.005
energies = []
for step in range(n_steps):
    pos, vel, F = vv_step(pos, vel, F, dt, L_box)
    if step % save_every == 0:
        E = float(lj_energy(pos, L_box) + 0.5 * jnp.sum(vel ** 2))
        energies.append(E)

t_axis = np.arange(len(energies)) * save_every * dt
fig, ax = plt.subplots(figsize=(7, 3.5))
ax.plot(t_axis, energies, '-o', color='seagreen', ms=4)
ax.set(xlabel='time (LJ units)', ylabel='total energy',
       title='JAX MD (NVE): forces from autodiff, jit-compiled')
ax.grid(True, ls=':', alpha=0.5)
fig.tight_layout()

Three concrete demos: forces, pressure, parameter sensitivity

The whole MD engine fits in ~15 lines, but does it work? Three checks make the autodiff payoff visible: forces fall out, pressure falls out, and parameter sensitivities fall out — all from the same one-line mechanism.

The trivial FCC starting point has zero forces by symmetry, so we use the JAX-evolved positions from the run above. We also need a small numpy reference implementation of the LJ force routine for the comparison.

# Hand-derived LJ forces in numpy (the routine you wrote in md_lj)
def lj_forces_numpy(pos, L, rcut=2.5):
    N = len(pos)
    I, J = np.triu_indices(N, k=1)
    r_vec = pos[I] - pos[J]
    r_vec -= L * np.round(r_vec / L)
    r_sq = np.sum(r_vec ** 2, axis=1)
    mask = r_sq < rcut ** 2
    r_vec, r_sq = r_vec[mask], r_sq[mask]
    r2_inv = 1.0 / r_sq
    r6_inv = r2_inv ** 3
    f_over_r = 48 * r6_inv * (r6_inv - 0.5) * r2_inv
    F_ij = f_over_r[:, None] * r_vec
    forces = np.zeros_like(pos)
    np.add.at(forces, I[mask],  F_ij)
    np.add.at(forces, J[mask], -F_ij)
    return forces, np.sum(F_ij * r_vec)               # forces and virial sum


# === Demo 1: forces from grad(energy) match the hand-derived numpy code ===
F_jax = -np.asarray(grad(lj_energy)(pos, L_box))
F_np, virial_np = lj_forces_numpy(np.asarray(pos), L_box)

print('=== Demo 1: forces ===')
print(f'  max |F_jax - F_numpy| = {np.max(np.abs(F_jax - F_np)):.3e}')
print(f'  typical |F|           = {np.abs(F_np).mean():.3e}')
print(f'  -> autodiff agrees with hand-derived forces to machine precision.')


# === Demo 2: configurational pressure as a *different* gradient ===
# Scaling pos and L by lambda multiplies all separations by lambda.
# dE/dlambda at lambda=1 = -W (virial sum), so p_config = -dE/dlambda / (3V).
def E_at_scale(scale, pos, L):
    return lj_energy(scale * pos, scale * L)

dE_dscale = float(jax.grad(E_at_scale)(1.0, pos, L_box))
V = L_box ** 3
p_config_jax = -dE_dscale / (3 * V)
p_config_np  = virial_np / (3 * V)

print('\n=== Demo 2: configurational pressure ===')
print(f'  p_config (autodiff): {p_config_jax:.6f}')
print(f'  p_config (virial):   {p_config_np:.6f}')
print(f'  -> same number, two completely different code paths.')

Differentiating through parameters of the potential

Demos 1 and 2 took gradients with respect to positions. Nothing in JAX cares whether the input is a coordinate or a parameter — make σ\sigma a function argument and E/σ\partial E/\partial \sigma is one grad call away.

This is the workhorse of force-field fitting and sensitivity analysis. The plot below sweeps σ\sigma at fixed configuration (a non-linear curve) and overlays the autodiff-computed tangent at σ=1\sigma = 1.

def lj_energy_sigma(pos, L, sigma, eps=1.0, rcut=2.5):
    diff = pos[:, None, :] - pos[None, :, :]
    diff = diff - L * jnp.round(diff / L)
    r2 = jnp.sum(diff ** 2, axis=-1)
    mask = (r2 < rcut * rcut) & (r2 > 1e-6)
    safe_r2 = jnp.where(mask, r2, 1.0)
    sr2 = sigma * sigma / safe_r2
    sr6 = sr2 ** 3
    pair_u = 4.0 * eps * (sr6 ** 2 - sr6)
    return 0.5 * jnp.sum(jnp.where(mask, pair_u, 0.0))


sigma_grid = np.linspace(0.85, 1.15, 40)
E_grid = np.array([float(lj_energy_sigma(pos, L_box, s)) for s in sigma_grid])

sigma0 = 1.0
E0 = float(lj_energy_sigma(pos, L_box, sigma0))
dE_dsigma = float(jax.grad(lj_energy_sigma, argnums=2)(pos, L_box, sigma0))

print(f'E(sigma=1)      = {E0:.4f}')
print(f'dE/dsigma at 1  = {dE_dsigma:.4f}  (one grad call, no finite differences)')

fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(sigma_grid, E_grid, color='steelblue', lw=2,
        label=r'$E(\sigma)$ at fixed positions')
ax.scatter([sigma0], [E0], color='crimson', s=50, zorder=5)
tan_s = np.array([sigma0 - 0.06, sigma0 + 0.06])
ax.plot(tan_s, E0 + dE_dsigma * (tan_s - sigma0), color='crimson', lw=2.2,
        label=fr'tangent from autodiff, slope $= {dE_dsigma:.2f}$')
ax.set(xlabel=r'$\sigma$', ylabel='total LJ energy',
       title=r'$E(\sigma)$: autodiff gives the local slope for free')
ax.legend()
ax.grid(True, ls=':', alpha=0.5)
fig.tight_layout()

Many simulations at once: vmap

JAX’s vmap adds a new leading “batch” dimension to any function. Wrap vv_step with vmap and you can integrate many independent trajectories in parallel — different initial velocities, different temperatures, different random seeds — without writing a Python loop. On a GPU this is a massive speedup; even on a CPU it amortizes JIT overhead and removes the Python interpreter from the inner loop over batch elements.

# Build a batch of 8 trajectories with different initial velocities
n_batch = 8
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, n_batch)


def init_batch_member(seed_key):
    pos0 = jnp.asarray(pos_np)                                   # same FCC start
    v = jax.random.normal(seed_key, (pos0.shape[0], 3))          # random velocity
    v = v - v.mean(axis=0)
    v = v * jnp.sqrt(3 * pos0.shape[0] * 1.0 / jnp.sum(v ** 2))  # rescale to T = 1
    F0 = lj_force(pos0, L_box)
    return pos0, v, F0


batch_state = vmap(init_batch_member)(keys)

# vmap'd integrator step: maps over the leading (batch) axis of pos / vel / F
batched_step = jit(vmap(vv_step, in_axes=(0, 0, 0, None, None)))

n_batch_steps = 600
for step in range(n_batch_steps):
    batch_state = batched_step(*batch_state, dt, L_box)

batch_pos, batch_vel, _ = batch_state
batched_energy = jit(vmap(lj_energy, in_axes=(0, None)))
final_PE = batched_energy(batch_pos, L_box)
final_KE = 0.5 * jnp.sum(batch_vel ** 2, axis=(1, 2))
final_E_total = np.asarray(final_PE + final_KE)

fig, ax = plt.subplots(figsize=(7, 3.5))
ax.bar(np.arange(n_batch), final_E_total, color='cornflowerblue', edgecolor='k')
ax.set(xlabel='trajectory index', ylabel='total energy after run',
       title=f'{n_batch} independent NVE trajectories run in parallel via vmap')
ax.grid(True, ls=':', alpha=0.4)
fig.tight_layout()
print('Final energies (independent runs, each NVE-conserved):',
      np.round(final_E_total, 3))

A toy fitting problem: tuning a parameter by gradient descent

The endgame of differentiable simulation: take the gradient of a loss (against experimental data, against a target ensemble average, …) with respect to potential parameters, and use it to fit. To keep the demo fast we use the simplest target imaginable: tune ε\varepsilon so that the LJ energy at this configuration equals a chosen value EE_\star. This is a one-parameter quadratic problem with a closed-form answer (EE is linear in ε\varepsilon), but the machineryvalue_and_grad, gradient-descent loop — is exactly what you’d use for a real fit.

def lj_energy_eps(pos, L, eps, sigma=1.0, rcut=2.5):
    diff = pos[:, None, :] - pos[None, :, :]
    diff = diff - L * jnp.round(diff / L)
    r2 = jnp.sum(diff ** 2, axis=-1)
    mask = (r2 < rcut * rcut) & (r2 > 1e-6)
    safe_r2 = jnp.where(mask, r2, 1.0)
    sr2 = sigma * sigma / safe_r2
    sr6 = sr2 ** 3
    pair_u = 4.0 * eps * (sr6 ** 2 - sr6)
    return 0.5 * jnp.sum(jnp.where(mask, pair_u, 0.0))


E_target = -700.0


def loss(eps, pos, L):
    return (lj_energy_eps(pos, L, eps) - E_target) ** 2


loss_and_grad = jit(jax.value_and_grad(loss))

eps = 0.5             # starting guess
lr  = 1e-6            # learning rate
history = [(float(eps), float(loss(eps, pos, L_box)))]
for step in range(40):
    L_val, g = loss_and_grad(eps, pos, L_box)
    eps = eps - lr * g
    history.append((float(eps), float(L_val)))

eps_traj, loss_traj = zip(*history)
final_E = float(lj_energy_eps(pos, L_box, eps))

fig, axes = plt.subplots(1, 2, figsize=(11, 3.8))
axes[0].plot(eps_traj, '-o', color='steelblue', ms=4)
axes[0].axhline(eps_traj[-1], color='r', ls='--', lw=0.8,
                label=fr'$\varepsilon^\star = {eps_traj[-1]:.4f}$')
axes[0].set(xlabel='gradient-descent step', ylabel=r'$\varepsilon$',
            title='Parameter trajectory')
axes[0].legend(); axes[0].grid(True, ls=':', alpha=0.5)
axes[1].semilogy(loss_traj, '-o', color='crimson', ms=4)
axes[1].set(xlabel='gradient-descent step', ylabel='loss',
            title='Loss (log scale)')
axes[1].grid(True, ls=':', alpha=0.5, which='both')
fig.tight_layout()
print(f'\nFinal eps = {eps_traj[-1]:.4f}, E = {final_E:.4f}  (target {E_target})')

Takeaways

  • Stop writing forces, write energies. grad(energy) gives forces to machine precision. The same one-line mechanism gives pressure (gradient w.r.t. box scaling) and parameter sensitivities (gradient w.r.t. ε,σ,\varepsilon, \sigma, \dots).

  • vmap parallelizes for free. Wrap vv_step with vmap and integrate many trajectories simultaneously without writing a batch loop.

  • Differentiable simulation enables fitting. With value_and_grad, a gradient-descent loop over potential parameters becomes a few lines. This is the foundation of modern force-field optimization, including neural-network potentials trained against ab-initio reference data.

  • jit makes it fast. First call traces and compiles; subsequent calls run as optimized XLA. On a GPU the same code scales to thousands of particles unchanged.

Problems

  1. Replace LJ with WCA. WCA is the purely repulsive part of LJ: uWCA(r)=uLJ(r)+εu_\text{WCA}(r) = u_\text{LJ}(r) + \varepsilon for r < r_\min = 2^{1/6}\sigma, and 0 otherwise. Modify lj_energy accordingly. Verify that forces from grad are continuous at the cutoff (unlike LJ truncated at rc=2.5r_c = 2.5, which has a force jump there).

  2. Where finite differences fail. Compute E/σ\partial E/\partial \sigma at σ=1\sigma = 1 via both jax.grad and central finite differences. Make a log-log plot of autodiffFD|\,\text{autodiff} - \text{FD}\,| vs. step size hh. You should see the classic V-shape: truncation error at large hh, round-off error at small hh. Autodiff has neither.

  3. Caloric curve from a batched temperature scan. Build a batch where each trajectory is initialized at a different temperature T{0.5,0.7,1.0,1.5,2.0}T \in \{0.5, 0.7, 1.0, 1.5, 2.0\} (rescale velocities accordingly inside init_batch_member). Run them in parallel with vmap, time-average the potential energy after equilibration, and plot U/N\langle U\rangle / N vs. TT — the LJ caloric curve.

References

  • S. S. Schoenholz and E. D. Cubuk, JAX, M.D.: A framework for differentiable physics, NeurIPS 2020. github.com/jax-md/jax-md

  • A. G. Baydin, B. A. Pearlmutter, A. A. Radul, J. M. Siskind, Automatic differentiation in machine learning: a survey, J. Mach. Learn. Res. 18, 1 (2018). arxiv:1502.05767