Skip to main content
The mlx.nn module provides an intuitive way of composing neural network layers, initializing their parameters, and managing trainable weights for training and fine-tuning.

Quick Start

Here’s a simple example of building a multi-layer perceptron:
import mlx.core as mx
import mlx.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dims: int, out_dims: int):
        super().__init__()
        
        self.layers = [
            nn.Linear(in_dims, 128),
            nn.Linear(128, 128),
            nn.Linear(128, out_dims),
        ]
    
    def __call__(self, x):
        for i, l in enumerate(self.layers):
            x = mx.maximum(x, 0) if i > 0 else x
            x = l(x)
        return x

# Create model (parameters are not initialized yet due to lazy evaluation)
mlp = MLP(2, 10)

# Access parameters
params = mlp.parameters()
print(params["layers"][0]["weight"].shape)

# Force evaluate all parameters to initialize the model
mx.eval(mlp.parameters())

Module

nn.Module
class
Base class for building neural networks with MLX.All layers in mlx.nn.layers subclass this class. Your custom models should also inherit from Module.A Module can contain other Module instances or mlx.core.array instances in arbitrary nesting of Python lists or dicts.Key Features:
  • Recursively extracts all parameters using parameters()
  • Supports trainable and frozen parameters
  • Provides parameter state management

Module Methods

parameters()
method
Recursively return all the mlx.core.array members of this Module as a dict of dicts and lists.
model = nn.Linear(10, 10)
params = model.parameters()
trainable_parameters()
method
Recursively return all the non-frozen mlx.core.array members of this Module.
model = nn.Linear(10, 10)
model.freeze(keys="bias")  # Freeze bias
trainable = model.trainable_parameters()  # Only contains weights
update(parameters)
method
Replace the parameters of this Module with the provided ones.Commonly used by optimizers to update model parameters.
model.update(new_parameters)
freeze(recurse=True, keys=None, strict=False)
method
Freeze the Module’s parameters or some of them. Freezing means not computing gradients.Parameters:
  • recurse (bool): If True, freeze parameters of submodules as well. Default: True
  • keys (str or list[str]): If provided, only freeze these parameters. Default: None
  • strict (bool): If True, validate that the passed keys exist. Default: False
# Freeze all parameters
model.freeze()

# Only train attention parameters in Transformer
model = nn.Transformer()
model.freeze()
model.apply_to_modules(
    lambda k, v: v.unfreeze() if k.endswith("attention") else None
)
unfreeze(recurse=True, keys=None, strict=False)
method
Unfreeze the Module’s parameters or some of them.
# Only train biases
model = nn.Transformer()
model.freeze()
model.unfreeze(keys="bias")
train(mode=True)
method
Set the model in or out of training mode.Training mode applies to layers like Dropout which behave differently during training vs evaluation.
model.train()  # Training mode
model.eval()   # Evaluation mode
load_weights(file_or_weights, strict=True)
method
Update the model’s weights from a .npz, .safetensors file, or a list.Parameters:
  • file_or_weights: Path to weights file or list of (name, array) pairs
  • strict (bool): If True, checks that weights exactly match parameters. Default: True
# Load from file
model.load_weights("weights.npz")

# Load from .safetensors
model.load_weights("weights.safetensors")

# Load from list
weights = [("weight", mx.random.uniform(shape=(10, 10)))]
model.load_weights(weights, strict=False)
save_weights(file)
method
Save the model’s weights to a file.Supports .npz and .safetensors formats.
model.save_weights("model.safetensors")

Linear Layers

nn.Linear
class
Applies an affine transformation to the input: y = xW^T + bParameters:
  • input_dims (int): The dimensionality of the input features
  • output_dims (int): The dimensionality of the output features
  • bias (bool): If False, the layer will not use a bias. Default: True
The values are initialized from the uniform distribution U(-k, k) where k = 1/√(input_dims).
layer = nn.Linear(128, 64)
x = mx.random.normal((10, 128))
y = layer(x)  # Shape: (10, 64)
nn.Bilinear
class
Applies a bilinear transformation to the inputs: y_i = x1^T W_i x2 + b_iParameters:
  • input1_dims (int): The dimensionality of the first input
  • input2_dims (int): The dimensionality of the second input
  • output_dims (int): The dimensionality of the output
  • bias (bool): If False, no bias is used. Default: True
