Ising-2D simulations using ML libraries#

import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm
Pytorch implementation#

def ising2d_torch(N=20, 
    Metropolis Monte Carlo simulator for the 2D Ising model using PyTorch.

    N (int): dimension of spin lattice, returning N*N spins
    T (float): Temperature.
    J (float): Interaction strength between spins.
    B (float): External magnetic field.
    n_steps (int): Number of Monte Carlo steps.
    out_freq (int): Output frequency for saving spin configurations, energy, and magnetization.
    tuple: Numpy arrays of spin configurations, energies, and magnetizations.
    device = torch.device("cpu") # For small lattices that we simulate CPU is more efficient

    # Initialize spins
    spins = torch.randint(0, 2, (N, N), device=device) * 2 - 1
    M_t = spins.sum()
    neighbors  = spins.roll(1, dims=0) + spins.roll(1, dims=1) 
    E_t        = -J * (spins * neighbors).sum() - M_t*B
    S, E, M = [], [], []

    for step in range(n_steps):

        i, j = torch.randint(0, N, (2,), device=device)

        z = spins[(i + 1) % N, j] + spins[(i - 1) % N, j] + spins[i, (j + 1) % N] + spins[i, (j - 1) % N]

        dE = 2 * spins[i, j] * (J * z + B)
        dM = 2 * spins[i, j]

        if torch.rand(1, device=device) < torch.exp(-dE / T):
            spins[i, j] *= -1
            E_t += dE
            M_t += dM

        if step % out_freq == 0:
            E.append(E_t / N**2)
            M.append(M_t / N**2)

    return np.array(S), np.array(E), np.array(M)

Jax implementation#

import jax
import jax.numpy as jnp
from jax import random, jit

def metropolis_update(spins, i, j, J, B, T, key):
    '''Calculate the energy change for a proposed spin flip at position (i, j) and decide whether to accept it.'''
    N = spins.shape[0]
    spin = spins[i, j]
    # Manual periodic boundary conditions
    neighbors = (
        spins[(i+1) % N, j] + spins[(i-1) % N, j] +
        spins[i, (j+1) % N] + spins[i, (j-1) % N]
    delta_E = 2 * spin * (J * neighbors + B)
    accept = random.uniform(key) < jnp.exp(-delta_E / T)
    return accept, delta_E

def metropolis_step(state, key):
    spins, J, B, T = state
    N = spins.shape[0]

    # Split key for the operation
    key, subkey = random.split(key)

    # Pick a random spin
    i, j = random.randint(subkey, (2,), 0, N)
    accept, _ = metropolis_update(spins, i, j, J, B, T, subkey)
    # Apply the Metropolis condition
    spins = jax.numpy.where(accept,[i, j].set(-spins[i, j]), spins)

    return (spins, J, B, T), key  # Pass the new key along

def simulate_ising(N, T, J=1.0, B=0.0, n_steps=10000, seed=0):
    key = random.PRNGKey(seed)
    spins = random.choice(key, jnp.array([-1, 1]), shape=(N, N))

    # Pack state
    state = (spins, J, B, T)
    keys = random.split(key, n_steps)  # Pre-split keys for each step

    # Use scan to apply metropolis_step across all steps
    final_state, _ = jax.lax.scan(metropolis_step, state, keys)

    return final_state[0]  # Return final spins only

Speed Test#

params = {'N':20,
          'T': 4,
          'n_steps': 10000, 
          'out_freq': 10}