Documentation Index Fetch the complete documentation index at: https://mintlify.com/ml-explore/mlx/llms.txt
Use this file to discover all available pages before exploring further.
MLX supports writing custom Metal kernels through both the Python and C++ APIs. This allows you to implement highly optimized GPU operations for Apple Silicon.
Quick Start
Here’s a simple custom kernel that computes exp element-wise:
import mlx.core as mx
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name = "myexp" ,
input_names = [ "inp" ],
output_names = [ "out" ],
source = source,
)
def exp_elementwise ( a : mx.array):
outputs = kernel(
inputs = [a],
template = [( "T" , a.dtype)],
grid = (a.size, 1 , 1 ),
threadgroup = ( 256 , 1 , 1 ),
output_shapes = [a.shape],
output_dtypes = [a.dtype],
)
return outputs[ 0 ]
# Use it
a = mx.random.normal( shape = ( 4 , 16 )).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
How It Works
Kernel Source
Only pass the body of the Metal kernel in source. The function signature is generated automatically based on:
Input arrays : From inputs parameter
Output arrays : From output_dtypes parameter
Template parameters : From template parameter
Metal attributes : Any Metal attributes used in source
For the example above, the generated signature is:
template < typename T >
[[ kernel ]] void custom_kernel_myexp (
const device float16_t * inp [[ buffer (0)]] ,
device float16_t * out [[ buffer (1)]] ,
uint3 thread_position_in_grid [[ thread_position_in_grid ]]
) {
uint elem = thread_position_in_grid . x ;
T tmp = inp [elem];
out [elem] = metal :: exp (tmp);
}
Grid and Threadgroups
grid and threadgroup map to Metal’s dispatchThreads function:
grid : Total number of threads to launch (3D)
threadgroup : Size of each threadgroup (3D)
For optimal performance, each threadgroup dimension should be ≤ the corresponding grid dimension.
Template Parameters
Template parameters can be:
mx.core.Dtype - Data types (float32, float16, etc.)
int - Integer constants
bool - Boolean flags
template = [
( "T" , mx.float32), # Type parameter
( "N" , 256 ), # Integer parameter
( "USE_BIAS" , True ) # Boolean parameter
]
Using Shapes and Strides
Row-Contiguous Arrays
By default, ensure_row_contiguous=True copies input arrays to be row-contiguous. This simplifies indexing:
source = """
uint elem = thread_position_in_grid.x;
out[elem] = metal::exp(inp[elem]); // Simple linear indexing
"""
Arbitrary Strides
To avoid copies and support arbitrary strides, set ensure_row_contiguous=False and use MLX indexing utilities:
source = """
uint elem = thread_position_in_grid.x;
// elem_to_loc from mlx/backend/metal/kernels/utils.h
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
out[elem] = metal::exp(tmp); // Output is always row-contiguous
"""
kernel = mx.fast.metal_kernel(
name = "myexp_strided" ,
input_names = [ "inp" ],
output_names = [ "out" ],
source = source,
ensure_row_contiguous = False ,
)
MLX automatically provides {name}_shape, {name}_strides, and {name}_ndim for each input array if they appear in source.
Advanced Example: Grid Sample
Here’s a more complex example implementing bilinear grid sampling.
Reference Implementation
First, a reference implementation using standard MLX ops:
def grid_sample_ref ( x , grid ):
N, H_in, W_in, _ = x.shape
ix = ((grid[ ... , 0 ] + 1 ) * W_in - 1 ) / 2
iy = ((grid[ ... , 1 ] + 1 ) * H_in - 1 ) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
# Gather values from corners
I_nw = x[mx.arange(N)[:, None , None ], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None , None ], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None , None ], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None , None ], iy_se, ix_se, :]
# Apply boundary masks
mask_nw = (iy_nw >= 0 ) & (iy_nw <= H_in - 1 ) & (ix_nw >= 0 ) & (ix_nw <= W_in - 1 )
# ... similar for ne, sw, se
I_nw *= mask_nw[ ... , None ]
# ... similar for others
output = nw[ ... , None ] * I_nw + ne[ ... , None ] * I_ne + \
sw[ ... , None ] * I_sw + se[ ... , None ] * I_se
return output
Now implement as a fused Metal kernel:
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name = "grid_sample" ,
input_names = [ "x" , "grid" ],
output_names = [ "out" ],
source = source,
)
@mx.custom_function
def grid_sample ( x , grid ):
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
outputs = kernel(
inputs = [x, grid],
template = [( "T" , x.dtype)],
output_shapes = [out_shape],
output_dtypes = [x.dtype],
grid = ( int (mx.prod(mx.array(out_shape)).item()), 1 , 1 ),
threadgroup = ( 256 , 1 , 1 ),
)
return outputs[ 0 ]
Performance : For x.shape = (8, 1024, 1024, 64) and grid.shape = (8, 256, 256, 2) on M1 Max:
Reference: 55.7ms
Fused kernel: 6.7ms
Speedup: 8x
Custom VJP with Atomics
Implement the backward pass using atomic operations:
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
// ... compute gradients ...
if (channel_idx < C) {
// Atomically update x_grad
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
}
// ... similar for other corners ...
}
// Reduce within simdgroup first (faster than pure atomics)
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name = "grid_sample_grad" ,
input_names = [ "x" , "grid" , "cotangent" ],
output_names = [ "x_grad" , "grid_grad" ],
source = source,
atomic_outputs = True , # Enable atomic operations on outputs
)
@grid_sample.vjp
def grid_sample_vjp ( primals , cotangent , _ ):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
# Pad to simdgroup size to avoid overlap in simd_sum
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1 ) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs = [x, grid, cotangent],
template = [( "T" , x.dtype)],
output_shapes = [x.shape, grid.shape],
output_dtypes = [x.dtype, x.dtype],
grid = (grid_size, 1 , 1 ),
threadgroup = ( 256 , 1 , 1 ),
init_value = 0 , # Initialize outputs to 0 before kernel
)
return outputs[ 0 ], outputs[ 1 ]
VJP Performance : For the same input sizes:
Reference: 676.4ms
Custom kernel: 16.7ms
Speedup: 40x
Kernel Features
Initialization
init_value = 0 # Initialize all outputs to this value before kernel runs
Useful when the kernel only updates part of the output (e.g., with scatter operations).
Atomic Outputs
atomic_outputs = True # Make outputs atomic in function signature
Enables Metal atomic operations for thread-safe updates. See Metal Shading Language Specification section 6.15.
Verbose Mode
outputs = kernel(
... ,
verbose = True # Print generated Metal code for debugging
)
All Metal attributes from Table 5.8 of the Metal Specification are supported:
thread_position_in_grid - Global thread index
thread_position_in_threadgroup - Local thread index
thread_index_in_simdgroup - Index within SIMD group
threads_per_simdgroup - Size of SIMD group
threadgroup_position_in_grid - Threadgroup index
Example:
source = """
uint gid = thread_position_in_grid.x;
uint lid = thread_position_in_threadgroup.x;
uint simd_idx = thread_index_in_simdgroup;
// Use simdgroup operations
float sum = simd_sum(local_value);
"""
Best Practices
Fuse operations : Combine multiple operations into one kernel
Use simdgroup operations : simd_sum(), simd_max(), etc. are very fast
Minimize atomics : Use simdgroup reductions first, then atomics
Pad to simdgroup size : Avoid false sharing when using simd_sum()
Profile with Xcode : Use Metal GPU capture for detailed profiling
Memory Access
Coalesced reads : Access memory in a pattern that matches thread layout
Bank conflicts : Avoid when using threadgroup memory
Output is contiguous : Output arrays are always row-contiguous
Debugging
Use verbose=True to see generated code
Start with simple kernels and add complexity incrementally
Test against reference implementation
Use Xcode GPU debugger for GPU-side debugging
Utilities
MLX provides utilities in mlx/backend/metal/kernels/utils.h:
// Convert linear index to strided location
uint elem_to_loc ( uint elem , const int* shape , const int64_t* strides , int ndim );
// Ceiling division
int ceildiv ( int a , int b );
These are automatically included in your kernel source.
Next Steps
C++ Extensions Build complete C++ extensions with primitives
Operations Reference Browse the C++ API reference
Resources