Skip to main content
MLX provides distributed communication operations that allow computational workloads to be shared across multiple physical machines or GPUs. The distributed module supports several backends including MPI, Ring, JACCL (Thunderbolt RDMA), and NCCL.

Overview

Distributed communication in MLX enables:
  • Data parallelism for training large models
  • Tensor parallelism for models too large for a single device
  • Efficient multi-node inference
  • Collective operations like all-reduce and all-gather

Getting Started

A basic distributed program in MLX:
import mlx.core as mx

# Initialize distributed backend
world = mx.distributed.init()

# Sum array across all processes
x = mx.distributed.all_sum(mx.ones(10))
print(f"Rank {world.rank()}: {x}")
Running the program:
# Run on 4 local processes
mlx.launch -n 4 my_script.py

# Run on remote hosts
mlx.launch --hosts host1,host2,host3 my_script.py

Classes

Group

mlx.core.distributed.Group
Represents a group of processes in distributed communication. Methods:

rank

rank() -> int
Returns the rank of the current process in the group. Example:
import mlx.core as mx

world = mx.distributed.init()
print(f"My rank is: {world.rank()}")

size

size() -> int
Returns the total number of processes in the group. Example:
import mlx.core as mx

world = mx.distributed.init()
print(f"Total processes: {world.size()}")

Functions

is_available

mlx.core.distributed.is_available(backend: str = "any") -> bool
Check if a distributed backend is available. Parameters:
  • backend (str): Backend name. Options: "any", "mpi", "ring", "jaccl", "nccl". Default: "any"
Returns:
  • True if the backend is available, False otherwise
Example:
import mlx.core as mx

if mx.distributed.is_available("mpi"):
    print("MPI backend is available")

if mx.distributed.is_available("jaccl"):
    print("JACCL (Thunderbolt RDMA) is available")

init

mlx.core.distributed.init(backend: str = "any") -> Group
Initialize the distributed backend and return the world group. Parameters:
  • backend (str): Backend to use. Options: "any", "mpi", "ring", "jaccl", "nccl". Default: "any"
Returns:
  • Group object representing all processes
Example:
import mlx.core as mx

# Initialize any available backend
world = mx.distributed.init()

# Initialize specific backend
mpi_world = mx.distributed.init(backend="mpi")
jaccl_world = mx.distributed.init(backend="jaccl")
After a distributed backend is successfully initialized, subsequent calls to init() with backend="any" will return the same backend, not initialize a new one.

all_sum

mlx.core.distributed.all_sum(x: array) -> array
Sum the input array across all processes. Each process receives the total sum of all input arrays. Parameters:
  • x (array): Input array to sum
Returns:
  • Array containing the sum across all processes
Example:
import mlx.core as mx

world = mx.distributed.init()

# Each process has value equal to its rank
x = mx.full((10,), float(world.rank()))

# Sum across all processes
total = mx.distributed.all_sum(x)

# If world.size() == 4, each process now has [0+1+2+3, ...] = [6, 6, ...]
print(f"Rank {world.rank()}: {total}")
Common use case - Gradient synchronization:
import mlx.core as mx
import mlx.nn as nn

world = mx.distributed.init()

# Compute gradients on local batch
loss, grads = loss_fn(model, batch)

# Average gradients across all processes
for key in grads:
    grads[key] = mx.distributed.all_sum(grads[key]) / world.size()

# Update model with averaged gradients
optimizer.update(model, grads)

all_gather

mlx.core.distributed.all_gather(x: array) -> array
Gather arrays from all processes. Each process receives a concatenation of arrays from all processes along the first dimension. Parameters:
  • x (array): Input array to gather
Returns:
  • Array with first dimension equal to x.shape[0] * world.size()
Example:
import mlx.core as mx

world = mx.distributed.init()

# Each process has different data
x = mx.full((2, 3), float(world.rank()))
print(f"Rank {world.rank()} input shape: {x.shape}")  # (2, 3)

# Gather from all processes
gathered = mx.distributed.all_gather(x)
print(f"Rank {world.rank()} output shape: {gathered.shape}")  # (8, 3) if 4 processes

# gathered contains:
# [[0, 0, 0],
#  [0, 0, 0],  # from rank 0
#  [1, 1, 1],
#  [1, 1, 1],  # from rank 1
#  [2, 2, 2],
#  [2, 2, 2],  # from rank 2
#  [3, 3, 3],
#  [3, 3, 3]]  # from rank 3

send

mlx.core.distributed.send(x: array, dst: int, tag: int = 0)
Send an array to a specific process. Parameters:
  • x (array): Array to send
  • dst (int): Destination rank
  • tag (int): Message tag for matching with receive. Default: 0
Example:
import mlx.core as mx

world = mx.distributed.init(backend="mpi")  # MPI required for send/recv

if world.rank() == 0:
    # Rank 0 sends to rank 1
    data = mx.array([1.0, 2.0, 3.0])
    mx.distributed.send(data, dst=1)
    print("Rank 0 sent data")
