Quick Start
Here’s a simple custom kernel that computesexp element-wise:
How It Works
Kernel Source
Only pass the body of the Metal kernel insource. The function signature is generated automatically based on:
- Input arrays: From
inputsparameter - Output arrays: From
output_dtypesparameter - Template parameters: From
templateparameter - Metal attributes: Any Metal attributes used in
source
Grid and Threadgroups
grid and threadgroup map to Metal’s dispatchThreads function:
grid: Total number of threads to launch (3D)threadgroup: Size of each threadgroup (3D)
Template Parameters
Template parameters can be:mx.core.Dtype- Data types (float32, float16, etc.)int- Integer constantsbool- Boolean flags
Using Shapes and Strides
Row-Contiguous Arrays
By default,ensure_row_contiguous=True copies input arrays to be row-contiguous. This simplifies indexing:
Arbitrary Strides
To avoid copies and support arbitrary strides, setensure_row_contiguous=False and use MLX indexing utilities:
{name}_shape, {name}_strides, and {name}_ndim for each input array if they appear in source.
Advanced Example: Grid Sample
Here’s a more complex example implementing bilinear grid sampling.Reference Implementation
First, a reference implementation using standard MLX ops:Fused Metal Kernel
Now implement as a fused Metal kernel:x.shape = (8, 1024, 1024, 64) and grid.shape = (8, 256, 256, 2) on M1 Max:
- Reference: 55.7ms
- Fused kernel: 6.7ms
- Speedup: 8x
Custom VJP with Atomics
Implement the backward pass using atomic operations:- Reference: 676.4ms
- Custom kernel: 16.7ms
- Speedup: 40x
Kernel Features
Initialization
Atomic Outputs
Verbose Mode
Metal Attributes
All Metal attributes from Table 5.8 of the Metal Specification are supported:thread_position_in_grid- Global thread indexthread_position_in_threadgroup- Local thread indexthread_index_in_simdgroup- Index within SIMD groupthreads_per_simdgroup- Size of SIMD groupthreadgroup_position_in_grid- Threadgroup index
Best Practices
Performance Tips
- Fuse operations: Combine multiple operations into one kernel
- Use simdgroup operations:
simd_sum(),simd_max(), etc. are very fast - Minimize atomics: Use simdgroup reductions first, then atomics
- Pad to simdgroup size: Avoid false sharing when using
simd_sum() - Profile with Xcode: Use Metal GPU capture for detailed profiling
Memory Access
- Coalesced reads: Access memory in a pattern that matches thread layout
- Bank conflicts: Avoid when using threadgroup memory
- Output is contiguous: Output arrays are always row-contiguous
Debugging
- Use
verbose=Trueto see generated code - Start with simple kernels and add complexity incrementally
- Test against reference implementation
- Use Xcode GPU debugger for GPU-side debugging
Utilities
MLX provides utilities inmlx/backend/metal/kernels/utils.h:
Next Steps
C++ Extensions
Build complete C++ extensions with primitives
Operations Reference
Browse the C++ API reference