Skip to main content
The mlx.optimizers module provides optimizers for training neural networks. All optimizers work with both mlx.nn modules and pure mlx.core functions.

Quick Start

Here’s a typical training loop with an optimizer:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# Create model and optimizer
model = MLP(num_layers, input_dims, hidden_dim, output_dims)
mx.eval(model.parameters())

# Create gradient function and optimizer
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=0.01)

for epoch in range(num_epochs):
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        loss, grads = loss_and_grad_fn(model, X, y)
        
        # Update model with gradients
        optimizer.update(model, grads)
        
        # Evaluate parameters and optimizer state
        mx.eval(model.parameters(), optimizer.state)

Base Optimizer

optim.Optimizer
class
Base class for all optimizers.Allows implementing optimizers on a per-parameter basis and applying them to parameter trees.Key Methods:
  • update(model, gradients): Apply gradients to model parameters
  • init(parameters): Initialize optimizer state
  • apply_gradients(gradients, parameters): Apply gradients and return updated parameters

Optimizer Methods

update(model, gradients)
method
Apply the gradients to the parameters of the model and update the model.Parameters:
  • model (nn.Module): An MLX module to be updated
  • gradients (dict): Python tree of gradients, typically from nn.value_and_grad
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
init(parameters)
method
Initialize the optimizer’s state.Optional - the optimizer will initialize itself on first update if not called explicitly.Parameters:
  • parameters (dict): Python tree of parameters
optimizer = optim.SGD(learning_rate=0.1, momentum=0.9)
model = nn.Linear(2, 2)
optimizer.init(model.trainable_parameters())
print(optimizer.state.keys())
# dict_keys(['step', 'learning_rate', 'weight', 'bias'])
state
property
The optimizer’s state dictionary.Contains step count, learning rate, and optimizer-specific state (e.g., momentum).
print(optimizer.state)
print(f"Step: {optimizer.step}")
print(f"Learning rate: {optimizer.learning_rate}")

Common Optimizers

optim.SGD
class
Stochastic Gradient Descent optimizer.Updates: v_t+1 = μv_t + (1 - τ)g_t and w_t+1 = w_t - λv_t+1Parameters:
  • learning_rate (float or callable): The learning rate λ
  • momentum (float): The momentum strength μ. Default: 0
  • weight_decay (float): The weight decay (L2 penalty). Default: 0
  • dampening (float): Dampening for momentum τ. Default: 0
  • nesterov (bool): Enables Nesterov momentum. Default: False
# Basic SGD
optimizer = optim.SGD(learning_rate=0.01)

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.01, momentum=0.9)

# SGD with Nesterov momentum
optimizer = optim.SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
optim.Adam
class
Adam optimizer.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • betas (Tuple[float, float]): Coefficients (β₁, β₂) for running averages. Default: (0.9, 0.999)
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-8
  • bias_correction (bool): If True, apply bias correction. Default: False
optimizer = optim.Adam(learning_rate=1e-3)

# With custom betas
optimizer = optim.Adam(learning_rate=1e-3, betas=(0.9, 0.999))
optim.AdamW
class
AdamW optimizer with decoupled weight decay.Parameters:
  • learning_rate (float or callable): The learning rate α
  • betas (Tuple[float, float]): Coefficients (β₁, β₂) for running averages. Default: (0.9, 0.999)
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-8
  • weight_decay (float): The weight decay λ. Default: 0.01
  • bias_correction (bool): If True, apply bias correction. Default: False
optimizer = optim.AdamW(learning_rate=1e-3, weight_decay=0.01)
optim.Adamax
class
Adamax optimizer, a variant of Adam based on the infinity norm.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • betas (Tuple[float, float]): Coefficients (β₁, β₂). Default: (0.9, 0.999)
  • eps (float): Term ε added to denominator. Default: 1e-8
optimizer = optim.Adamax(learning_rate=1e-3)
optim.Lion
class
Lion optimizer.Recommended to use a learning rate 3-10x smaller than AdamW and weight decay 3-10x larger.Parameters:
  • learning_rate (float or callable): The learning rate η
  • betas (Tuple[float, float]): Coefficients (β₁, β₂). Default: (0.9, 0.99)
  • weight_decay (float): The weight decay λ. Default: 0.0