layer = nn.Bilinear(20, 30, 40)
x1 = mx.random.normal((128, 20))
x2 = mx.random.normal((128, 30))
y = layer(x1, x2)  # Shape: (128, 40)
nn.Identity
class
A placeholder identity operator that returns the input unchanged.
layer = nn.Identity()
x = mx.array([1, 2, 3])
y = layer(x)  # Returns x unchanged

Convolutional Layers

nn.Conv1d
class
Applies a 1-dimensional convolution over the multi-channel input sequence.Input shape: NLC where N=batch, L=sequence length, C=channelsParameters:
  • in_channels (int): Number of input channels
  • out_channels (int): Number of output channels
  • kernel_size (int): Size of the convolution filters
  • stride (int): Stride when applying the filter. Default: 1
  • padding (int): Positions to 0-pad the input. Default: 0
  • dilation (int): Dilation of the convolution. Default: 1
  • groups (int): Number of groups for the convolution. Default: 1
  • bias (bool): If True, add a learnable bias. Default: True
conv = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3)
x = mx.random.normal((8, 100, 16))  # (batch, length, channels)
y = conv(x)
nn.Conv2d
class
Applies a 2-dimensional convolution over the multi-channel input image.Input shape: NHWC where N=batch, H=height, W=width, C=channelsParameters:
  • in_channels (int): Number of input channels
  • out_channels (int): Number of output channels
  • kernel_size (int or tuple): Size of the convolution filters
  • stride (int or tuple): Stride when applying the filter. Default: 1
  • padding (int or tuple): Positions to 0-pad the input. Default: 0
  • dilation (int or tuple): Dilation of the convolution. Default: 1
  • groups (int): Number of groups for the convolution. Default: 1
  • bias (bool): If True, add a learnable bias. Default: True
conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3)
x = mx.random.normal((8, 224, 224, 3))  # (batch, height, width, channels)
y = conv(x)
nn.Conv3d
class
Applies a 3-dimensional convolution over the multi-channel input volume.Similar parameters to Conv2d but for 3D inputs.

Normalization Layers

nn.LayerNorm
class
Applies layer normalization on the inputs.Computes: y = (x - E[x]) / sqrt(Var[x] + eps) * γ + βParameters:
  • dims (int): The feature dimension to normalize over
  • eps (float): Small constant for numerical stability. Default: 1e-5
  • affine (bool): If True, learn an affine transform. Default: True
  • bias (bool): If True, include bias in affine transform. Default: True
layer = nn.LayerNorm(128)
x = mx.random.normal((10, 128))
y = layer(x)
nn.RMSNorm
class
Applies Root Mean Square normalization.Computes: y = x / sqrt(mean(x^2) + eps) * γParameters:
  • dims (int): The feature dimension to normalize over
  • eps (float): Small constant for numerical stability. Default: 1e-5
layer = nn.RMSNorm(128)
x = mx.random.normal((10, 128))
y = layer(x)
nn.BatchNorm
class
Applies batch normalization over the inputs.Parameters:
  • num_features (int): Number of features in the input
  • eps (float): Small constant for numerical stability. Default: 1e-5
  • momentum (float): Momentum for running statistics. Default: 0.1
  • affine (bool): If True, learn scale and shift. Default: True
  • track_running_stats (bool): Track running mean and variance. Default: True
bn = nn.BatchNorm(64)
x = mx.random.normal((32, 64))
y = bn(x)
nn.InstanceNorm
class
Applies instance normalization on the inputs.Parameters:
  • dims (int): Number of features in the input
  • eps (float): Small constant for numerical stability. Default: 1e-5
  • affine (bool): If True, learn scale and shift. Default: False
inorm = nn.InstanceNorm(dims=16)
x = mx.random.normal((8, 4, 4, 16))
y = inorm(x)
nn.GroupNorm
class
Applies group normalization on the inputs.Parameters:
  • num_groups (int): Number of groups to separate the channels into
  • dims (int): Number of features in the input
  • eps (float): Small constant for numerical stability. Default: 1e-5
  • affine (bool): If True, learn scale and shift. Default: True
gn = nn.GroupNorm(num_groups=8, dims=64)
x = mx.random.normal((4, 32, 32, 64))
y = gn(x)

Activation Functions

