Skip to main content
The array class is the fundamental data structure in MLX. It represents a multi-dimensional array with a specific data type and shape.

Array Construction

array

Construct an MLX array from Python values or iterables.
import mlx.core as mx

# Scalar array
a = mx.array(5)

# 1D array
b = mx.array([1, 2, 3, 4])

# 2D array
c = mx.array([[1, 2], [3, 4]])

# With explicit dtype
d = mx.array([1.0, 2.0, 3.0], dtype=mx.float32)
data
scalar | list | tuple | ndarray
The data to construct the array from. Can be a Python scalar, list, tuple, or numpy array.
dtype
Dtype
default:"inferred"
The data type of the array. If not specified, the dtype is inferred from the input data.
array
array
The constructed MLX array.

Array Attributes

array.dtype

The data type of the array elements.
import mlx.core as mx

a = mx.array([1, 2, 3])
print(a.dtype)  # mlx.core.int32

b = mx.array([1.0, 2.0, 3.0])
print(b.dtype)  # mlx.core.float32
dtype
Dtype
The data type of the array.

array.shape

The dimensions of the array as a tuple.
import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6]])
print(a.shape)  # (2, 3)

b = mx.array(5)
print(b.shape)  # ()
shape
tuple[int, ...]
A tuple containing the size of each dimension.

array.ndim

The number of dimensions of the array.
import mlx.core as mx

a = mx.array([1, 2, 3])
print(a.ndim)  # 1

b = mx.array([[1, 2], [3, 4]])
print(b.ndim)  # 2
ndim
int
The number of dimensions.

array.size

The total number of elements in the array.
import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6]])
print(a.size)  # 6
size
int
The total number of elements.

array.itemsize

The size of one array element in bytes.
import mlx.core as mx

a = mx.array([1, 2, 3], dtype=mx.float32)
print(a.itemsize)  # 4

b = mx.array([1, 2, 3], dtype=mx.int8)
print(b.itemsize)  # 1
itemsize
int
The size in bytes of one element.

array.nbytes

The total number of bytes used by the array data.
import mlx.core as mx

a = mx.array([1, 2, 3, 4], dtype=mx.float32)
print(a.nbytes)  # 16 (4 elements * 4 bytes each)
nbytes
int
The total size in bytes.

array.T

The transposed array (reverses all dimensions).
import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6]])
print(a.T)
# array([[1, 4],
#        [2, 5],
#        [3, 6]], dtype=int32)
T
array
The transposed array.

Array Methods

array.astype

Cast the array to a different data type.
import mlx.core as mx

a = mx.array([1, 2, 3], dtype=mx.int32)
b = a.astype(mx.float32)
print(b)  # array([1, 2, 3], dtype=float32)
dtype
Dtype
required
The target data type.
stream
Stream
The stream on which to schedule the operation.
result
array
A new array with the specified dtype.

array.item

Extract the scalar value from a single-element array.
import mlx.core as mx

a = mx.array(42)
print(a.item())  # 42

b = mx.array([5])
print(b.item())  # 5
value
scalar
The scalar value. Raises an error if the array has more than one element.

array.tolist

Convert the array to a Python list.
import mlx.core as mx

a = mx.array([1, 2, 3])
print(a.tolist())  # [1, 2, 3]

b = mx.array([[1, 2], [3, 4]])
print(b.tolist())  # [[1, 2], [3, 4]]
list
list
The array as a nested Python list.

array.reshape

Reshape the array to a new shape without changing its data.
import mlx.core as mx

a = mx.array([1, 2, 3, 4, 5, 6])
b = a.reshape(2, 3)
print(b)
# array([[1, 2, 3],
#        [4, 5, 6]], dtype=int32)

# Use -1 to infer a dimension
c = a.reshape(3, -1)
print(c.shape)  # (3, 2)
*shape
int
required
The new shape. One dimension can be -1, which will be inferred from the array size.
stream
Stream
The stream on which to schedule the operation.
result
array
The reshaped array.

array.flatten

Return a flattened (1D) copy of the array.
import mlx.core as mx

