MLX enables efficient inference of large-ish transformers on Apple silicon without compromising on ease of use. In this example, we’ll create an inference script for the Llama family of transformer models with the model defined in less than 200 lines of Python.Documentation Index
Fetch the complete documentation index at: https://mintlify.com/ml-explore/mlx/llms.txt
Use this file to discover all available pages before exploring further.
Implementing the Model
We’ll use the neural network building blocks defined in themlx.nn module to concisely define the model architecture.
Attention Layer
We’ll start with the Llama attention layer which uses RoPE positional encoding. Our implementation includes an optional key/value cache for efficient inference. We usemlx.nn.Linear for all projections and mlx.nn.RoPE for positional encoding:
Encoder Layer
The encoder layer uses RMS normalization and SwiGLU activation. We use themlx.nn.RMSNorm layer that’s already provided:
Full Model
To implement any Llama model, we simply combineLlamaEncoderLayer instances with an mlx.nn.Embedding to embed the input tokens:
We use a simple list to hold the encoder layers, but
model.parameters() will still consider these layers.Generation
The__call__ method above is suitable for training but not inference. We need to add a generation method that uses the cache and performs autoregressive sampling:
Using the Model
We now have everything needed to create a Llama model and sample tokens from it:Converting Weights
To use actual Llama weights, you need to convert PyTorch weights to MLX format. Here’s a script that maps PyTorch parameter names to MLX names:Loading Weights and Benchmarking
Load the converted weights usingmlx.utils.tree_unflatten:
tree_unflatten method transforms flat keys like layers.2.attention.query_proj.weight into nested dictionaries that can update the model.