# Lion typically needs smaller learning rate than AdamW
optimizer = optim.Lion(learning_rate=1e-4, weight_decay=0.1)
optim.Adagrad
class
Adagrad optimizer.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-8
optimizer = optim.Adagrad(learning_rate=0.01)
optim.AdaDelta
class
AdaDelta optimizer with a learning rate.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • rho (float): Coefficient ρ for computing running average of squared gradients. Default: 0.9
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-6
optimizer = optim.AdaDelta(learning_rate=1.0, rho=0.9)
optim.RMSprop
class
RMSprop optimizer.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • alpha (float): The smoothing constant α. Default: 0.99
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-8
optimizer = optim.RMSprop(learning_rate=0.01, alpha=0.99)
optim.Adafactor
class
Adafactor optimizer with adaptive learning rates and sublinear memory cost.Parameters:
  • learning_rate (float or callable): The learning rate. Default: None
  • eps (tuple): (ε₁, ε₂) for numerical stability and parameter scaling. Default: (1e-30, 1e-3)
  • clip_threshold (float): Clips unscaled update at this threshold. Default: 1.0
  • decay_rate (float): Coefficient for running average of squared gradient. Default: -0.8
  • beta_1 (float): If set, use first moment. Default: None
  • weight_decay (float): The weight decay λ. Default: 0.0
  • scale_parameter (bool): If True, scale learning rate by RMS of parameters. Default: True
  • relative_step (bool): If True, use relative step size. Default: True
  • warmup_init (bool): If True, calculate step size by current step. Default: False
optimizer = optim.Adafactor(learning_rate=None, relative_step=True)
optim.Muon
class
Muon (MomentUm Orthogonalized by Newton-schulz) optimizer.Note: Muon may be sub-optimal for embedding layers, final fully connected layers, or 0D/1D parameters. Use a different optimizer (e.g., AdamW) for those.Parameters:
  • learning_rate (float or callable): The learning rate
  • momentum (float): The momentum strength. Default: 0.95
  • weight_decay (float): The weight decay (L2 penalty). Default: 0.01
  • nesterov (bool): Enables Nesterov momentum. Default: True
  • ns_steps (int): Number of Newton-Schulz iteration steps. Default: 5
optimizer = optim.Muon(learning_rate=0.01, momentum=0.95, weight_decay=0.01)

Multi-Optimizer

optim.MultiOptimizer
class
Wraps multiple optimizers with weight predicates to use different optimizers for different parameters.Parameters:
  • optimizers (list[Optimizer]): List of optimizers to delegate to
  • filters (list[Callable]): List of predicates (one less than optimizers). Last optimizer is fallback.
# Use AdamW for most parameters, but SGD for biases
optimizer = optim.MultiOptimizer(
    optimizers=[
        optim.SGD(learning_rate=0.01),
        optim.AdamW(learning_rate=1e-3)
    ],
    filters=[
        lambda k, v: "bias" in k  # Use SGD for biases
        # AdamW is fallback for everything else
    ]
)

Learning Rate Schedulers

Learning rate schedulers can be passed directly to optimizers:
optim.exponential_decay
function
Make an exponential decay scheduler.Parameters:
  • init (float): Initial value
  • decay_rate (float): Multiplicative factor to decay by
lr_schedule = optim.exponential_decay(1e-1, 0.9)
optimizer = optim.SGD(learning_rate=lr_schedule)

# Learning rate decays exponentially with each step
print(optimizer.learning_rate)  # 0.1
for _ in range(5):
    optimizer.update({}, {})
print(optimizer.learning_rate)  # 0.06561
optim.step_decay
function
Make a step decay scheduler.Parameters:
  • init (float): Initial value
  • decay_rate (float): Multiplicative factor to decay by
  • step_size (int): Decay every step_size steps
lr_schedule = optim.step_decay(1e-1, 0.9, step_size=10)
optimizer = optim.SGD(learning_rate=lr_schedule)

# Learning rate stays constant for 10 steps, then decays
for _ in range(21):
    optimizer.update({}, {})
print(optimizer.learning_rate)  # 0.081
optim.cosine_decay
function
Make a cosine decay scheduler.Parameters:
  • init (float): Initial value
  • decay_steps (int): Number of steps to decay over
  • end (float): Final value to decay to. Default: 0.0
lr_schedule = optim.cosine_decay(1e-1, decay_steps=1000)
optimizer = optim.SGD(learning_rate=lr_schedule)

# Learning rate follows cosine curve from init to end
optim.linear_schedule
function
Make a linear scheduler.Parameters:
  • init (float): Initial value
  • end (float): Final value
  • steps (int): Number of steps to apply schedule over