a = mx.array([[1, 2], [3, 4], [5, 6]])
b = a.flatten()
print(b)  # array([1, 2, 3, 4, 5, 6], dtype=int32)
start_axis
int
default:"0"
The first axis to flatten.
end_axis
int
default:"-1"
The last axis to flatten.
stream
Stream
The stream on which to schedule the operation.
result
array
The flattened array.

array.squeeze

Remove singleton dimensions from the array.
import mlx.core as mx

a = mx.array([[[1, 2, 3]]])
print(a.shape)  # (1, 1, 3)

b = a.squeeze()
print(b.shape)  # (3,)

# Squeeze specific axis
c = a.squeeze(0)
print(c.shape)  # (1, 3)
axis
int | tuple[int, ...]
Axis or axes to squeeze. If not specified, all singleton dimensions are removed.
stream
Stream
The stream on which to schedule the operation.
result
array
The squeezed array.

array.transpose

Permute the dimensions of the array.
import mlx.core as mx

a = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(a.shape)  # (2, 2, 2)

# Reverse all dimensions
b = a.transpose()
print(b.shape)  # (2, 2, 2)

# Specific permutation
c = a.transpose(2, 0, 1)
print(c.shape)  # (2, 2, 2)
*axes
int
The new order of dimensions. If not specified, reverses all dimensions.
stream
Stream
The stream on which to schedule the operation.
result
array
The transposed array.

array.swapaxes

Swap two axes of the array.
import mlx.core as mx

a = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(a.shape)  # (2, 2, 2)

b = a.swapaxes(0, 2)
print(b.shape)  # (2, 2, 2)
axis1
int
required
The first axis to swap.
axis2
int
required
The second axis to swap.
stream
Stream
The stream on which to schedule the operation.
result
array
The array with swapped axes.

array.moveaxis

Move an axis to a new position.
import mlx.core as mx

a = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(a.shape)  # (2, 2, 2)

b = a.moveaxis(0, -1)
print(b.shape)  # (2, 2, 2)
source
int
required
The original position of the axis to move.
destination
int
required
The destination position for the axis.
stream
Stream
The stream on which to schedule the operation.
result
array
The array with moved axis.

Reduction Methods

array.sum

Sum of array elements along specified axes.
import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6]])
print(a.sum())  # array(21, dtype=int32)
print(a.sum(axis=0))  # array([5, 7, 9], dtype=int32)
print(a.sum(axis=1))  # array([6, 15], dtype=int32)
axis
int | tuple[int, ...]
The axis or axes along which to sum. If not specified, sums all elements.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
The sum along the specified axes.

array.mean

Compute the mean of array elements along specified axes.
import mlx.core as mx

a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(a.mean())  # array(3.5, dtype=float32)
print(a.mean(axis=0))  # array([2.5, 3.5, 4.5], dtype=float32)
axis
int | tuple[int, ...]
The axis or axes along which to compute the mean.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
The mean along the specified axes.

array.var

Compute the variance of array elements along specified axes.
import mlx.core as mx

a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(a.var())  # array(2.9167, dtype=float32)
print(a.var(axis=0))  # array([2.25, 2.25, 2.25], dtype=float32)
axis
int | tuple[int, ...]
The axis or axes along which to compute the variance.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
ddof
int
default:"0"
Delta degrees of freedom for the divisor (N - ddof).
stream
Stream
The stream on which to schedule the operation.
result
array
The variance along the specified axes.

array.std

Compute the standard deviation of array elements along specified axes.
import mlx.core as mx

a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(a.std())  # array(1.7078, dtype=float32)
print(a.std(axis=0))  # array([1.5, 1.5, 1.5], dtype=float32)
axis
int | tuple[int, ...]
The axis or axes along which to compute the standard deviation.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
ddof
int
default:"0"
Delta degrees of freedom for the divisor (N - ddof).
stream
Stream
The stream on which to schedule the operation.
result
array
The standard deviation along the specified axes.

array.max

Return the maximum value along specified axes.
import mlx.core as mx

