Skip to main content
The mlx.core.cuda module provides access to CUDA-specific functionality for running MLX on NVIDIA GPUs. This enables MLX to run on Linux and Windows systems with NVIDIA hardware.

Overview

The CUDA backend allows MLX to leverage NVIDIA GPUs for computation. When available, it provides:
  • High-performance computation on NVIDIA GPUs
  • Multi-GPU support via the NCCL distributed backend
  • Cross-platform compatibility (Linux, Windows)

Functions

is_available

mlx.core.cuda.is_available() -> bool
Check if CUDA is available on the current system. Returns True if MLX was compiled with CUDA support and a compatible NVIDIA GPU is detected, False otherwise. Returns:
  • True if CUDA is available, False otherwise
Example:
import mlx.core as mx
import mlx.core.cuda as cuda

if cuda.is_available():
    print("CUDA GPU acceleration is available")
    print("MLX will use NVIDIA GPU for computation")
else:
    print("CUDA is not available")
    print("Possible reasons:")
    print("  - MLX not compiled with CUDA support")
    print("  - No NVIDIA GPU detected")
    print("  - CUDA drivers not installed")
Cross-platform device detection:
import mlx.core as mx
import mlx.core.metal as metal
import mlx.core.cuda as cuda

def get_device_info():
    """Detect and return information about available compute devices."""
    if metal.is_available():
        info = metal.device_info()
        return {
            "backend": "Metal",
            "device": info["name"],
            "memory_gb": info["memory"] / (1024**3)
        }
    elif cuda.is_available():
        return {
            "backend": "CUDA",
            "device": "NVIDIA GPU",
            "memory_gb": None  # Use nvidia-smi to get this
        }
    else:
        return {
            "backend": "CPU",
            "device": "CPU",
            "memory_gb": None
        }

device = get_device_info()
print(f"Running on {device['backend']}: {device['device']}")
if device['memory_gb']:
    print(f"Device memory: {device['memory_gb']:.2f} GB")

Usage

Basic CUDA Check

import mlx.core as mx
import mlx.core.cuda as cuda

# Check if CUDA is available before running GPU-intensive code
if not cuda.is_available():
    raise RuntimeError("This script requires CUDA support")

# Proceed with GPU computation
x = mx.random.normal((10000, 10000))
y = mx.random.normal((10000, 10000))
z = x @ y
mx.eval(z)

print("Computation completed on NVIDIA GPU")

Conditional Code Paths

import mlx.core as mx
import mlx.core.metal as metal
import mlx.core.cuda as cuda

def create_large_model():
    """Create a model with size based on available GPU."""
    if metal.is_available():
        # Apple Silicon typically has unified memory
        info = metal.device_info()
        max_memory = info["memory"] / (1024**3)
        print(f"Using Metal with {max_memory:.0f} GB unified memory")
        model_size = "large"
    elif cuda.is_available():
        # NVIDIA GPU - check with nvidia-smi for actual memory
        print("Using CUDA")
        model_size = "medium"  # Conservative default
    else:
        print("Using CPU")
        model_size = "small"
    
    return create_model(size=model_size)

Distributed Training Setup

import mlx.core as mx
import mlx.core.cuda as cuda
import mlx.core.distributed as dist

def setup_distributed():
    """Initialize distributed backend based on available hardware."""
    if cuda.is_available():
        # Use NCCL for NVIDIA GPUs
        print("Initializing NCCL backend for CUDA")
        world = dist.init(backend="nccl")
    else:
        # Fall back to other backends
        print("CUDA not available, using alternative backend")
        world = dist.init()
    
    return world

world = setup_distributed()
print(f"Rank {world.rank()} of {world.size()}")

Installation

Building MLX with CUDA Support

To use the CUDA backend, MLX must be built with CUDA support:
# Clone MLX repository
git clone https://github.com/ml-explore/mlx.git
cd mlx

# Build with CUDA support
mkdir build
cd build
cmake .. -DMLX_BUILD_CUDA=ON
make -j

# Install
pip install -e python

Requirements

  • NVIDIA GPU with compute capability 7.0 or higher
  • CUDA Toolkit 11.0 or later
  • Compatible NVIDIA drivers
  • Linux or Windows operating system