lr_schedule = optim.linear_schedule(0, 1e-1, steps=100)
optimizer = optim.Adam(learning_rate=lr_schedule)

# Learning rate linearly increases from 0 to 0.1 over 100 steps
print(optimizer.learning_rate)  # 0.0
for _ in range(101):
    optimizer.update({}, {})
print(optimizer.learning_rate)  # 0.1
optim.join_schedules
function
Join multiple schedules to create a new schedule.Parameters:
  • schedules (list[Callable]): List of schedules
  • boundaries (list[int]): Boundaries indicating when to transition between schedules
# Warmup with linear schedule, then cosine decay
linear = optim.linear_schedule(0, 1e-1, steps=10)
cosine = optim.cosine_decay(1e-1, decay_steps=200)
lr_schedule = optim.join_schedules([linear, cosine], boundaries=[10])

optimizer = optim.Adam(learning_rate=lr_schedule)
print(optimizer.learning_rate)  # 0.0 (linear warmup)

for _ in range(12):
    optimizer.update({}, {})
print(optimizer.learning_rate)  # ~0.0999 (cosine decay)

Gradient Clipping

optim.clip_grad_norm
function
Clips the global norm of the gradients.Ensures that the global norm of gradients does not exceed max_norm. Scales down gradients proportionally if needed.Parameters:
  • grads (dict): Dictionary containing gradient arrays
  • max_norm (float): Maximum allowed global norm of gradients
Returns:
  • (dict, float): Clipped gradients and original gradient norm
loss, grads = loss_and_grad_fn(model, x, y)

# Clip gradients to max norm of 1.0
clipped_grads, total_norm = optim.clip_grad_norm(grads, max_norm=1.0)

optimizer.update(model, clipped_grads)
mx.eval(model.parameters(), optimizer.state)

print(f"Gradient norm: {total_norm}")

Saving and Loading

To serialize an optimizer, save its state. To load an optimizer, load and set the saved state.
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim

# Create and use optimizer
optimizer = optim.Adam(learning_rate=1e-2)
model = {"w": mx.zeros((5, 5))}
grads = {"w": mx.ones((5, 5))}
optimizer.update(model, grads)

# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)

# Later: recreate optimizer and load state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state
Note: Not every optimizer configuration parameter is saved in the state. For example, for Adam the learning rate is saved but betas and eps are not. As a rule of thumb, if a parameter can be scheduled, it will be included in the optimizer state.

Complete Training Example

Here’s a complete example showing optimizer usage in a training loop:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# Define model
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.layers = [
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        ]
    
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Initialize model
model = MLP(784, 128, 10)
mx.eval(model.parameters())

# Define loss function
def loss_fn(model, x, y):
    logits = model(x)
    return nn.losses.cross_entropy(logits, y)

# Create optimizer with learning rate schedule
lr_schedule = optim.join_schedules(
    [
        optim.linear_schedule(0, 1e-3, steps=100),  # Warmup
        optim.cosine_decay(1e-3, decay_steps=1000)  # Decay
    ],
    boundaries=[100]
)
optimizer = optim.AdamW(learning_rate=lr_schedule, weight_decay=0.01)

# Create gradient function
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for x_batch, y_batch in dataloader:
        # Compute loss and gradients
        loss, grads = loss_and_grad_fn(model, x_batch, y_batch)
        
        # Clip gradients
        grads, grad_norm = optim.clip_grad_norm(grads, max_norm=1.0)
        
        # Update model
        optimizer.update(model, grads)
        
        # Evaluate
        mx.eval(model.parameters(), optimizer.state, loss)
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader):.4f}, "
          f"LR = {optimizer.learning_rate.item():.6f}")

# Save optimizer state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer_checkpoint.safetensors", state)

Using Multiple Optimizers

For advanced use cases, you can use different optimizers for different parts of your model:
# Use different optimizers for different parameter groups
def is_embedding(key, value):
    return "embedding" in key

def is_output_layer(key, value):
    return "output" in key

optimizer = optim.MultiOptimizer(
    optimizers=[
        optim.SGD(learning_rate=0.001),      # For embeddings
        optim.SGD(learning_rate=0.01),       # For output layer  
        optim.AdamW(learning_rate=1e-3)      # For everything else (fallback)
    ],
    filters=[
        is_embedding,
        is_output_layer
    ]
)