Point-to-point communication (send and recv) is not supported by the Ring backend. Use MPI, JACCL, or NCCL for these operations.

recv

mlx.core.distributed.recv(
    shape: tuple,
    dtype: Dtype,
    src: int,
    tag: int = 0
) -> array
Receive an array from a specific process. You must know the shape and dtype of the incoming array. Parameters:
  • shape (tuple): Shape of the array to receive
  • dtype (Dtype): Data type of the array to receive
  • src (int): Source rank
  • tag (int): Message tag for matching with send. Default: 0
Returns:
  • Received array
Example:
import mlx.core as mx

world = mx.distributed.init(backend="mpi")

if world.rank() == 1:
    # Rank 1 receives from rank 0
    data = mx.distributed.recv(
        shape=(3,),
        dtype=mx.float32,
        src=0
    )
    print(f"Rank 1 received: {data}")

recv_like

mlx.core.distributed.recv_like(
    x: array,
    src: int,
    tag: int = 0
) -> array
Receive an array with the same shape and dtype as the template array. Parameters:
  • x (array): Template array (only shape and dtype are used)
  • src (int): Source rank
  • tag (int): Message tag. Default: 0
Returns:
  • Received array
Example:
import mlx.core as mx

world = mx.distributed.init(backend="mpi")

if world.rank() == 0:
    data = mx.random.normal((10, 20))
    mx.distributed.send(data, dst=1)
elif world.rank() == 1:
    # Receive with matching shape
    template = mx.zeros((10, 20))
    data = mx.distributed.recv_like(template, src=0)
    print(f"Received shape: {data.shape}")

Backends

Ring Backend

  • Always available, no dependencies
  • Uses TCP sockets
  • Nodes connected in a ring topology
  • Best for Ethernet or Thunderbolt connections
  • Does not support point-to-point send/recv
world = mx.distributed.init(backend="ring")

JACCL Backend

  • Low-latency RDMA over Thunderbolt 5 on macOS 26.2+
  • Requires fully connected mesh topology
  • Order of magnitude lower latency than Ring
  • Best for tensor parallelism
world = mx.distributed.init(backend="jaccl")
Setting up JACCL:
# Auto-configure thunderbolt mesh
mlx.distributed_config --verbose \
    --hosts m3-1,m3-2,m3-3,m3-4 \
    --over thunderbolt --backend jaccl \
    --auto-setup --output hostfile.json

# Launch with JACCL
mlx.launch --backend jaccl --hostfile hostfile.json \
    --env MLX_METAL_FAST_SYNCH=1 -- \
    python my_script.py

MPI Backend

  • Full-featured, mature library
  • Supports all operations
  • Requires MPI installation
world = mx.distributed.init(backend="mpi")
Installation:
# Via conda (recommended)
conda install conda-forge::openmpi

# Launch with MPI
mlx.launch --backend mpi -n 4 my_script.py

NCCL Backend

  • High-performance for CUDA environments
  • Default backend for CUDA in mlx.launch
  • Supports multi-GPU and multi-node
world = mx.distributed.init(backend="nccl")
# Launch on 8 GPUs
mlx.launch -n 8 my_script.py

Practical Examples

Data Parallel Training

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# Initialize distributed
world = mx.distributed.init()

# Each process loads different data
train_loader = get_data_loader(
    shard_id=world.rank(),
    num_shards=world.size()
)

# Training loop
for batch in train_loader:
    # Forward pass
    loss, grads = loss_fn(model, batch)
    
    # Synchronize gradients
    grads = mx.tree_map(
        lambda g: mx.distributed.all_sum(g) / world.size(),
        grads
    )
    
    # Update model
    optimizer.update(model, grads)

Model Parallel Inference

import mlx.core as mx

world = mx.distributed.init(backend="jaccl")  # Low latency needed

# Split model across devices
if world.rank() == 0:
    output = first_half_of_model(input)
    mx.distributed.send(output, dst=1)
elif world.rank() == 1:
    input = mx.distributed.recv_like(template, src=0)
    output = second_half_of_model(input)

Distributed Evaluation

import mlx.core as mx

world = mx.distributed.init()

# Each process evaluates different data
local_correct = evaluate_shard(model, test_data[world.rank()])

# Sum correct predictions across all processes
total_correct = mx.distributed.all_sum(mx.array([local_correct]))

if world.rank() == 0:
    accuracy = total_correct[0] / total_samples
    print(f"Accuracy: {accuracy:.2%}")

Performance Tips

  1. Batch communication: Combine small arrays into larger ones before calling all_sum or all_gather
  2. Use JACCL for tensor parallelism: The low latency is critical for frequent small communications
  3. Test locally first: Run with mlx.launch -n 2 on a single machine before scaling up
  4. Enable fast sync: Set MLX_METAL_FAST_SYNCH=1 when using JACCL
  5. Overlap computation and communication: Start the next computation before waiting for communication to complete

See Also