a = mx.array([[1, 5, 3], [4, 2, 6]])
print(a.max())  # array(6, dtype=int32)
print(a.max(axis=0))  # array([4, 5, 6], dtype=int32)
print(a.max(axis=1))  # array([5, 6], dtype=int32)
axis
int | tuple[int, ...]
The axis or axes along which to find the maximum.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
The maximum values along the specified axes.

array.min

Return the minimum value along specified axes.
import mlx.core as mx

a = mx.array([[1, 5, 3], [4, 2, 6]])
print(a.min())  # array(1, dtype=int32)
print(a.min(axis=0))  # array([1, 2, 3], dtype=int32)
axis
int | tuple[int, ...]
The axis or axes along which to find the minimum.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
The minimum values along the specified axes.

array.prod

Return the product of array elements along specified axes.
import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6]])
print(a.prod())  # array(720, dtype=int32)
print(a.prod(axis=0))  # array([4, 10, 18], dtype=int32)
axis
int | tuple[int, ...]
The axis or axes along which to compute the product.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
The product along the specified axes.

array.all

Test whether all array elements along specified axes evaluate to True.
import mlx.core as mx

a = mx.array([True, True, True])
print(a.all())  # array(True, dtype=bool)

b = mx.array([True, False, True])
print(b.all())  # array(False, dtype=bool)
axis
int | tuple[int, ...]
The axis or axes along which to test.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
Boolean array indicating if all elements are True.

array.any

Test whether any array element along specified axes evaluates to True.
import mlx.core as mx

a = mx.array([False, True, False])
print(a.any())  # array(True, dtype=bool)

b = mx.array([False, False, False])
print(b.any())  # array(False, dtype=bool)
axis
int | tuple[int, ...]
The axis or axes along which to test.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
Boolean array indicating if any elements are True.

array.argmax

Return the indices of the maximum values along specified axes.
import mlx.core as mx

a = mx.array([[1, 5, 3], [4, 2, 6]])
print(a.argmax())  # array(5, dtype=int32)
print(a.argmax(axis=0))  # array([1, 0, 1], dtype=int32)
print(a.argmax(axis=1))  # array([1, 2], dtype=int32)
axis
int
The axis along which to find the indices.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
The indices of the maximum values.

array.argmin

Return the indices of the minimum values along specified axes.
import mlx.core as mx

a = mx.array([[1, 5, 3], [4, 2, 6]])
print(a.argmin())  # array(0, dtype=int32)
print(a.argmin(axis=0))  # array([0, 1, 0], dtype=int32)
axis
int
The axis along which to find the indices.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
The indices of the minimum values.

Element-wise Mathematical Methods

array.abs

Compute the absolute value element-wise.
import mlx.core as mx

a = mx.array([-1, -2, 3, -4])
print(a.abs())  # array([1, 2, 3, 4], dtype=int32)
stream
Stream
The stream on which to schedule the operation.
result
array
The absolute values.

array.sqrt

Compute the square root element-wise.
import mlx.core as mx

a = mx.array([1.0, 4.0, 9.0, 16.0])
print(a.sqrt())  # array([1, 2, 3, 4], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The square roots.

array.rsqrt

Compute the reciprocal square root element-wise (1/sqrt(x)).
import mlx.core as mx

a = mx.array([1.0, 4.0, 9.0])
print(a.rsqrt())  # array([1, 0.5, 0.333...], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The reciprocal square roots.

array.square

Compute the square element-wise.
import mlx.core as mx

a = mx.array([1, 2, 3, 4])
print(a.square())  # array([1, 4, 9, 16], dtype=int32)
stream
Stream
The stream on which to schedule the operation.
result
array
The squares.

array.exp

Compute the exponential (e^x) element-wise.
import mlx.core as mx

a = mx.array([0.0, 1.0, 2.0])
print(a.exp())  # array([1, 2.71828, 7.38906], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The exponentials.

array.log

Compute the natural logarithm element-wise.
import mlx.core as mx

a = mx.array([1.0, 2.71828, 7.38906])
print(a.log())  # array([0, 1, 2], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The natural logarithms.

array.log2

Compute the base-2 logarithm element-wise.
import mlx.core as mx

