In this tutorial, we explore Equinox, a lightweight and elegant neural network library built on JAX, and show how to use it. We begin by understanding how eqx.Module treats models as PyTrees, which makes parameter handling, transformation, and serialization feel simple and explicit. As we move forward, we work through static fields, filtered transformations such as filter_jit and filter_grad, PyTree manipulation utilities, stateful layers such as BatchNorm, and a complete end-to-end training workflow for a toy regression problem.
Throughout the tutorial, we focus on writing clear, executable code that demonstrates not only how Equinox works but also why it fits so well into the JAX ecosystem for research and practical experimentation. Copy CodeCopiedUse a different Browser!pip install equinox optax jaxtyping matplotlib -q import jax import jax.numpy as jnp import equinox as eqx import optax from jaxtyping import Array, Float, Int, PRNGKeyArray from typing import Optional import matplotlib.pyplot as plt import time print(f"JAX version : {jax.__version__}") print(f"Equinox version: {eqx.__version__}") print(f"Devices : {jax.devices()}") print("\n" + "="*60) print("SECTION 1: eqx.Module basics") print("="*60) class Linear(eqx.Module): weight: Float[Array, "out in"] bias: Float[Array, "out"] def __init__(self, in_size: int, out_size: int, *, key: PRNGKeyArray): wkey, bkey = jax.random.split(key) self.weight = jax.random.normal(wkey, (out_size, in_size)) * 0.1 self.bias = jax.random.normal(bkey, (out_size,)) * 0.01 def __call__(self, x: Float[Array, "in"]) -> Float[Array, "out"]: return self.weight @ x + self.bias key = jax.random.PRNGKey(0) lin = Linear(4, 2, key=key) leaves, treedef = jax.tree_util.tree_flatten(lin) print("Leaves shapes:", [l.shape for l in leaves]) print("Treedef:", treedef) print("\n" + "="*60) print("SECTION 2: Static fields") print("="*60) class Conv1dBlock(eqx.Module): conv: eqx.nn.Conv1d norm: eqx.nn.LayerNorm activation: str = eqx.field(static=True) def __init__(self, channels: int, kernel: int, activation: str, *, key: PRNGKeyArray): self.conv = eqx.nn.Conv1d(channels, channels, kernel, padding="same", key=key) self.norm = eqx.nn.LayerNorm((channels,)) self.activation = activation def __call__(self, x: Float[Array, "C L"]) -> Float[Array, "C L"]: x = self.conv(x) x = jax.vmap(self.norm)(x.T).T if self.activation == "relu": return jax.nn.relu(x) elif self.activation == "gelu": return jax.nn.gelu(x) return x key, subkey = jax.random.split(key) block = Conv1dBlock(8, 3, "gelu", key=subkey) x_seq = jnp.ones((8, 16)) out = block(x_seq) print(f"Conv1dBlock output shape: {out.shape}") We set up the full Equinox environment by installing the required libraries and importing JAX, Equinox, Optax, Jaxtyping, Matplotlib, and other essentials.
We immediately verify the runtime by printing the JAX and Equinox versions and the available devices, which helps us confirm that our Colab environment is ready for execution. We then begin with the foundations of Equinox by defining a simple Linear module, creating an instance of it, and inspecting its PyTree leaves and structure before introducing a Conv1dBlock that demonstrates how static fields and learnable layers work together in practice. Copy CodeCopiedUse a different Browserprint("\n" + "="*60) print("SECTION 3: Filtered transforms") print("="*60) class MLP(eqx.Module): layers: list dropout: eqx.nn.Dropout def __init__(self, in_size, hidden, out_size, *, key: PRNGKeyArray): k1, k2, k3 = jax.random.split(key, 3) self.layers = [ eqx.nn.Linear(in_size, hidden, key=k1), eqx.nn.Linear(hidden, hidden, key=k2), eqx.nn.Linear(hidden, out_size, key=k3), ] self.dropout = eqx.nn.Dropout(p=0.1) def __call__(self, x: Float[Array, "in"], *, key: Optional[PRNGKeyArray] = None) -> Float[Array, "out"]: for layer in self.layers[:-1]: x = jax.nn.relu(layer(x)) if key is not None: key, subkey = jax.random.split(key) x = self.dropout(x, key=subkey) return self.layers[-1](x) key, mk = jax.random.split(key) mlp = MLP(8, 32, 4, key=mk) @eqx.filter_jit def forward(model, x, *, key): return model(x, key=key) x_in = jnp.ones((8,)) key, fk = jax.random.split(key) y_out = forward(mlp, x_in, key=fk) print(f"MLP output: {y_out}") @eqx.filter_jit def loss_fn(model: MLP, x: Float[Array, "B in"], y: Float[Array, "B out"], key: PRNGKeyArray) -> Float[Array, ""]: keys = jax.random.split(key, x.shape[0]) preds = jax.vmap(model)(x) return jnp.mean((preds - y) ** 2) grad_fn = eqx.filter_grad(loss_fn) key, dk = jax.random.split(key) X = jax.random.normal(dk, (16, 8)) Y = jax.random.normal(dk, (16, 4)) grads = grad_fn(mlp, X, Y, dk) print(f"Grad of first layer weight: shape={grads.layers[0].weight.shape}, norm={jnp.linalg.norm(grads.layers[0].weight):.4f}") We focus on Equinox’s filtered transformations by building an MLP that includes both linear layers and dropout. We use filter_jit to compile the forward pass while allowing the model to contain non-array fields, and we use filter_grad to