Skip to main content
MLX supports executing operations on different devices (CPU, GPU) and coordinating execution through streams. Streams allow for asynchronous execution and fine-grained control over when operations run.

Devices

Device

Represents a computation device in MLX.
import mlx.core as mx

# Create devices
cpu = mx.Device(mx.cpu)
gpu = mx.Device(mx.gpu)

print(cpu)  # Device(cpu, 0)
print(gpu)  # Device(gpu, 0)
Devices have two components:
  • Device type: Either cpu or gpu
  • Device index: Integer identifying which device (useful for multi-GPU systems)
Common device types:
import mlx.core as mx

mx.cpu  # CPU device type
mx.gpu  # GPU device type

default_device

Get the default device for operations.
import mlx.core as mx

device = mx.default_device()
print(device)  # Device(gpu, 0) or Device(cpu, 0)
device
Device
The current default device.

set_default_device

Set the default device for operations.
import mlx.core as mx

# Set CPU as default
mx.set_default_device(mx.cpu)

# Now arrays are created on CPU by default
a = mx.array([1, 2, 3])

# Set GPU as default
mx.set_default_device(mx.gpu)

# Now arrays are created on GPU by default
b = mx.array([4, 5, 6])
device
Device
required
The device to set as default.

device_count

Get the number of devices available.
import mlx.core as mx

num_gpus = mx.device_count(mx.gpu)
print(f"Number of GPUs: {num_gpus}")

num_cpus = mx.device_count(mx.cpu)
print(f"Number of CPUs: {num_cpus}")  # Usually 1
device_type
DeviceType
required
The type of device to count (e.g., mx.cpu or mx.gpu).
count
int
The number of available devices of the specified type.

device_info

Get information about a device.
import mlx.core as mx

info = mx.device_info(mx.gpu)
print(info)
# Returns device-specific information like memory, architecture, etc.
device
Device
required
The device to query.
info
dict
Dictionary containing device information.

Streams

Stream

Represents an execution stream for scheduling operations.
import mlx.core as mx

# Create a new stream on the GPU
stream = mx.Stream(mx.gpu)

# Schedule operations on this stream
a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream)
Streams allow operations to execute asynchronously. Operations on different streams can run in parallel, while operations on the same stream execute sequentially. Stream properties:
device
Device
The device associated with the stream.

default_stream

Get the default stream for a device.
import mlx.core as mx

# Get default stream for GPU
stream = mx.default_stream(mx.gpu)

# Get default stream for CPU
cpu_stream = mx.default_stream(mx.cpu)
device
Device
required
The device to get the default stream for.
stream
Stream
The default stream for the specified device.

set_default_stream

Set the default stream for a device.
import mlx.core as mx

# Create a custom stream
custom_stream = mx.Stream(mx.gpu)

# Set it as the default for GPU operations
mx.set_default_stream(custom_stream)

# Now operations use this stream by default
a = mx.array([1, 2, 3])
stream
Stream
required
The stream to set as default.

new_stream

Create a new stream on a device.
import mlx.core as mx

# Create new stream on GPU
stream1 = mx.new_stream(mx.gpu)
stream2 = mx.new_stream(mx.gpu)

# Operations on different streams can run in parallel
a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream1)
c = mx.multiply(a, 2, stream=stream2)
device
Device
required
The device to create the stream on.
stream
Stream
A new stream for the specified device.

stream

Context manager for temporarily setting the default stream.
import mlx.core as mx

# Create a custom stream
custom_stream = mx.new_stream(mx.gpu)

# Use it as default within a context
with mx.stream(custom_stream):
    # Operations in this block use custom_stream
    a = mx.array([1, 2, 3])
    b = mx.add(a, 1)

# Outside the context, the original default stream is restored
c = mx.array([4, 5, 6])
stream
Stream
required
The stream to use as default within the context.

synchronize

Wait for operations on a stream to complete.
import mlx.core as mx

# Schedule operations
a = mx.array([1, 2, 3])
b = mx.add(a, 1)

# Wait for all operations on default stream to complete
mx.synchronize()

# Now it's safe to access the result
print(b)  # array([2, 3, 4], dtype=int32)
stream
Stream
The stream to synchronize. If not provided, synchronizes the default stream.

Using Streams for Parallelism

Streams enable parallelism by allowing operations to overlap:
import mlx.core as mx

# Create two streams
stream1 = mx.new_stream(mx.gpu)
stream2 = mx.new_stream(mx.gpu)

# These operations can run in parallel
a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream1)      # Runs on stream1
c = mx.multiply(a, 2, stream=stream2)  # Runs on stream2 (in parallel)

# Wait for both to complete
mx.synchronize(stream1)
mx.synchronize(stream2)

print(b)  # array([2, 3, 4], dtype=int32)
print(c)  # array([2, 4, 6], dtype=int32)

Device-Specific Operations

You can explicitly specify which device an operation should run on:
import mlx.core as mx

# Create array on CPU
a = mx.array([1, 2, 3])

# Perform operation on GPU
b = mx.add(a, 1, stream=mx.default_stream(mx.gpu))

# Perform operation on CPU
c = mx.multiply(a, 2, stream=mx.default_stream(mx.cpu))

Best Practices

Lazy Evaluation

MLX uses lazy evaluation - operations are not executed immediately but scheduled on a stream. To ensure an operation completes before accessing results:
import mlx.core as mx

a = mx.array([1, 2, 3])
b = mx.add(a, 1)

# Force evaluation
mx.eval(b)  # or mx.synchronize()

# Now safe to use b
print(b)

Multi-GPU Usage

For systems with multiple GPUs:
import mlx.core as mx

num_gpus = mx.device_count(mx.gpu)
print(f"Available GPUs: {num_gpus}")

# Use specific GPU
if num_gpus > 1:
    gpu0 = mx.Device(mx.gpu, 0)
    gpu1 = mx.Device(mx.gpu, 1)
    
    stream0 = mx.default_stream(gpu0)
    stream1 = mx.default_stream(gpu1)
    
    # Schedule operations on different GPUs
    a = mx.array([1, 2, 3])
    b = mx.add(a, 1, stream=stream0)  # GPU 0
    c = mx.add(a, 2, stream=stream1)  # GPU 1

Stream Safety

Operations on the same stream execute in order:
import mlx.core as mx

stream = mx.new_stream(mx.gpu)

a = mx.array([1, 2, 3])
b = mx.add(a, 1, stream=stream)
c = mx.multiply(b, 2, stream=stream)  # Waits for b to complete

mx.synchronize(stream)
print(c)  # array([4, 6, 8], dtype=int32)