a = mx.array([1.0, 2.0, 4.0, 8.0])
print(a.log2())  # array([0, 1, 2, 3], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The base-2 logarithms.

array.log10

Compute the base-10 logarithm element-wise.
import mlx.core as mx

a = mx.array([1.0, 10.0, 100.0])
print(a.log10())  # array([0, 1, 2], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The base-10 logarithms.

array.log1p

Compute log(1 + x) element-wise.
import mlx.core as mx

a = mx.array([0.0, 1.0, 2.0])
print(a.log1p())  # array([0, 0.693147, 1.09861], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The log(1 + x) values.

array.sin

Compute the sine element-wise.
import mlx.core as mx
import math

a = mx.array([0.0, math.pi/2, math.pi])
print(a.sin())  # array([0, 1, 0], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The sines.

array.cos

Compute the cosine element-wise.
import mlx.core as mx
import math

a = mx.array([0.0, math.pi/2, math.pi])
print(a.cos())  # array([1, 0, -1], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The cosines.

array.round

Round to the nearest integer element-wise.
import mlx.core as mx

a = mx.array([1.2, 2.5, 3.7, 4.5])
print(a.round())  # array([1, 2, 4, 4], dtype=float32)
decimals
int
default:"0"
Number of decimal places to round to.
stream
Stream
The stream on which to schedule the operation.
result
array
The rounded values.

array.reciprocal

Compute the reciprocal (1/x) element-wise.
import mlx.core as mx

a = mx.array([1.0, 2.0, 4.0])
print(a.reciprocal())  # array([1, 0.5, 0.25], dtype=float32)
stream
Stream
The stream on which to schedule the operation.
result
array
The reciprocals.

Cumulative Methods

array.cumsum

Compute the cumulative sum along an axis.
import mlx.core as mx

a = mx.array([1, 2, 3, 4])
print(a.cumsum())  # array([1, 3, 6, 10], dtype=int32)

b = mx.array([[1, 2], [3, 4]])
print(b.cumsum(axis=0))
# array([[1, 2],
#        [4, 6]], dtype=int32)
axis
int
default:"0"
The axis along which to compute the cumulative sum.
reverse
bool
default:"false"
If True, compute the cumulative sum in reverse.
inclusive
bool
default:"true"
If True, include the current element in the sum.
stream
Stream
The stream on which to schedule the operation.
result
array
The cumulative sum.

array.cumprod

Compute the cumulative product along an axis.
import mlx.core as mx

a = mx.array([1, 2, 3, 4])
print(a.cumprod())  # array([1, 2, 6, 24], dtype=int32)
axis
int
default:"0"
The axis along which to compute the cumulative product.
reverse
bool
default:"false"
If True, compute the cumulative product in reverse.
inclusive
bool
default:"true"
If True, include the current element in the product.
stream
Stream
The stream on which to schedule the operation.
result
array
The cumulative product.

array.cummax

Compute the cumulative maximum along an axis.
import mlx.core as mx

a = mx.array([1, 3, 2, 5, 4])
print(a.cummax())  # array([1, 3, 3, 5, 5], dtype=int32)
axis
int
default:"0"
The axis along which to compute the cumulative maximum.
reverse
bool
default:"false"
If True, compute the cumulative maximum in reverse.
inclusive
bool
default:"true"
If True, include the current element in the maximum.
stream
Stream
The stream on which to schedule the operation.
result
array
The cumulative maximum.

array.cummin

Compute the cumulative minimum along an axis.
import mlx.core as mx

a = mx.array([5, 3, 4, 1, 2])
print(a.cummin())  # array([5, 3, 3, 1, 1], dtype=int32)
axis
int
default:"0"
The axis along which to compute the cumulative minimum.
reverse
bool
default:"false"
If True, compute the cumulative minimum in reverse.
inclusive
bool
default:"true"
If True, include the current element in the minimum.
stream
Stream
The stream on which to schedule the operation.
result
array
The cumulative minimum.

array.logsumexp

Compute the log-sum-exp along specified axes.
import mlx.core as mx

