The mlx.core.fast module provides highly optimized implementations of common operations used in deep learning models. These functions are specifically tuned for performance and should be used instead of manual implementations when available.
Overview
Fast operations include optimized implementations of normalization layers, attention mechanisms, and positional encodings that are commonly used in transformer models.
Functions
rms_norm
mlx.core.fast.rms_norm(
x: array,
weight: array,
eps: float = 1e-5
) -> array
Root Mean Square Layer Normalization.
RMS normalization normalizes the input using the root mean square, which is faster than standard layer normalization as it doesn’t compute the mean.
Parameters:
x (array): Input array to normalize
weight (array): Learnable scale parameter
eps (float): Small constant for numerical stability. Default: 1e-5
Returns:
- Normalized array with the same shape as input
Example:
import mlx.core as mx
import mlx.core.fast as fast
# Input tensor
x = mx.random.normal((2, 10, 512))
weight = mx.ones((512,))
# Apply RMS normalization
out = fast.rms_norm(x, weight, eps=1e-6)
print(out.shape) # (2, 10, 512)
layer_norm
mlx.core.fast.layer_norm(
x: array,
weight: Optional[array] = None,
bias: Optional[array] = None,
eps: float = 1e-5
) -> array
Optimized layer normalization.
This function provides a faster implementation of layer normalization compared to the standard approach.
Parameters:
x (array): Input array to normalize
weight (array, optional): Learnable scale parameter
bias (array, optional): Learnable shift parameter
eps (float): Small constant for numerical stability. Default: 1e-5
Returns:
- Normalized array with the same shape as input
Example:
import mlx.core as mx
import mlx.core.fast as fast
# Input tensor
x = mx.random.normal((4, 128, 768))
weight = mx.ones((768,))
bias = mx.zeros((768,))
# Apply layer normalization
out = fast.layer_norm(x, weight, bias)
print(out.shape) # (4, 128, 768)
rope
mlx.core.fast.rope(
x: array,
dims: int,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
offset: int = 0
) -> array
Rotary Positional Embedding (RoPE).
RoPE encodes positional information by rotating pairs of dimensions in the input. This is commonly used in transformer models like LLaMA and GPT-NeoX.
Parameters:
x (array): Input array with shape (..., sequence_length, features)
dims (int): Number of dimensions to apply the rotation to
traditional (bool): Use traditional RoPE implementation. Default: False
base (float): Base for computing rotation frequencies. Default: 10000
scale (float): Scale factor for positions. Default: 1.0
offset (int): Offset to add to positions. Default: 0
Returns:
- Array with rotary positional embeddings applied
Example:
import mlx.core as mx
import mlx.core.fast as fast
# Query or key tensor from attention
# Shape: (batch, num_heads, seq_len, head_dim)
q = mx.random.normal((2, 8, 512, 64))
# Apply RoPE
q_rotated = fast.rope(q, dims=64, traditional=True)
print(q_rotated.shape) # (2, 8, 512, 64)
# With offset for cached keys
q_new = mx.random.normal((2, 8, 1, 64))
q_new_rotated = fast.rope(q_new, dims=64, offset=512)
print(q_new_rotated.shape) # (2, 8, 1, 64)
scaled_dot_product_attention
mlx.core.fast.scaled_dot_product_attention(
queries: array,
keys: array,
values: array,
scale: float = None,
mask: Optional[array] = None
) -> array
Optimized scaled dot-product attention.
Computes the attention output using the formula:
softmax(Q @ K^T / sqrt(d_k)) @ V
Parameters:
queries (array): Query tensor of shape (..., seq_len_q, d_k)
keys (array): Key tensor of shape (..., seq_len_k, d_k)
values (array): Value tensor of shape (..., seq_len_k, d_v)
scale (float, optional): Scale factor. If None, uses 1 / sqrt(d_k)
mask (array, optional): Attention mask to apply
Returns:
- Attention output of shape
(..., seq_len_q, d_v)
Example:
import mlx.core as mx
import mlx.core.fast as fast
batch_size = 2
num_heads = 8
seq_len = 128
head_dim = 64
# Create query, key, value tensors
q = mx.random.normal((batch_size, num_heads, seq_len, head_dim))
k = mx.random.normal((batch_size, num_heads, seq_len, head_dim))
v = mx.random.normal((batch_size, num_heads, seq_len, head_dim))
# Create causal mask
mask = mx.tril(mx.ones((seq_len, seq_len)))
mask = mx.where(mask, 0.0, float('-inf'))
# Compute attention
out = fast.scaled_dot_product_attention(q, k, v, mask=mask)
print(out.shape) # (2, 8, 128, 64)
The fast attention implementation may have different numerical precision compared to manual implementations. For most use cases, the difference is negligible and the performance benefits are significant.
mlx.core.fast.metal_kernel(
name: str,
input_names: List[str],
output_names: List[str],
source: str,
header: str = "",
ensure_row_contiguous: bool = True,
atomic_outputs: bool = False
)
Compile and register a custom Metal kernel.
This function allows you to write custom GPU kernels in Metal Shading Language for operations not provided by MLX.
Parameters:
name (str): Name of the kernel function
input_names (List[str]): Names of input arrays
output_names (List[str]): Names of output arrays
source (str): Metal kernel source code
header (str): Additional header code. Default: ""
ensure_row_contiguous (bool): Ensure inputs are row-contiguous. Default: True
atomic_outputs (bool): Use atomic operations for outputs. Default: False
Example:
import mlx.core as mx
import mlx.core.fast as fast
# Define a simple element-wise multiply kernel
source = '''
uint elem = thread_position_in_grid.x;
if (elem < out.size()) {
out[elem] = in0[elem] * in1[elem];
}
'''
# Register the kernel
kernel = fast.metal_kernel(
name="elem_multiply",
input_names=["in0", "in1"],
output_names=["out"],
source=source
)
# Use the kernel
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = kernel(a, b)
print(result) # [4.0, 10.0, 18.0]
Custom Metal kernels require knowledge of Metal Shading Language and GPU programming. Ensure your kernels are properly tested and handle edge cases.
cuda_kernel
mlx.core.fast.cuda_kernel(
name: str,
input_names: List[str],
output_names: List[str],
source: str,
header: str = "",
ensure_row_contiguous: bool = True,
atomic_outputs: bool = False
)
Compile and register a custom CUDA kernel.
Similar to metal_kernel, but for CUDA devices. This allows you to write custom GPU kernels in CUDA C for operations not provided by MLX.
Parameters:
name (str): Name of the kernel function
input_names (List[str]): Names of input arrays
output_names (List[str]): Names of output arrays
source (str): CUDA kernel source code
header (str): Additional header code. Default: ""
ensure_row_contiguous (bool): Ensure inputs are row-contiguous. Default: True
atomic_outputs (bool): Use atomic operations for outputs. Default: False
Example:
import mlx.core as mx
import mlx.core.fast as fast
# Define a simple element-wise add kernel
source = '''
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
out[idx] = in0[idx] + in1[idx];
}
'''
# Register the kernel
kernel = fast.cuda_kernel(
name="elem_add",
input_names=["in0", "in1"],
output_names=["out"],
source=source
)
# Use the kernel
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = kernel(a, b)
print(result) # [5.0, 7.0, 9.0]
Custom CUDA kernels are only available when MLX is built with CUDA support. The kernel will raise an error if CUDA is not available.
- Use fast operations: Always prefer functions from
mlx.core.fast over manual implementations for common patterns
- Batch operations: Process multiple sequences together for better GPU utilization
- Memory layout: Ensure your tensors are properly shaped for optimal memory access patterns
- Kernel fusion: MLX automatically fuses operations, but using fast functions ensures the best fusion opportunities
See Also