Verifying Installation

# Check CUDA availability
python -c "import mlx.core.cuda as cuda; print('CUDA available:', cuda.is_available())"

# Check NVIDIA GPU
nvidia-smi

Environment Variables

CUDA_VISIBLE_DEVICES

Control which GPUs are visible to MLX:
# Use only GPU 0
export CUDA_VISIBLE_DEVICES=0
python my_script.py

# Use GPUs 0 and 1
export CUDA_VISIBLE_DEVICES=0,1
python my_script.py

# Hide all GPUs (force CPU)
export CUDA_VISIBLE_DEVICES=""
python my_script.py

MLX_DISABLE_CUDA

Disable CUDA even if available:
export MLX_DISABLE_CUDA=1
python my_script.py

Distributed Training with NCCL

When CUDA is available, MLX can use NCCL for efficient multi-GPU communication:
import mlx.core as mx
import mlx.core.cuda as cuda
import mlx.core.distributed as dist

if not cuda.is_available():
    raise RuntimeError("CUDA required for this example")

# Initialize NCCL backend
world = dist.init(backend="nccl")

print(f"Process {world.rank()} of {world.size()}")

# Data parallel training
x = mx.random.normal((32, 1000))
x_sum = dist.all_sum(x)

print(f"All-reduce completed on rank {world.rank()}")
Launching with mlx.launch:
# Launch on 8 GPUs
mlx.launch -n 8 train.py

# Launch on specific GPUs
CUDA_VISIBLE_DEVICES=0,1,2,3 mlx.launch -n 4 train.py

# Launch on multiple nodes
mlx.launch --hosts node1,node2 -n 8 train.py

Comparing Metal and CUDA

FeatureMetalCUDA
PlatformmacOS onlyLinux, Windows
HardwareApple Silicon, AMD GPUsNVIDIA GPUs
Unified MemoryYes (Apple Silicon)No
Multi-GPUJACCL (Thunderbolt)NCCL (NVLink, PCIe)
DebuggingXcode Metal DebuggerNVIDIA Nsight
PerformanceOptimized for AppleOptimized for NVIDIA

Troubleshooting

CUDA Not Detected

import mlx.core.cuda as cuda

if not cuda.is_available():
    print("CUDA not available. Troubleshooting:")
    
    # Check if MLX was built with CUDA
    import mlx.core as mx
    if not hasattr(mx, 'cuda'):
        print("❌ MLX not built with CUDA support")
        print("   Rebuild MLX with -DMLX_BUILD_CUDA=ON")
    
    # Check NVIDIA drivers
    import subprocess
    try:
        subprocess.run(["nvidia-smi"], check=True)
        print("✓ NVIDIA drivers installed")
    except:
        print("❌ NVIDIA drivers not found")
        print("   Install CUDA drivers from nvidia.com")
    
    # Check CUDA_VISIBLE_DEVICES
    import os
    if os.environ.get("CUDA_VISIBLE_DEVICES") == "":
        print("❌ CUDA_VISIBLE_DEVICES is empty")
        print("   Remove or modify CUDA_VISIBLE_DEVICES")

Performance Issues

import mlx.core as mx
import mlx.core.cuda as cuda
import time

if cuda.is_available():
    # Warm up GPU
    x = mx.random.normal((100, 100))
    mx.eval(x @ x)
    
    # Benchmark
    sizes = [1000, 2000, 4000]
    for size in sizes:
        x = mx.random.normal((size, size))
        
        start = time.time()
        y = x @ x
        mx.eval(y)
        elapsed = time.time() - start
        
        ops = 2 * size**3  # Matrix multiply operations
        gflops = ops / elapsed / 1e9
        
        print(f"Size {size}x{size}: {gflops:.1f} GFLOPS")

Performance Tips

  1. Use batch processing: Larger batches better utilize GPU parallelism
  2. Enable cuDNN: Ensure cuDNN is installed for optimized convolutions
  3. Monitor GPU utilization: Use nvidia-smi to check if GPU is fully utilized
  4. Use NCCL for multi-GPU: Much faster than other backends for NVIDIA GPUs
  5. Pin memory: Reduce CPU-GPU transfer overhead (future MLX feature)

See Also