Skip to main content
The mlx.core.fast module provides highly optimized implementations of common operations used in deep learning models. These functions are specifically tuned for performance and should be used instead of manual implementations when available.

Overview

Fast operations include optimized implementations of normalization layers, attention mechanisms, and positional encodings that are commonly used in transformer models.

Functions

rms_norm

mlx.core.fast.rms_norm(
    x: array,
    weight: array,
    eps: float = 1e-5
) -> array
Root Mean Square Layer Normalization. RMS normalization normalizes the input using the root mean square, which is faster than standard layer normalization as it doesn’t compute the mean. Parameters:
  • x (array): Input array to normalize
  • weight (array): Learnable scale parameter
  • eps (float): Small constant for numerical stability. Default: 1e-5
Returns:
  • Normalized array with the same shape as input
Example:
import mlx.core as mx
import mlx.core.fast as fast

# Input tensor
x = mx.random.normal((2, 10, 512))
weight = mx.ones((512,))

# Apply RMS normalization
out = fast.rms_norm(x, weight, eps=1e-6)
print(out.shape)  # (2, 10, 512)

layer_norm

mlx.core.fast.layer_norm(
    x: array,
    weight: Optional[array] = None,
    bias: Optional[array] = None,
    eps: float = 1e-5
) -> array
Optimized layer normalization. This function provides a faster implementation of layer normalization compared to the standard approach. Parameters:
  • x (array): Input array to normalize
  • weight (array, optional): Learnable scale parameter
  • bias (array, optional): Learnable shift parameter
  • eps (float): Small constant for numerical stability. Default: 1e-5
Returns:
  • Normalized array with the same shape as input
Example:
import mlx.core as mx
import mlx.core.fast as fast

# Input tensor
x = mx.random.normal((4, 128, 768))
weight = mx.ones((768,))
bias = mx.zeros((768,))

# Apply layer normalization
out = fast.layer_norm(x, weight, bias)
print(out.shape)  # (4, 128, 768)

rope

mlx.core.fast.rope(
    x: array,
    dims: int,
    traditional: bool = False,
    base: float = 10000,
    scale: float = 1.0,
    offset: int = 0
) -> array
Rotary Positional Embedding (RoPE). RoPE encodes positional information by rotating pairs of dimensions in the input. This is commonly used in transformer models like LLaMA and GPT-NeoX. Parameters:
  • x (array): Input array with shape (..., sequence_length, features)
  • dims (int): Number of dimensions to apply the rotation to
  • traditional (bool): Use traditional RoPE implementation. Default: False
  • base (float): Base for computing rotation frequencies. Default: 10000
  • scale (float): Scale factor for positions. Default: 1.0
  • offset (int): Offset to add to positions. Default: 0
Returns:
  • Array with rotary positional embeddings applied
Example:
import mlx.core as mx
import mlx.core.fast as fast

# Query or key tensor from attention
# Shape: (batch, num_heads, seq_len, head_dim)
q = mx.random.normal((2, 8, 512, 64))

# Apply RoPE
q_rotated = fast.rope(q, dims=64, traditional=True)
print(q_rotated.shape)  # (2, 8, 512, 64)

# With offset for cached keys
q_new = mx.random.normal((2, 8, 1, 64))
q_new_rotated = fast.rope(q_new, dims=64, offset=512)
print(q_new_rotated.shape)  # (2, 8, 1, 64)

scaled_dot_product_attention

mlx.core.fast.scaled_dot_product_attention(
    queries: array,
    keys: array,
    values: array,
    scale: float = None,
    mask: Optional[array] = None
) -> array
Optimized scaled dot-product attention. Computes the attention output using the formula: softmax(Q @ K^T / sqrt(d_k)) @ V Parameters:
  • queries (array): Query tensor of shape (..., seq_len_q, d_k)
  • keys (array): Key tensor of shape (..., seq_len_k, d_k)
  • values (array): Value tensor of shape (..., seq_len_k, d_v)
  • scale (float, optional): Scale factor. If None, uses 1 / sqrt(d_k)
  • mask (array, optional): Attention mask to apply
