Documentation Index
Fetch the complete documentation index at: https://mintlify.com/ml-explore/mlx/llms.txt
Use this file to discover all available pages before exploring further.
Overview
Random sampling functions in MLX use an implicit global PRNG state by default. However, all functions take an optional key keyword argument for when more fine-grained control or explicit state management is needed.
MLX follows JAX’s PRNG design using a splittable version of Threefry, which is a counter-based PRNG.
Basic Usage
Generate random numbers using the implicit global state:
import mlx.core as mx
for _ in range(3):
print(mx.random.uniform())
Or use explicit key management:
import mlx.core as mx
key = mx.random.key(0)
for _ in range(3):
print(mx.random.uniform(key=key)) # Same number each time
Key Management
key
mx.random.key(seed: int) -> array
Get a PRNG key from a seed.
Returns: Array representing the PRNG key
Example:
seed
mx.random.seed(seed: int) -> None
Seed the default global PRNG state.
Seed value for the global PRNG
Example:
split
mx.random.split(key: array, num: int = 2, stream: StreamOrDevice = None) -> array
Split the RNG key into multiple keys.
Number of keys to generate. If num=2, returns a tuple of two keys. Otherwise returns an array of shape (num, 2).
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Tuple of two arrays (if num=2) or array of shape (num, 2)
Example:
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
keys = mx.random.split(key, 10) # Shape: (10, 2)
Continuous Distributions
mx.random.uniform(
low: float | array = 0.0,
high: float | array = 1.0,
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate uniformly distributed random numbers.
low
float | array
default:"0.0"
Lower bound of the distribution
high
float | array
default:"1.0"
Upper bound of the distribution
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of uniformly distributed random values
Example:
# Scalar value between 0 and 1
x = mx.random.uniform()
# Array of shape (2, 3)
x = mx.random.uniform(shape=(2, 3))
# Values between -1 and 5
x = mx.random.uniform(low=-1, high=5, shape=(1000,))
normal
mx.random.normal(
loc: float = 0.0,
scale: float = 1.0,
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from a normal (Gaussian) distribution.
Standard deviation of the distribution
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of normally distributed random values
Example:
# Standard normal
x = mx.random.normal(shape=(2, 3))
# Mean 1.0, std 2.0
x = mx.random.normal(loc=1.0, scale=2.0, shape=(100,))
multivariate_normal
mx.random.multivariate_normal(
mean: array,
cov: array,
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from a multivariate normal distribution.
Mean vector of the distribution
Covariance matrix of the distribution
Shape of the output array
Data type of the output (must be float32)
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the multivariate normal distribution
Example:
mean = mx.array([0, 0])
cov = mx.array([[1, 0], [0, 1]])
x = mx.random.multivariate_normal(mean, cov)
laplace
mx.random.laplace(
loc: float = 0.0,
scale: float = 1.0,
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from the Laplace distribution.
Location parameter (mean) of the distribution
Scale parameter of the distribution
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the Laplace distribution
Example:
x = mx.random.laplace(shape=(100,))
truncated_normal
mx.random.truncated_normal(
lower: array | float,
upper: array | float,
shape: tuple = None,
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from a truncated normal distribution.
Lower bound for truncation
Upper bound for truncation
Shape of the output array. If None, inferred from lower/upper.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the truncated normal distribution
Example:
x = mx.random.truncated_normal(lower=-2, upper=2, shape=(100,))
gumbel
mx.random.gumbel(
shape: tuple = (),
dtype: Dtype = float32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate samples from the Gumbel distribution.
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the Gumbel distribution
Example:
x = mx.random.gumbel(shape=(100,))
Discrete Distributions
randint
mx.random.randint(
low: int | array,
high: int | array,
shape: tuple = (),
dtype: Dtype = int32,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate random integers uniformly from the range [low, high).
Shape of the output array
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of random integers
Example:
# Integers from 0 to 9
x = mx.random.randint(0, 10, shape=(100,))
bernoulli
mx.random.bernoulli(
p: float | array = 0.5,
shape: tuple = None,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Generate binary random variables with probability p of being True.
p
float | array
default:"0.5"
Probability of generating True (1)
Shape of the output array. If None, inferred from p.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of boolean values
Example:
# Coin flips
x = mx.random.bernoulli(p=0.5, shape=(100,))
# 30% chance of True
x = mx.random.bernoulli(p=0.3, shape=(100,))
categorical
mx.random.categorical(
logits: array,
axis: int = -1,
shape: tuple = None,
num_samples: int = None,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Sample from a categorical distribution.
Unnormalized log probabilities for each category
Axis along which to sample
Shape of the output array
Number of samples to draw. Alternative to specifying shape.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of categorical samples (indices)
Example:
# Sample from 3 categories with equal probability
logits = mx.array([0.0, 0.0, 0.0])
x = mx.random.categorical(logits, num_samples=100)
Permutations
permutation
mx.random.permutation(
x: array | int,
axis: int = 0,
key: array | None = None,
stream: StreamOrDevice = None
) -> array
Randomly permute elements.
If an array, permute its elements along the given axis. If an integer, equivalent to permutation(mx.arange(x)).
Axis along which to permute. Only used if x is an array.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Permuted array
Example:
# Random permutation of 0 to 9
x = mx.random.permutation(10)
# Shuffle array rows
arr = mx.array([[1, 2], [3, 4], [5, 6]])
shuffled = mx.random.permutation(arr, axis=0)