mlx.nn module provides an intuitive way of composing neural network layers, initializing their parameters, and managing trainable weights for training and fine-tuning.
Quick Start
Here’s a simple example of building a multi-layer perceptron:Module
Base class for building neural networks with MLX.All layers in
mlx.nn.layers subclass this class. Your custom models should also inherit from Module.A Module can contain other Module instances or mlx.core.array instances in arbitrary nesting of Python lists or dicts.Key Features:- Recursively extracts all parameters using
parameters() - Supports trainable and frozen parameters
- Provides parameter state management
Module Methods
Recursively return all the
mlx.core.array members of this Module as a dict of dicts and lists.Recursively return all the non-frozen
mlx.core.array members of this Module.Replace the parameters of this Module with the provided ones.Commonly used by optimizers to update model parameters.
Freeze the Module’s parameters or some of them. Freezing means not computing gradients.Parameters:
recurse(bool): If True, freeze parameters of submodules as well. Default:Truekeys(str or list[str]): If provided, only freeze these parameters. Default:Nonestrict(bool): If True, validate that the passed keys exist. Default:False
Unfreeze the Module’s parameters or some of them.
Set the model in or out of training mode.Training mode applies to layers like
Dropout which behave differently during training vs evaluation.Update the model’s weights from a
.npz, .safetensors file, or a list.Parameters:file_or_weights: Path to weights file or list of (name, array) pairsstrict(bool): If True, checks that weights exactly match parameters. Default:True
Save the model’s weights to a file.Supports
.npz and .safetensors formats.Linear Layers
Applies an affine transformation to the input:
y = xW^T + bParameters:input_dims(int): The dimensionality of the input featuresoutput_dims(int): The dimensionality of the output featuresbias(bool): If False, the layer will not use a bias. Default:True
Applies a bilinear transformation to the inputs:
y_i = x1^T W_i x2 + b_iParameters:input1_dims(int): The dimensionality of the first inputinput2_dims(int): The dimensionality of the second inputoutput_dims(int): The dimensionality of the outputbias(bool): If False, no bias is used. Default:True
A placeholder identity operator that returns the input unchanged.
Convolutional Layers
Applies a 1-dimensional convolution over the multi-channel input sequence.Input shape:
NLC where N=batch, L=sequence length, C=channelsParameters:in_channels(int): Number of input channelsout_channels(int): Number of output channelskernel_size(int): Size of the convolution filtersstride(int): Stride when applying the filter. Default:1padding(int): Positions to 0-pad the input. Default:0dilation(int): Dilation of the convolution. Default:1groups(int): Number of groups for the convolution. Default:1bias(bool): If True, add a learnable bias. Default:True
Applies a 2-dimensional convolution over the multi-channel input image.Input shape:
NHWC where N=batch, H=height, W=width, C=channelsParameters:in_channels(int): Number of input channelsout_channels(int): Number of output channelskernel_size(int or tuple): Size of the convolution filtersstride(int or tuple): Stride when applying the filter. Default:1padding(int or tuple): Positions to 0-pad the input. Default:0dilation(int or tuple): Dilation of the convolution. Default:1groups(int): Number of groups for the convolution. Default:1bias(bool): If True, add a learnable bias. Default:True
Applies a 3-dimensional convolution over the multi-channel input volume.Similar parameters to Conv2d but for 3D inputs.
Normalization Layers
Applies layer normalization on the inputs.Computes:
y = (x - E[x]) / sqrt(Var[x] + eps) * γ + βParameters:dims(int): The feature dimension to normalize overeps(float): Small constant for numerical stability. Default:1e-5affine(bool): If True, learn an affine transform. Default:Truebias(bool): If True, include bias in affine transform. Default:True
Applies Root Mean Square normalization.Computes:
y = x / sqrt(mean(x^2) + eps) * γParameters:dims(int): The feature dimension to normalize overeps(float): Small constant for numerical stability. Default:1e-5
Applies batch normalization over the inputs.Parameters:
num_features(int): Number of features in the inputeps(float): Small constant for numerical stability. Default:1e-5momentum(float): Momentum for running statistics. Default:0.1affine(bool): If True, learn scale and shift. Default:Truetrack_running_stats(bool): Track running mean and variance. Default:True
Applies instance normalization on the inputs.Parameters:
dims(int): Number of features in the inputeps(float): Small constant for numerical stability. Default:1e-5affine(bool): If True, learn scale and shift. Default:False
Applies group normalization on the inputs.Parameters:
num_groups(int): Number of groups to separate the channels intodims(int): Number of features in the inputeps(float): Small constant for numerical stability. Default:1e-5affine(bool): If True, learn scale and shift. Default:True
Activation Functions
All activation functions are available as both functional APIs and Module classes.Applies the Rectified Linear Unit:
max(0, x)Applies Leaky ReLU:
max(negative_slope * x, x)Parameters:negative_slope(float): Controls the angle of the negative slope. Default:0.01
Applies ReLU6:
min(max(0, x), 6)Applies the Gaussian Error Linear Unit.Parameters:
approx(str): Approximation to use: ‘none’, ‘precise’, ‘tanh’, or ‘fast’. Default: ‘none’
Applies the Sigmoid Linear Unit (Swish):
x * σ(x)Applies the Sigmoid function:
1 / (1 + exp(-x))Applies the hyperbolic tangent function.
Applies the Softmax function:
exp(x_i) / sum(exp(x_j))Applies the Log Softmax function.
Applies the Exponential Linear Unit.Parameters:
alpha(float): The α value for ELU. Default:1.0
Applies the Scaled Exponential Linear Unit.
Applies the parametric ReLU:
max(0, x) + a * min(0, x)Parameters:num_parameters(int): Number of α parameters to learn. Default:1init(float): Initial value of α. Default:0.25
Applies the Mish function:
x * tanh(softplus(x))Applies the Hardswish function:
x * min(max(x + 3, 0), 6) / 6Applies the Gated Linear Unit:
a * σ(b) where a, b are split from input.Parameters:axis(int): The dimension to split along. Default:-1
Dropout and Regularization
Randomly zero elements during training with probability p.Parameters:
p(float): Probability of zeroing an element. Default:0.5
Randomly zero entire channels during training.Parameters:
p(float): Probability of zeroing a channel. Default:0.5
Randomly zero entire channels during training (3D version).Parameters:
p(float): Probability of zeroing a channel. Default:0.5
Embedding Layers
Simple lookup table that maps indices to embeddings.Parameters:
num_embeddings(int): Size of the dictionary of embeddingsdims(int): Dimension of each embedding vector
Recurrent Layers
Applies a multi-layer Elman RNN with tanh or ReLU activation.Parameters:
input_size(int): Dimension of input featureshidden_size(int): Dimension of hidden statenum_layers(int): Number of recurrent layers. Default:1nonlinearity(str): Activation function (‘tanh’ or ‘relu’). Default:'tanh'bias(bool): If False, no bias is used. Default:True
Applies a multi-layer Gated Recurrent Unit.Parameters:
input_size(int): Dimension of input featureshidden_size(int): Dimension of hidden statenum_layers(int): Number of recurrent layers. Default:1bias(bool): If False, no bias is used. Default:True
Applies a multi-layer Long Short-Term Memory.Parameters:
input_size(int): Dimension of input featureshidden_size(int): Dimension of hidden statenum_layers(int): Number of recurrent layers. Default:1bias(bool): If False, no bias is used. Default:True
Container Layers
A container that runs modules sequentially.
Utilities
Transform a function to compute gradients with respect to the model’s trainable parameters.Parameters:
model(nn.Module): The model whose trainable parameters to compute gradients forfn(Callable): The scalar function to compute gradients for
Average gradients across distributed processes.Parameters:
gradients: Python tree containing the gradientsgroup: Group of processes to average over. Default:Noneall_reduce_size(int): Group arrays until size exceeds this. Default:32MBcommunication_type: Cast to this type before communication. Default:Nonecommunication_stream: Stream to use for communication. Default:None
Loss Functions
MLX provides common loss functions inmlx.nn.losses:
Computes cross entropy loss.
Computes binary cross entropy loss.
Computes mean squared error loss.
Computes L1 loss (mean absolute error).
Computes smooth L1 loss (Huber loss).
Computes negative log likelihood loss.
Computes Kullback-Leibler divergence loss.