All activation functions are available as both functional APIs and Module classes.
nn.ReLU
class
Applies the Rectified Linear Unit: max(0, x)
relu = nn.ReLU()
x = mx.array([-1, 0, 1, 2])
y = relu(x)  # [0, 0, 1, 2]
nn.LeakyReLU
class
Applies Leaky ReLU: max(negative_slope * x, x)Parameters:
  • negative_slope (float): Controls the angle of the negative slope. Default: 0.01
leaky_relu = nn.LeakyReLU(negative_slope=0.1)
y = leaky_relu(x)
nn.ReLU6
class
Applies ReLU6: min(max(0, x), 6)
relu6 = nn.ReLU6()
y = relu6(x)
nn.GELU
class
Applies the Gaussian Error Linear Unit.Parameters:
  • approx (str): Approximation to use: ‘none’, ‘precise’, ‘tanh’, or ‘fast’. Default: ‘none’
gelu = nn.GELU(approx='none')
y = gelu(x)

# Fast approximation
gelu_fast = nn.GELU(approx='fast')
y = gelu_fast(x)
nn.SiLU
class
Applies the Sigmoid Linear Unit (Swish): x * σ(x)
silu = nn.SiLU()
y = silu(x)
nn.Sigmoid
class
Applies the Sigmoid function: 1 / (1 + exp(-x))
sigmoid = nn.Sigmoid()
y = sigmoid(x)
nn.Tanh
class
Applies the hyperbolic tangent function.
tanh = nn.Tanh()
y = tanh(x)
nn.Softmax
class
Applies the Softmax function: exp(x_i) / sum(exp(x_j))
softmax = nn.Softmax()
y = softmax(x)
nn.LogSoftmax
class
Applies the Log Softmax function.
log_softmax = nn.LogSoftmax()
y = log_softmax(x)
nn.ELU
class
Applies the Exponential Linear Unit.Parameters:
  • alpha (float): The α value for ELU. Default: 1.0
elu = nn.ELU(alpha=1.0)
y = elu(x)
nn.SELU
class
Applies the Scaled Exponential Linear Unit.
selu = nn.SELU()
y = selu(x)
nn.PReLU
class
Applies the parametric ReLU: max(0, x) + a * min(0, x)Parameters:
  • num_parameters (int): Number of α parameters to learn. Default: 1
  • init (float): Initial value of α. Default: 0.25
prelu = nn.PReLU(num_parameters=1, init=0.25)
y = prelu(x)
nn.Mish
class
Applies the Mish function: x * tanh(softplus(x))
mish = nn.Mish()
y = mish(x)
nn.Hardswish
class
Applies the Hardswish function: x * min(max(x + 3, 0), 6) / 6
hardswish = nn.Hardswish()
y = hardswish(x)
nn.GLU
class
Applies the Gated Linear Unit: a * σ(b) where a, b are split from input.Parameters:
  • axis (int): The dimension to split along. Default: -1
glu = nn.GLU(axis=-1)
x = mx.random.normal((10, 20))
y = glu(x)  # Shape: (10, 10)

Dropout and Regularization

nn.Dropout
class
Randomly zero elements during training with probability p.Parameters:
  • p (float): Probability of zeroing an element. Default: 0.5
dropout = nn.Dropout(p=0.5)
model.train()  # Dropout is active
y = dropout(x)

model.eval()   # Dropout is inactive
y = dropout(x)  # Returns x unchanged
nn.Dropout2d
class
Randomly zero entire channels during training.Parameters:
  • p (float): Probability of zeroing a channel. Default: 0.5
dropout2d = nn.Dropout2d(p=0.5)
x = mx.random.normal((8, 224, 224, 64))
y = dropout2d(x)
nn.Dropout3d
class
Randomly zero entire channels during training (3D version).Parameters:
  • p (float): Probability of zeroing a channel. Default: 0.5

Embedding Layers

nn.Embedding
class
Simple lookup table that maps indices to embeddings.Parameters:
  • num_embeddings (int): Size of the dictionary of embeddings
  • dims (int): Dimension of each embedding vector
embedding = nn.Embedding(num_embeddings=1000, dims=128)
indices = mx.array([1, 5, 10])
embeddings = embedding(indices)  # Shape: (3, 128)

Recurrent Layers

