Skip to main content
The mlx.core.metal module provides access to Metal-specific functionality for Apple Silicon and AMD GPUs on macOS. This includes device information, GPU capture for debugging, and Metal-specific optimizations.

Overview

The Metal backend is MLX’s primary GPU backend on macOS, providing high-performance computation on:
  • Apple Silicon (M1, M2, M3, M4 chips)
  • AMD GPUs on Intel Macs

Functions

is_available

mlx.core.metal.is_available() -> bool
Check if Metal is available on the current system. Returns True if Metal can be used for GPU computation, False otherwise. Returns:
  • True if Metal is available, False otherwise
Example:
import mlx.core as mx
import mlx.core.metal as metal

if metal.is_available():
    print("Metal GPU acceleration is available")
    info = metal.device_info()
    print(f"Running on: {info['name']}")
else:
    print("Metal is not available, using CPU")
Cross-platform code:
import mlx.core as mx
import mlx.core.metal as metal
import mlx.core.cuda as cuda

if metal.is_available():
    backend = "Metal"
    device_info = metal.device_info()
elif cuda.is_available():
    backend = "CUDA"
    device_info = cuda.device_info()
else:
    backend = "CPU"
    device_info = {"name": "CPU"}

print(f"Using {backend} backend: {device_info['name']}")

device_info

mlx.core.metal.device_info() -> dict
Get information about the Metal device. Returns a dictionary containing details about the GPU, including its name, memory, and capabilities. Returns:
  • Dictionary with device information
Dictionary keys:
  • name (str): Device name (e.g., “Apple M3 Max”)
  • memory (int): Total device memory in bytes
  • max_buffer_length (int): Maximum buffer size
  • max_recommended_working_set_size (int): Recommended working set size
  • registry_id (int): Device registry ID
Example:
import mlx.core.metal as metal

if metal.is_available():
    info = metal.device_info()
    
    print(f"Device: {info['name']}")
    print(f"Total memory: {info['memory'] / (1024**3):.2f} GB")
    print(f"Max buffer size: {info['max_buffer_length'] / (1024**3):.2f} GB")
    print(f"Registry ID: {info['registry_id']}")
Output example:
Device: Apple M3 Max
Total memory: 128.00 GB
Max buffer size: 42.67 GB
Registry ID: 4294968320
System monitoring:
import mlx.core as mx
import mlx.core.metal as metal

def print_gpu_info():
    if not metal.is_available():
        print("Metal not available")
        return
    
    info = metal.device_info()
    active_mem = mx.metal.get_active_memory() / (1024**3)
    peak_mem = mx.metal.get_peak_memory() / (1024**3)
    total_mem = info['memory'] / (1024**3)
    
    print(f"Device: {info['name']}")
    print(f"Memory usage: {active_mem:.2f} GB / {total_mem:.2f} GB")
    print(f"Peak usage: {peak_mem:.2f} GB")
    print(f"Utilization: {active_mem/total_mem*100:.1f}%")

print_gpu_info()

start_capture

mlx.core.metal.start_capture(path: str) -> None
Start capturing Metal GPU commands to a trace file. This is useful for debugging and profiling GPU operations using Xcode’s Metal Debugger. The capture includes all Metal commands executed between start_capture() and stop_capture(). Parameters:
  • path (str): Path where the GPU trace file will be saved (.gputrace extension)
Example:
import mlx.core as mx
import mlx.core.metal as metal

# Start capturing GPU commands
metal.start_capture("debug_trace.gputrace")

# Run operations to debug
x = mx.random.normal((1000, 1000))
y = mx.random.normal((1000, 1000))
z = x @ y  # Matrix multiplication
mx.eval(z)

# Stop capture
metal.stop_capture()

print("GPU trace saved to debug_trace.gputrace")
print("Open with: open debug_trace.gputrace")
Debugging custom kernels:
import mlx.core as mx
import mlx.core.metal as metal
import mlx.core.fast as fast

# Define a custom kernel
source = '''
uint elem = thread_position_in_grid.x;
if (elem < out.size()) {
    out[elem] = in0[elem] * 2.0 + in1[elem];
}
'''

kernel = fast.metal_kernel(
    name="custom_op",
    input_names=["in0", "in1"],
    output_names=["out"],
    source=source
)

# Capture execution for debugging
metal.start_capture("custom_kernel.gputrace")

a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = kernel(a, b)
mx.eval(result)

metal.stop_capture()
Profiling performance:
import mlx.core as mx
import mlx.core.metal as metal
import time