a = mx.array([1.0, 2.0, 3.0])
print(a.logsumexp())  # array(3.40761, dtype=float32)
axis
int | tuple[int, ...]
The axis or axes along which to compute the log-sum-exp.
keepdims
bool
default:"false"
If True, the reduced axes are kept as dimensions with size 1.
stream
Stream
The stream on which to schedule the operation.
result
array
The log-sum-exp along the specified axes.

array.logcumsumexp

Compute the cumulative log-sum-exp along an axis.
import mlx.core as mx

a = mx.array([1.0, 2.0, 3.0])
print(a.logcumsumexp())
# array([1, 2.31326, 3.40761], dtype=float32)
axis
int
default:"0"
The axis along which to compute the cumulative log-sum-exp.
reverse
bool
default:"false"
If True, compute in reverse order.
inclusive
bool
default:"true"
If True, include the current element.
stream
Stream
The stream on which to schedule the operation.
result
array
The cumulative log-sum-exp.

Complex Number Methods

array.real

Return the real part of a complex array.
import mlx.core as mx

a = mx.array([1+2j, 3+4j])
print(a.real)  # array([1, 3], dtype=float32)
result
array
The real part of the array.

array.imag

Return the imaginary part of a complex array.
import mlx.core as mx

a = mx.array([1+2j, 3+4j])
print(a.imag)  # array([2, 4], dtype=float32)
result
array
The imaginary part of the array.

array.conj

Return the complex conjugate element-wise.
import mlx.core as mx

a = mx.array([1+2j, 3+4j])
print(a.conj())  # array([1-2j, 3-4j], dtype=complex64)
stream
Stream
The stream on which to schedule the operation.
result
array
The complex conjugate.

Other Methods

array.split

Split an array into sub-arrays along an axis.
import mlx.core as mx

a = mx.array([1, 2, 3, 4, 5, 6])
result = a.split(3)  # Split into 3 equal parts
for sub in result:
    print(sub)
# array([1, 2], dtype=int32)
# array([3, 4], dtype=int32)
# array([5, 6], dtype=int32)
num_splits
int | tuple[int, ...]
required
Number of equal splits, or indices where to split.
axis
int
default:"0"
The axis along which to split.
stream
Stream
The stream on which to schedule the operation.
result
list[array]
List of sub-arrays.

array.diagonal

Return specified diagonal from the array.
import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(a.diagonal())  # array([1, 5, 9], dtype=int32)
print(a.diagonal(offset=1))  # array([2, 6], dtype=int32)
offset
int
default:"0"
Offset of the diagonal from the main diagonal.
axis1
int
default:"0"
First axis.
axis2
int
default:"1"
Second axis.
stream
Stream
The stream on which to schedule the operation.
result
array
The diagonal elements.

array.diag

Extract diagonal or construct diagonal array.
import mlx.core as mx

# Extract diagonal from 2D array
a = mx.array([[1, 2], [3, 4]])
print(a.diag())  # array([1, 4], dtype=int32)

# Create diagonal matrix from 1D array
b = mx.array([1, 2, 3])
print(b.diag())
# array([[1, 0, 0],
#        [0, 2, 0],
#        [0, 0, 3]], dtype=int32)
k
int
default:"0"
The diagonal offset.
stream
Stream
The stream on which to schedule the operation.
result
array
Extracted diagonal or diagonal matrix.

array.view

Create a view of the array with a different data type.
import mlx.core as mx

a = mx.array([1, 2, 3, 4], dtype=mx.int32)
b = a.view(mx.float32)
print(b)  # Reinterpret bytes as float32
dtype
Dtype
required
The target data type for the view.
stream
Stream
The stream on which to schedule the operation.
result
array
A view of the array with the specified dtype.

array.at

Advanced indexing interface for in-place updates.
import mlx.core as mx

a = mx.array([1, 2, 3, 4, 5])
# Update specific indices
b = a.at[mx.array([0, 2, 4])].add(10)
print(b)  # array([11, 2, 13, 4, 15], dtype=int32)
indexer
ArrayIndexer
An indexing interface that supports methods like add, subtract, multiply, divide, and maximum, minimum for in-place-style updates.