The mlx.core.metal module provides access to Metal-specific functionality for Apple Silicon and AMD GPUs on macOS. This includes device information, GPU capture for debugging, and Metal-specific optimizations.
Overview
The Metal backend is MLX’s primary GPU backend on macOS, providing high-performance computation on:
- Apple Silicon (M1, M2, M3, M4 chips)
- AMD GPUs on Intel Macs
Functions
is_available
mlx.core.metal.is_available() -> bool
Check if Metal is available on the current system.
Returns True if Metal can be used for GPU computation, False otherwise.
Returns:
True if Metal is available, False otherwise
Example:
import mlx.core as mx
import mlx.core.metal as metal
if metal.is_available():
print("Metal GPU acceleration is available")
info = metal.device_info()
print(f"Running on: {info['name']}")
else:
print("Metal is not available, using CPU")
Cross-platform code:
import mlx.core as mx
import mlx.core.metal as metal
import mlx.core.cuda as cuda
if metal.is_available():
backend = "Metal"
device_info = metal.device_info()
elif cuda.is_available():
backend = "CUDA"
device_info = cuda.device_info()
else:
backend = "CPU"
device_info = {"name": "CPU"}
print(f"Using {backend} backend: {device_info['name']}")
device_info
mlx.core.metal.device_info() -> dict
Get information about the Metal device.
Returns a dictionary containing details about the GPU, including its name, memory, and capabilities.
Returns:
- Dictionary with device information
Dictionary keys:
name (str): Device name (e.g., “Apple M3 Max”)
memory (int): Total device memory in bytes
max_buffer_length (int): Maximum buffer size
max_recommended_working_set_size (int): Recommended working set size
registry_id (int): Device registry ID
Example:
import mlx.core.metal as metal
if metal.is_available():
info = metal.device_info()
print(f"Device: {info['name']}")
print(f"Total memory: {info['memory'] / (1024**3):.2f} GB")
print(f"Max buffer size: {info['max_buffer_length'] / (1024**3):.2f} GB")
print(f"Registry ID: {info['registry_id']}")
Output example:
Device: Apple M3 Max
Total memory: 128.00 GB
Max buffer size: 42.67 GB
Registry ID: 4294968320
System monitoring:
import mlx.core as mx
import mlx.core.metal as metal
def print_gpu_info():
if not metal.is_available():
print("Metal not available")
return
info = metal.device_info()
active_mem = mx.metal.get_active_memory() / (1024**3)
peak_mem = mx.metal.get_peak_memory() / (1024**3)
total_mem = info['memory'] / (1024**3)
print(f"Device: {info['name']}")
print(f"Memory usage: {active_mem:.2f} GB / {total_mem:.2f} GB")
print(f"Peak usage: {peak_mem:.2f} GB")
print(f"Utilization: {active_mem/total_mem*100:.1f}%")
print_gpu_info()
start_capture
mlx.core.metal.start_capture(path: str) -> None
Start capturing Metal GPU commands to a trace file.
This is useful for debugging and profiling GPU operations using Xcode’s Metal Debugger. The capture includes all Metal commands executed between start_capture() and stop_capture().
Parameters:
path (str): Path where the GPU trace file will be saved (.gputrace extension)
Example:
import mlx.core as mx
import mlx.core.metal as metal
# Start capturing GPU commands
metal.start_capture("debug_trace.gputrace")
# Run operations to debug
x = mx.random.normal((1000, 1000))
y = mx.random.normal((1000, 1000))
z = x @ y # Matrix multiplication
mx.eval(z)
# Stop capture
metal.stop_capture()
print("GPU trace saved to debug_trace.gputrace")
print("Open with: open debug_trace.gputrace")
Debugging custom kernels:
import mlx.core as mx
import mlx.core.metal as metal
import mlx.core.fast as fast
# Define a custom kernel
source = '''
uint elem = thread_position_in_grid.x;
if (elem < out.size()) {
out[elem] = in0[elem] * 2.0 + in1[elem];
}
'''
kernel = fast.metal_kernel(
name="custom_op",
input_names=["in0", "in1"],
output_names=["out"],
source=source
)
# Capture execution for debugging
metal.start_capture("custom_kernel.gputrace")
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = kernel(a, b)
mx.eval(result)
metal.stop_capture()
Profiling performance:
import mlx.core as mx
import mlx.core.metal as metal
import time
def profile_operation(name, operation, *args):
# Warm up
for _ in range(3):
result = operation(*args)
mx.eval(result)
# Capture and time
trace_path = f"{name}.gputrace"
metal.start_capture(trace_path)
start = time.time()
result = operation(*args)
mx.eval(result)
elapsed = time.time() - start
metal.stop_capture()
print(f"{name}: {elapsed*1000:.2f} ms")
print(f"Trace saved to {trace_path}")
return result
# Profile different operations
A = mx.random.normal((2048, 2048))
B = mx.random.normal((2048, 2048))
profile_operation("matmul", lambda a, b: a @ b, A, B)
profile_operation("elementwise", lambda a, b: a * b + a, A, B)
GPU capture can significantly slow down execution. Only use it for debugging and profiling, not in production code.
stop_capture
mlx.core.metal.stop_capture() -> None
Stop capturing Metal GPU commands.
Must be called after start_capture() to finalize the trace file.
Example:
import mlx.core as mx
import mlx.core.metal as metal
try:
metal.start_capture("trace.gputrace")
# Your GPU operations here
result = my_gpu_function()
mx.eval(result)
finally:
# Always stop capture, even if there's an error
metal.stop_capture()
Usage Patterns
Device Selection
import mlx.core as mx
import mlx.core.metal as metal
def setup_device():
"""Configure the best available device."""
if metal.is_available():
info = metal.device_info()
print(f"Using Metal on {info['name']}")
return "gpu"
else:
print("Using CPU")
return "cpu"
device = setup_device()
Memory Monitoring
import mlx.core as mx
import mlx.core.metal as metal
class GPUMemoryMonitor:
def __init__(self):
if metal.is_available():
self.total_memory = metal.device_info()['memory']
else:
self.total_memory = None
def check_memory(self):
if self.total_memory is None:
return "Metal not available"
active = mx.metal.get_active_memory()
peak = mx.metal.get_peak_memory()
return {
"active_gb": active / (1024**3),
"peak_gb": peak / (1024**3),
"total_gb": self.total_memory / (1024**3),
"utilization": active / self.total_memory * 100
}
def print_stats(self):
stats = self.check_memory()
if isinstance(stats, dict):
print(f"GPU Memory: {stats['active_gb']:.2f} GB / {stats['total_gb']:.2f} GB")
print(f"Peak: {stats['peak_gb']:.2f} GB")
print(f"Utilization: {stats['utilization']:.1f}%")
else:
print(stats)
monitor = GPUMemoryMonitor()
monitor.print_stats()
Conditional GPU Code
import mlx.core as mx
import mlx.core.metal as metal
def efficient_matmul(A, B):
"""Choose best implementation based on device."""
if metal.is_available():
# Use GPU-optimized approach
return A @ B
else:
# Use CPU-friendly approach (e.g., blocking)
return cpu_blocked_matmul(A, B)
Opening GPU Traces in Xcode
-
Capture a trace:
metal.start_capture("debug.gputrace")
# ... your code ...
metal.stop_capture()
-
Open the trace:
-
In Xcode’s Metal Debugger:
- View all Metal commands
- Inspect shader performance
- Analyze memory usage
- Debug shader code
What to Look For
- Command buffer bottlenecks: Long-running operations
- Memory transfers: Excessive CPU-GPU data movement
- Kernel efficiency: GPU occupancy and utilization
- Shader warnings: Potential optimization opportunities
Environment Variables
Enables faster GPU-CPU synchronization.
export MLX_METAL_FAST_SYNCH=1
python my_script.py
Benefits:
- Lower latency for CPU-GPU collaboration
- Critical for distributed communication with JACCL
- Improves performance when frequently reading GPU results
- Check availability once: Cache the result of
is_available() instead of calling it repeatedly
- Monitor memory: Use
device_info() to ensure your workload fits in GPU memory
- Batch operations: Larger batches utilize GPU more efficiently
- Minimize captures: GPU capture has significant overhead
- Use fast sync: Set
MLX_METAL_FAST_SYNCH=1 for latency-sensitive workloads
See Also