Skip to main content
In this example, we’ll learn to use mlx.nn by implementing a simple multi-layer perceptron (MLP) to classify MNIST digits. This demonstrates how to create custom modules, define training loops, and use MLX’s built-in optimizers.

Setup

First, import the MLX packages we need:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

import numpy as np

Define the Model

The model is defined as the MLP class which inherits from mlx.nn.Module. We follow the standard idiom for creating a new module:
1

Define __init__

Set up parameters and submodules. The mlx.nn.Module base class automatically registers parameters.
2

Define __call__

Implement the forward computation.
class MLP(nn.Module):
    def __init__(
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = mx.maximum(l(x), 0.0)
        return self.layers[-1](x)
This creates a network with:
  • Multiple hidden layers with ReLU activation (using mx.maximum)
  • A final output layer without activation

Loss and Evaluation Functions

Define the loss function which takes the mean of the per-example cross-entropy loss:
def loss_fn(model, X, y):
    return mx.mean(nn.losses.cross_entropy(model(X), y))
The mlx.nn.losses subpackage provides implementations of commonly used loss functions. We also need a function to compute accuracy on the validation set:
def eval_fn(model, X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)

Setup and Data Loading

Configure the problem parameters and load the MNIST data:
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1

# Load the data
import mnist
train_images, train_labels, test_images, test_labels = map(
    mx.array, mnist.mnist()
)
You’ll need the MNIST data loader from the mlx-examples repository.

Batch Iterator

Since we’re using SGD, we need an iterator that shuffles and constructs minibatches:
def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

Training Loop

Put it all together by instantiating the model, optimizer, and running the training loop:
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())

# Get a function which gives the loss and gradient of the
# loss with respect to the model's trainable parameters
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Instantiate the optimizer
optimizer = optim.SGD(learning_rate=learning_rate)

for e in range(num_epochs):
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        loss, grads = loss_and_grad_fn(model, X, y)

        # Update the optimizer state and model parameters
        # in a single call
        optimizer.update(model, grads)

        # Force a graph evaluation
        mx.eval(model.parameters(), optimizer.state)

    accuracy = eval_fn(model, test_images, test_labels)
    print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
The mlx.nn.value_and_grad() function is a convenience function to get the gradient of a loss with respect to the trainable parameters of a model. This should not be confused with mlx.core.value_and_grad().

Expected Results

The model should train to about 95% accuracy after just a few passes over the training set:
Epoch 0: Test accuracy 0.891
Epoch 1: Test accuracy 0.921
Epoch 2: Test accuracy 0.934
Epoch 3: Test accuracy 0.942
Epoch 4: Test accuracy 0.947
...

Complete Example

The full MNIST example is available in the MLX examples repository.