Overview
Random sampling functions in MLX use an implicit global PRNG state by default. However, all functions take an optional key keyword argument for when more fine-grained control or explicit state management is needed.
MLX follows JAX’s PRNG design using a splittable version of Threefry, which is a counter-based PRNG.
Basic Usage
Generate random numbers using the implicit global state:
import mlx.core as mx
for _ in range(3):
print(mx.random.uniform())
Or use explicit key management:
import mlx.core as mx
key = mx.random.key(0)
for _ in range(3):
print(mx.random.uniform(key=key)) # Same number each time
Key Management
key
mx.random.key(seed: int) -> array
Get a PRNG key from a seed.
Returns: Array representing the PRNG key
Example:
seed
mx.random.seed(seed: int) -> None
Seed the default global PRNG state.
Seed value for the global PRNG
Example:
split
mx.random.split(key: array, num: int = 2, stream: StreamOrDevice = None) -> array
Split the RNG key into multiple keys.
Number of keys to generate. If num=2, returns a tuple of two keys. Otherwise returns an array of shape (num, 2).
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Tuple of two arrays (if num=2) or array of shape (num, 2)
Example:
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
keys = mx.random.split(key, 10) # Shape: (10, 2)
Continuous Distributions
mx.random.uniform(
low: float | array = 0.0,
high: float | array = 1.0,
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate uniformly distributed random numbers.
low
float | array
default:"0.0"
Lower bound of the distribution
high
float | array
default:"1.0"
Upper bound of the distribution
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of uniformly distributed random values
Example:
# Scalar value between 0 and 1
x = mx.random.uniform()
# Array of shape (2, 3)
x = mx.random.uniform(shape=(2, 3))
# Values between -1 and 5
x = mx.random.uniform(low=-1, high=5, shape=(1000,))
normal
mx.random.normal(
loc: float = 0.0,
scale: float = 1.0,
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from a normal (Gaussian) distribution.
Standard deviation of the distribution
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of normally distributed random values
Example:
# Standard normal
x = mx.random.normal(shape=(2, 3))
# Mean 1.0, std 2.0
x = mx.random.normal(loc=1.0, scale=2.0, shape=(100,))
multivariate_normal
mx.random.multivariate_normal(
mean: array,
cov: array,
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from a multivariate normal distribution.
Mean vector of the distribution
Covariance matrix of the distribution
Shape of the output array
Data type of the output (must be float32)
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the multivariate normal distribution
Example:
mean = mx.array([0, 0])
cov = mx.array([[1, 0], [0, 1]])
x = mx.random.multivariate_normal(mean, cov)
laplace
mx.random.laplace(
loc: float = 0.0,
scale: float = 1.0,
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from the Laplace distribution.
Location parameter (mean) of the distribution
Scale parameter of the distribution
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the Laplace distribution
Example:
x = mx.random.laplace(shape=(100,))
truncated_normal
mx.random.truncated_normal(
lower: array | float,
upper: array | float,
shape: tuple = None,
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from a truncated normal distribution.
Lower bound for truncation
Upper bound for truncation
Shape of the output array. If None, inferred from lower/upper.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the truncated normal distribution
Example:
x = mx.random.truncated_normal(lower=-2, upper=2, shape=(100,))
gumbel
mx.random.gumbel(
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from the Gumbel distribution.
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the Gumbel distribution
Example:
x = mx.random.gumbel(shape=(100,))
Discrete Distributions
randint
mx.random.randint(
low: int | array,
high: int | array,
shape: tuple = (),
dtype: Dtype = int32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate random integers uniformly from the range [low, high).
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of random integers
Example:
# Integers from 0 to 9
x = mx.random.randint(0, 10, shape=(100,))
bernoulli
mx.random.bernoulli(
p: float | array = 0.5,
shape: tuple = None,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate binary random variables with probability p of being True.
p
float | array
default:"0.5"
Probability of generating True (1)
Shape of the output array. If None, inferred from p.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of boolean values
Example:
# Coin flips
x = mx.random.bernoulli(p=0.5, shape=(100,))
# 30% chance of True
x = mx.random.bernoulli(p=0.3, shape=(100,))
categorical
mx.random.categorical(
logits: array,
axis: int = -1,
shape: tuple = None,
num_samples: int = None,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Sample from a categorical distribution.
Unnormalized log probabilities for each category
Axis along which to sample
Shape of the output array
Number of samples to draw. Alternative to specifying shape.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of categorical samples (indices)
Example:
# Sample from 3 categories with equal probability
logits = mx.array([0.0, 0.0, 0.0])
x = mx.random.categorical(logits, num_samples=100)
Permutations
permutation
mx.random.permutation(
x: array | int,
axis: int = 0,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Randomly permute elements.
If an array, permute its elements along the given axis. If an integer, equivalent to permutation(mx.arange(x)).
Axis along which to permute. Only used if x is an array.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Permuted array
Example:
# Random permutation of 0 to 9
x = mx.random.permutation(10)
# Shuffle array rows
arr = mx.array([[1, 2], [3, 4], [5, 6]])
shuffled = mx.random.permutation(arr, axis=0)