nn.RNN
class
Applies a multi-layer Elman RNN with tanh or ReLU activation.Parameters:
  • input_size (int): Dimension of input features
  • hidden_size (int): Dimension of hidden state
  • num_layers (int): Number of recurrent layers. Default: 1
  • nonlinearity (str): Activation function (‘tanh’ or ‘relu’). Default: 'tanh'
  • bias (bool): If False, no bias is used. Default: True
rnn = nn.RNN(input_size=128, hidden_size=256, num_layers=2)
x = mx.random.normal((10, 32, 128))  # (seq_len, batch, input_size)
output, hidden = rnn(x)
nn.GRU
class
Applies a multi-layer Gated Recurrent Unit.Parameters:
  • input_size (int): Dimension of input features
  • hidden_size (int): Dimension of hidden state
  • num_layers (int): Number of recurrent layers. Default: 1
  • bias (bool): If False, no bias is used. Default: True
gru = nn.GRU(input_size=128, hidden_size=256, num_layers=2)
x = mx.random.normal((10, 32, 128))
output, hidden = gru(x)
nn.LSTM
class
Applies a multi-layer Long Short-Term Memory.Parameters:
  • input_size (int): Dimension of input features
  • hidden_size (int): Dimension of hidden state
  • num_layers (int): Number of recurrent layers. Default: 1
  • bias (bool): If False, no bias is used. Default: True
lstm = nn.LSTM(input_size=128, hidden_size=256, num_layers=2)
x = mx.random.normal((10, 32, 128))
output, (hidden, cell) = lstm(x)

Container Layers

nn.Sequential
class
A container that runs modules sequentially.
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)
x = mx.random.normal((32, 784))
y = model(x)

Utilities

nn.value_and_grad
function
Transform a function to compute gradients with respect to the model’s trainable parameters.Parameters:
  • model (nn.Module): The model whose trainable parameters to compute gradients for
  • fn (Callable): The scalar function to compute gradients for
def loss_fn(x, y):
    y_hat = model(x)
    return mx.mean((y_hat - y) ** 2)

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(x_batch, y_batch)

# Update model
optimizer.update(model, grads)
nn.average_gradients
function
Average gradients across distributed processes.Parameters:
  • gradients: Python tree containing the gradients
  • group: Group of processes to average over. Default: None
  • all_reduce_size (int): Group arrays until size exceeds this. Default: 32MB
  • communication_type: Cast to this type before communication. Default: None
  • communication_stream: Stream to use for communication. Default: None
# In a distributed training loop
loss, grads = loss_and_grad_fn(x, y)
grads = nn.average_gradients(grads)
optimizer.update(model, grads)

Loss Functions

MLX provides common loss functions in mlx.nn.losses:
nn.losses.cross_entropy
function
Computes cross entropy loss.
from mlx.nn import losses

logits = model(x)
loss = losses.cross_entropy(logits, targets)
nn.losses.binary_cross_entropy
function
Computes binary cross entropy loss.
loss = losses.binary_cross_entropy(predictions, targets)
nn.losses.mse_loss
function
Computes mean squared error loss.
loss = losses.mse_loss(predictions, targets)
nn.losses.l1_loss
function
Computes L1 loss (mean absolute error).
loss = losses.l1_loss(predictions, targets)
nn.losses.smooth_l1_loss
function
Computes smooth L1 loss (Huber loss).
loss = losses.smooth_l1_loss(predictions, targets)
nn.losses.nll_loss
function
Computes negative log likelihood loss.
loss = losses.nll_loss(log_probs, targets)
nn.losses.kl_div_loss
function
Computes Kullback-Leibler divergence loss.
loss = losses.kl_div_loss(log_predictions, targets)

Training Example

Here’s a complete example of training a neural network with MLX:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# Define model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        ]
    
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Create model and optimizer
model = MLP()
optimizer = optim.Adam(learning_rate=1e-3)

# Define loss function
def loss_fn(x, y):
    logits = model(x)
    return nn.losses.cross_entropy(logits, y)

# Create gradient function
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Training loop
model.train()
for epoch in range(num_epochs):
    for x_batch, y_batch in data_loader:
        # Compute loss and gradients
        loss, grads = loss_and_grad_fn(x_batch, y_batch)
        
        # Update parameters
        optimizer.update(model, grads)
        
        # Evaluate (forces computation)
        mx.eval(model.parameters(), optimizer.state)

# Evaluation
model.eval()
predictions = model(x_test)