def profile_operation(name, operation, *args):
    # Warm up
    for _ in range(3):
        result = operation(*args)
        mx.eval(result)
    
    # Capture and time
    trace_path = f"{name}.gputrace"
    metal.start_capture(trace_path)
    
    start = time.time()
    result = operation(*args)
    mx.eval(result)
    elapsed = time.time() - start
    
    metal.stop_capture()
    
    print(f"{name}: {elapsed*1000:.2f} ms")
    print(f"Trace saved to {trace_path}")
    return result

# Profile different operations
A = mx.random.normal((2048, 2048))
B = mx.random.normal((2048, 2048))

profile_operation("matmul", lambda a, b: a @ b, A, B)
profile_operation("elementwise", lambda a, b: a * b + a, A, B)
GPU capture can significantly slow down execution. Only use it for debugging and profiling, not in production code.

stop_capture

mlx.core.metal.stop_capture() -> None
Stop capturing Metal GPU commands. Must be called after start_capture() to finalize the trace file. Example:
import mlx.core as mx
import mlx.core.metal as metal

try:
    metal.start_capture("trace.gputrace")
    
    # Your GPU operations here
    result = my_gpu_function()
    mx.eval(result)
    
finally:
    # Always stop capture, even if there's an error
    metal.stop_capture()

Usage Patterns

Device Selection

import mlx.core as mx
import mlx.core.metal as metal

def setup_device():
    """Configure the best available device."""
    if metal.is_available():
        info = metal.device_info()
        print(f"Using Metal on {info['name']}")
        return "gpu"
    else:
        print("Using CPU")
        return "cpu"

device = setup_device()

Memory Monitoring

import mlx.core as mx
import mlx.core.metal as metal

class GPUMemoryMonitor:
    def __init__(self):
        if metal.is_available():
            self.total_memory = metal.device_info()['memory']
        else:
            self.total_memory = None
    
    def check_memory(self):
        if self.total_memory is None:
            return "Metal not available"
        
        active = mx.metal.get_active_memory()
        peak = mx.metal.get_peak_memory()
        
        return {
            "active_gb": active / (1024**3),
            "peak_gb": peak / (1024**3),
            "total_gb": self.total_memory / (1024**3),
            "utilization": active / self.total_memory * 100
        }
    
    def print_stats(self):
        stats = self.check_memory()
        if isinstance(stats, dict):
            print(f"GPU Memory: {stats['active_gb']:.2f} GB / {stats['total_gb']:.2f} GB")
            print(f"Peak: {stats['peak_gb']:.2f} GB")
            print(f"Utilization: {stats['utilization']:.1f}%")
        else:
            print(stats)

monitor = GPUMemoryMonitor()
monitor.print_stats()

Conditional GPU Code

import mlx.core as mx
import mlx.core.metal as metal

def efficient_matmul(A, B):
    """Choose best implementation based on device."""
    if metal.is_available():
        # Use GPU-optimized approach
        return A @ B
    else:
        # Use CPU-friendly approach (e.g., blocking)
        return cpu_blocked_matmul(A, B)

Debugging with Metal Debugger

Opening GPU Traces in Xcode

  1. Capture a trace:
    metal.start_capture("debug.gputrace")
    # ... your code ...
    metal.stop_capture()
    
  2. Open the trace:
    open debug.gputrace
    
  3. In Xcode’s Metal Debugger:
    • View all Metal commands
    • Inspect shader performance
    • Analyze memory usage
    • Debug shader code

What to Look For

  • Command buffer bottlenecks: Long-running operations
  • Memory transfers: Excessive CPU-GPU data movement
  • Kernel efficiency: GPU occupancy and utilization
  • Shader warnings: Potential optimization opportunities

Environment Variables

MLX_METAL_FAST_SYNCH

Enables faster GPU-CPU synchronization.
export MLX_METAL_FAST_SYNCH=1
python my_script.py
Benefits:
  • Lower latency for CPU-GPU collaboration
  • Critical for distributed communication with JACCL
  • Improves performance when frequently reading GPU results

Performance Tips

  1. Check availability once: Cache the result of is_available() instead of calling it repeatedly
  2. Monitor memory: Use device_info() to ensure your workload fits in GPU memory
  3. Batch operations: Larger batches utilize GPU more efficiently
  4. Minimize captures: GPU capture has significant overhead
  5. Use fast sync: Set MLX_METAL_FAST_SYNCH=1 for latency-sensitive workloads

See Also