Returns:
  • Attention output of shape (..., seq_len_q, d_v)
Example:
import mlx.core as mx
import mlx.core.fast as fast

batch_size = 2
num_heads = 8
seq_len = 128
head_dim = 64

# Create query, key, value tensors
q = mx.random.normal((batch_size, num_heads, seq_len, head_dim))
k = mx.random.normal((batch_size, num_heads, seq_len, head_dim))
v = mx.random.normal((batch_size, num_heads, seq_len, head_dim))

# Create causal mask
mask = mx.tril(mx.ones((seq_len, seq_len)))
mask = mx.where(mask, 0.0, float('-inf'))

# Compute attention
out = fast.scaled_dot_product_attention(q, k, v, mask=mask)
print(out.shape)  # (2, 8, 128, 64)
The fast attention implementation may have different numerical precision compared to manual implementations. For most use cases, the difference is negligible and the performance benefits are significant.

metal_kernel

mlx.core.fast.metal_kernel(
    name: str,
    input_names: List[str],
    output_names: List[str],
    source: str,
    header: str = "",
    ensure_row_contiguous: bool = True,
    atomic_outputs: bool = False
)
Compile and register a custom Metal kernel. This function allows you to write custom GPU kernels in Metal Shading Language for operations not provided by MLX. Parameters:
  • name (str): Name of the kernel function
  • input_names (List[str]): Names of input arrays
  • output_names (List[str]): Names of output arrays
  • source (str): Metal kernel source code
  • header (str): Additional header code. Default: ""
  • ensure_row_contiguous (bool): Ensure inputs are row-contiguous. Default: True
  • atomic_outputs (bool): Use atomic operations for outputs. Default: False
Example:
import mlx.core as mx
import mlx.core.fast as fast

# Define a simple element-wise multiply kernel
source = '''
uint elem = thread_position_in_grid.x;
if (elem < out.size()) {
    out[elem] = in0[elem] * in1[elem];
}
'''

# Register the kernel
kernel = fast.metal_kernel(
    name="elem_multiply",
    input_names=["in0", "in1"],
    output_names=["out"],
    source=source
)

# Use the kernel
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = kernel(a, b)
print(result)  # [4.0, 10.0, 18.0]
Custom Metal kernels require knowledge of Metal Shading Language and GPU programming. Ensure your kernels are properly tested and handle edge cases.

cuda_kernel

mlx.core.fast.cuda_kernel(
    name: str,
    input_names: List[str],
    output_names: List[str],
    source: str,
    header: str = "",
    ensure_row_contiguous: bool = True,
    atomic_outputs: bool = False
)
Compile and register a custom CUDA kernel. Similar to metal_kernel, but for CUDA devices. This allows you to write custom GPU kernels in CUDA C for operations not provided by MLX. Parameters:
  • name (str): Name of the kernel function
  • input_names (List[str]): Names of input arrays
  • output_names (List[str]): Names of output arrays
  • source (str): CUDA kernel source code
  • header (str): Additional header code. Default: ""
  • ensure_row_contiguous (bool): Ensure inputs are row-contiguous. Default: True
  • atomic_outputs (bool): Use atomic operations for outputs. Default: False
Example:
import mlx.core as mx
import mlx.core.fast as fast

# Define a simple element-wise add kernel
source = '''
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
    out[idx] = in0[idx] + in1[idx];
}
'''

# Register the kernel
kernel = fast.cuda_kernel(
    name="elem_add",
    input_names=["in0", "in1"],
    output_names=["out"],
    source=source
)

# Use the kernel
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = kernel(a, b)
print(result)  # [5.0, 7.0, 9.0]
Custom CUDA kernels are only available when MLX is built with CUDA support. The kernel will raise an error if CUDA is not available.

Performance Tips

  1. Use fast operations: Always prefer functions from mlx.core.fast over manual implementations for common patterns
  2. Batch operations: Process multiple sequences together for better GPU utilization
  3. Memory layout: Ensure your tensors are properly shaped for optimal memory access patterns
  4. Kernel fusion: MLX automatically fuses operations, but using fast functions ensures the best fusion opportunities

See Also