Ising-2D simulations using ML libraries#
import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm
Pytorch implementation#
def ising2d_torch(N=20,
T=1.0,
J=1.0,
B=0.0,
n_steps=20000,
out_freq=10):
'''
Metropolis Monte Carlo simulator for the 2D Ising model using PyTorch.
Parameters:
N (int): dimension of spin lattice, returning N*N spins
T (float): Temperature.
J (float): Interaction strength between spins.
B (float): External magnetic field.
n_steps (int): Number of Monte Carlo steps.
out_freq (int): Output frequency for saving spin configurations, energy, and magnetization.
Returns:
tuple: Numpy arrays of spin configurations, energies, and magnetizations.
'''
device = torch.device("cpu") # For small lattices that we simulate CPU is more efficient
# Initialize spins
spins = torch.randint(0, 2, (N, N), device=device) * 2 - 1
M_t = spins.sum()
neighbors = spins.roll(1, dims=0) + spins.roll(1, dims=1)
E_t = -J * (spins * neighbors).sum() - M_t*B
S, E, M = [], [], []
for step in range(n_steps):
i, j = torch.randint(0, N, (2,), device=device)
z = spins[(i + 1) % N, j] + spins[(i - 1) % N, j] + spins[i, (j + 1) % N] + spins[i, (j - 1) % N]
dE = 2 * spins[i, j] * (J * z + B)
dM = 2 * spins[i, j]
if torch.rand(1, device=device) < torch.exp(-dE / T):
spins[i, j] *= -1
E_t += dE
M_t += dM
if step % out_freq == 0:
S.append(spins.clone())
E.append(E_t / N**2)
M.append(M_t / N**2)
return np.array(S), np.array(E), np.array(M)
Jax implementation#
import jax
import jax.numpy as jnp
from jax import random, jit
def metropolis_update(spins, i, j, J, B, T, key):
'''Calculate the energy change for a proposed spin flip at position (i, j) and decide whether to accept it.'''
N = spins.shape[0]
spin = spins[i, j]
# Manual periodic boundary conditions
neighbors = (
spins[(i+1) % N, j] + spins[(i-1) % N, j] +
spins[i, (j+1) % N] + spins[i, (j-1) % N]
)
delta_E = 2 * spin * (J * neighbors + B)
accept = random.uniform(key) < jnp.exp(-delta_E / T)
return accept, delta_E
@jit
def metropolis_step(state, key):
spins, J, B, T = state
N = spins.shape[0]
# Split key for the operation
key, subkey = random.split(key)
# Pick a random spin
i, j = random.randint(subkey, (2,), 0, N)
accept, _ = metropolis_update(spins, i, j, J, B, T, subkey)
# Apply the Metropolis condition
spins = jax.numpy.where(accept, spins.at[i, j].set(-spins[i, j]), spins)
return (spins, J, B, T), key # Pass the new key along
def simulate_ising(N, T, J=1.0, B=0.0, n_steps=10000, seed=0):
key = random.PRNGKey(seed)
spins = random.choice(key, jnp.array([-1, 1]), shape=(N, N))
# Pack state
state = (spins, J, B, T)
keys = random.split(key, n_steps) # Pre-split keys for each step
# Use scan to apply metropolis_step across all steps
final_state, _ = jax.lax.scan(metropolis_step, state, keys)
return final_state[0] # Return final spins only
Speed Test#
#Parameters
params = {'N':20,
'J':1,
'B':0,
'T': 4,
'n_steps': 10000,
'out_freq': 10}