MLX provides functions to monitor and manage GPU memory usage. Understanding memory management is crucial for training large models and optimizing performance.
Overview
MLX uses a caching allocator to efficiently manage GPU memory. The allocator:
- Caches freed memory for reuse
- Reduces allocation overhead
- Can be configured with limits
- Provides detailed usage statistics
Memory Functions
get_active_memory
mlx.core.get_active_memory() -> int
Get the current amount of active GPU memory in bytes.
Active memory includes all allocated arrays that are currently in use.
Returns:
Example:
import mlx.core as mx
# Check initial memory
print(f"Initial memory: {mx.get_active_memory() / 1024**2:.2f} MB")
# Allocate arrays
x = mx.random.normal((10000, 10000))
y = mx.random.normal((10000, 10000))
mx.eval(x, y)
print(f"After allocation: {mx.get_active_memory() / 1024**2:.2f} MB")
# Free arrays
del x, y
print(f"After deletion: {mx.get_active_memory() / 1024**2:.2f} MB")
Monitoring during training:
import mlx.core as mx
import mlx.nn as nn
class MemoryTracker:
def __init__(self):
self.history = []
def track(self, label=""):
mem_mb = mx.get_active_memory() / (1024**2)
self.history.append((label, mem_mb))
print(f"{label}: {mem_mb:.2f} MB")
def report(self):
print("\nMemory usage timeline:")
for label, mem_mb in self.history:
print(f" {label:30s} {mem_mb:8.2f} MB")
tracker = MemoryTracker()
tracker.track("Start")
model = create_model()
tracker.track("After model creation")
for epoch in range(5):
for batch in data_loader:
loss, grads = compute_loss(model, batch)
optimizer.update(model, grads)
tracker.track(f"After epoch {epoch + 1}")
tracker.report()
get_peak_memory
mlx.core.get_peak_memory() -> int
Get the peak GPU memory usage in bytes since the last reset.
This tracks the maximum active memory usage over time.
Returns:
Example:
import mlx.core as mx
# Reset peak counter
mx.reset_peak_memory()
# Run computation
for i in range(10):
size = (i + 1) * 1000
x = mx.random.normal((size, size))
y = x @ x.T
mx.eval(y)
current = mx.get_active_memory() / (1024**2)
peak = mx.get_peak_memory() / (1024**2)
print(f"Iteration {i}: Current={current:.1f} MB, Peak={peak:.1f} MB")
print(f"\nPeak memory usage: {mx.get_peak_memory() / (1024**2):.2f} MB")
Profiling a function:
import mlx.core as mx
def profile_memory(func, *args, **kwargs):
"""Profile peak memory usage of a function."""
# Reset and record initial state
mx.reset_peak_memory()
initial = mx.get_active_memory()
# Run function
result = func(*args, **kwargs)
mx.eval(result)
# Collect statistics
peak = mx.get_peak_memory()
final = mx.get_active_memory()
print(f"Memory profile for {func.__name__}:")
print(f" Initial: {initial / (1024**2):.2f} MB")
print(f" Peak: {peak / (1024**2):.2f} MB")
print(f" Final: {final / (1024**2):.2f} MB")
print(f" Leaked: {(final - initial) / (1024**2):.2f} MB")
return result
def my_computation(n):
x = mx.random.normal((n, n))
return x @ x.T
result = profile_memory(my_computation, 5000)
reset_peak_memory
mlx.core.reset_peak_memory() -> None
Reset the peak memory counter.
Use this before profiling specific code sections to get accurate peak usage measurements.
Example:
import mlx.core as mx
# Profile different operations
operations = {
"matmul": lambda: mx.random.normal((5000, 5000)) @ mx.random.normal((5000, 5000)),
"conv2d": lambda: mx.conv2d(mx.random.normal((8, 224, 224, 3)), mx.random.normal((64, 3, 3, 3))),
"softmax": lambda: mx.softmax(mx.random.normal((1000, 10000)))
}
for name, op in operations.items():
mx.reset_peak_memory()
result = op()
mx.eval(result)
peak = mx.get_peak_memory() / (1024**2)
print(f"{name:15s}: {peak:8.2f} MB peak")
get_cache_memory
mlx.core.get_cache_memory() -> int
Get the amount of cached GPU memory in bytes.
Cached memory is memory that has been freed but not returned to the system, available for reuse.
Returns:
Example:
import mlx.core as mx
def memory_status():
active = mx.get_active_memory() / (1024**2)
cached = mx.get_cache_memory() / (1024**2)
total = active + cached
print(f"Active: {active:8.2f} MB")
print(f"Cached: {cached:8.2f} MB")
print(f"Total: {total:8.2f} MB")
print(f"Efficiency: {active/total*100:.1f}% utilized")
print("Before allocation:")
memory_status()
# Allocate and free
x = mx.random.normal((10000, 10000))
mx.eval(x)
print("\nAfter allocation:")
memory_status()
del x
print("\nAfter deletion (memory cached):")
memory_status()
mx.clear_cache()
print("\nAfter clearing cache:")
memory_status()
clear_cache
mlx.core.clear_cache() -> None
Clear the memory cache and return memory to the system.
Use this when you need to free up memory for other applications or when running multiple experiments sequentially.
Example:
import mlx.core as mx
def train_model(config):
# Train model
model = create_model(config)
train(model)
# Clear memory before next run
del model
mx.clear_cache()
print(f"Memory released: {mx.get_cache_memory() / (1024**2):.2f} MB")
# Train multiple models without memory accumulation
for config in configs:
train_model(config)
Memory cleanup in notebooks:
import mlx.core as mx
# After a large experiment
del model, optimizer, data_loader
mx.clear_cache()
print(f"Memory freed: {mx.get_cache_memory() / (1024**2):.2f} MB")
print(f"Active memory: {mx.get_active_memory() / (1024**2):.2f} MB")
set_memory_limit
mlx.core.set_memory_limit(limit: int, relaxed: bool = True) -> None
Set the maximum amount of GPU memory MLX can use.
Parameters:
limit (int): Memory limit in bytes (0 means unlimited)
relaxed (bool): If True, allows temporary exceeding of limit. Default: True
Example:
import mlx.core as mx
import mlx.core.metal as metal
if metal.is_available():
info = metal.device_info()
total_memory = info['memory']
# Use at most 80% of available memory
limit = int(total_memory * 0.8)
mx.set_memory_limit(limit, relaxed=True)
print(f"Memory limit set to {limit / (1024**3):.2f} GB")
Sharing GPU with other processes:
import mlx.core as mx
# Reserve 16 GB for other applications
reserved_gb = 16
limit = (64 - reserved_gb) * 1024**3 # Assuming 64 GB total
mx.set_memory_limit(limit, relaxed=False)
print(f"MLX limited to {limit / (1024**3):.0f} GB")
When relaxed=False, allocations that would exceed the limit will fail immediately. Use relaxed=True for more flexibility.
set_cache_limit
mlx.core.set_cache_limit(limit: int) -> None
Set the maximum amount of cached memory.
When the cache exceeds this limit, memory is returned to the system.
Parameters:
limit (int): Cache limit in bytes (0 means unlimited)
Example:
import mlx.core as mx
# Limit cache to 4 GB
cache_limit = 4 * 1024**3
mx.set_cache_limit(cache_limit)
print(f"Cache limit set to {cache_limit / (1024**3):.0f} GB")
# Cache will automatically be trimmed when it exceeds 4 GB
for i in range(100):
x = mx.random.normal((1000, 1000))
y = x @ x
mx.eval(y)
del x, y
if i % 10 == 0:
cached = mx.get_cache_memory() / (1024**2)
print(f"Iteration {i}: Cached memory = {cached:.2f} MB")
set_wired_limit
mlx.core.set_wired_limit(limit: int) -> None
Set the maximum amount of wired (pinned) memory.
Wired memory is locked in physical RAM and cannot be paged out. On Apple Silicon with unified memory, this affects memory available to the GPU.
Parameters:
limit (int): Wired memory limit in bytes (0 means unlimited)
Example:
import mlx.core as mx
# Limit wired memory to 32 GB on a Mac with 64 GB RAM
wired_limit = 32 * 1024**3
mx.set_wired_limit(wired_limit)
print(f"Wired memory limit set to {wired_limit / (1024**3):.0f} GB")
Practical Examples
Memory-Efficient Training
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
class MemoryEfficientTrainer:
def __init__(self, model, memory_limit_gb=None):
self.model = model
if memory_limit_gb:
mx.set_memory_limit(memory_limit_gb * 1024**3, relaxed=True)
# Set cache limit to 4 GB
mx.set_cache_limit(4 * 1024**3)
def train_epoch(self, data_loader):
mx.reset_peak_memory()
for i, batch in enumerate(data_loader):
loss, grads = self.loss_fn(self.model, batch)
self.optimizer.update(self.model, grads)
# Periodic memory cleanup
if i % 100 == 0:
mx.clear_cache()
# Periodic memory report
if i % 50 == 0:
active = mx.get_active_memory() / (1024**2)
peak = mx.get_peak_memory() / (1024**2)
print(f"Batch {i}: Active={active:.0f}MB, Peak={peak:.0f}MB")
trainer = MemoryEfficientTrainer(model, memory_limit_gb=40)
trainer.train_epoch(train_loader)
Memory Leak Detection
import mlx.core as mx
class LeakDetector:
def __init__(self, tolerance_mb=10):
self.tolerance = tolerance_mb * 1024**2
self.baseline = None
def start(self):
mx.clear_cache()
self.baseline = mx.get_active_memory()
print(f"Baseline memory: {self.baseline / (1024**2):.2f} MB")
def check(self, label=""):
mx.clear_cache()
current = mx.get_active_memory()
leaked = current - self.baseline
if leaked > self.tolerance:
print(f"⚠️ Potential leak {label}: {leaked / (1024**2):.2f} MB")
return True
else:
print(f"✓ No leak {label}: {leaked / (1024**2):.2f} MB")
return False
detector = LeakDetector(tolerance_mb=5)
detector.start()
for i in range(10):
x = mx.random.normal((1000, 1000))
y = x @ x
mx.eval(y)
del x, y
detector.check(f"iteration {i}")
Dynamic Batch Size Selection
import mlx.core as mx
import mlx.core.metal as metal
def find_optimal_batch_size(model, input_shape, max_memory_gb=None):
"""Find largest batch size that fits in memory."""
if max_memory_gb is None and metal.is_available():
info = metal.device_info()
max_memory_gb = info['memory'] / (1024**3) * 0.8 # Use 80%
batch_size = 1
max_batch = 1
while True:
batch_size *= 2
mx.clear_cache()
mx.reset_peak_memory()
try:
# Try forward and backward pass
x = mx.random.normal((batch_size,) + input_shape)
output = model(x)
loss = mx.mean(output)
grads = mx.grad(lambda m, x: mx.mean(m(x)))(model, x)
mx.eval(loss, grads)
peak_gb = mx.get_peak_memory() / (1024**3)
if peak_gb > max_memory_gb:
break
max_batch = batch_size
print(f"Batch size {batch_size}: {peak_gb:.2f} GB - OK")
except Exception as e:
print(f"Batch size {batch_size} failed: {e}")
break
mx.clear_cache()
print(f"\nOptimal batch size: {max_batch}")
return max_batch
model = create_model()
batch_size = find_optimal_batch_size(model, input_shape=(3, 224, 224))
Memory Debugging Tips
- Use
clear_cache() frequently: Especially in notebooks and between experiments
- Monitor peak memory: Track peak usage to find memory bottlenecks
- Set conservative limits: Leave headroom for memory spikes
- Profile incrementally: Add memory checks at key points in your code
- Delete large arrays: Use
del explicitly and call clear_cache()
- Watch for leaks: Use the leak detector pattern shown above
See Also