Skip to main content
MLX provides a comprehensive set of operations for array manipulation, mathematical computations, and linear algebra. Most operations support broadcasting and can be executed on different devices.

Array Creation

arange

Create a 1D array with evenly spaced values.
import mlx.core as mx

# Range from 0 to 10 (exclusive)
a = mx.arange(10)
print(a)  # array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

# Range with start and stop
b = mx.arange(5, 10)
print(b)  # array([5, 6, 7, 8, 9], dtype=int32)

# Range with step
c = mx.arange(0, 10, 2)
print(c)  # array([0, 2, 4, 6, 8], dtype=int32)

# Floating point range
d = mx.arange(0.0, 1.0, 0.25)
print(d)  # array([0, 0.25, 0.5, 0.75], dtype=float32)
start
int | float
default:"0"
Start of the range.
stop
int | float
required
End of the range (exclusive).
step
int | float
default:"1"
Spacing between values.
dtype
Dtype
The data type of the output array. If not specified, inferred from the inputs.
stream
Stream
The stream on which to schedule the operation.
result
array
1D array of evenly spaced values.

linspace

Create a 1D array with a specified number of evenly spaced values.
import mlx.core as mx

# 5 values from 0 to 1 (inclusive)
a = mx.linspace(0, 1, 5)
print(a)  # array([0, 0.25, 0.5, 0.75, 1], dtype=float32)

# 10 values from -1 to 1
b = mx.linspace(-1, 1, 10)
print(b)
start
float
required
Start of the range.
stop
float
required
End of the range (inclusive).
num
int
default:"50"
Number of values to generate.
dtype
Dtype
default:"float32"
The data type of the output array.
stream
Stream
The stream on which to schedule the operation.
result
array
1D array of evenly spaced values.

zeros

Create an array filled with zeros.
import mlx.core as mx

a = mx.zeros(5)
print(a)  # array([0, 0, 0, 0, 0], dtype=float32)

b = mx.zeros((2, 3))
print(b)
# array([[0, 0, 0],
#        [0, 0, 0]], dtype=float32)

c = mx.zeros((2, 3), dtype=mx.int32)
print(c.dtype)  # mlx.core.int32
shape
int | tuple[int, ...]
required
Shape of the array.
dtype
Dtype
default:"float32"
The data type of the output array.
stream
Stream
The stream on which to schedule the operation.
result
array
Array filled with zeros.

ones

Create an array filled with ones.
import mlx.core as mx

a = mx.ones(5)
print(a)  # array([1, 1, 1, 1, 1], dtype=float32)

b = mx.ones((2, 3), dtype=mx.int32)
print(b)
# array([[1, 1, 1],
#        [1, 1, 1]], dtype=int32)
shape
int | tuple[int, ...]
required
Shape of the array.
dtype
Dtype
default:"float32"
The data type of the output array.
stream
Stream
The stream on which to schedule the operation.
result
array
Array filled with ones.

full

Create an array filled with a specified value.
import mlx.core as mx

a = mx.full((2, 3), 7)
print(a)
# array([[7, 7, 7],
#        [7, 7, 7]], dtype=int32)

b = mx.full((3,), 3.14, dtype=mx.float32)
print(b)  # array([3.14, 3.14, 3.14], dtype=float32)
shape
int | tuple[int, ...]
required
Shape of the array.
val
scalar | array
required
The fill value.
dtype
Dtype
The data type of the output array. If not specified, inferred from val.
stream
Stream
The stream on which to schedule the operation.
result
array
Array filled with the specified value.

eye

Create a 2D identity matrix.
import mlx.core as mx

# 3x3 identity matrix
a = mx.eye(3)
print(a)
# array([[1, 0, 0],
#        [0, 1, 0],
#        [0, 0, 1]], dtype=float32)

# 3x4 matrix with ones on diagonal
b = mx.eye(3, 4)
print(b)
# array([[1, 0, 0, 0],
#        [0, 1, 0, 0],
#        [0, 0, 1, 0]], dtype=float32)

# With offset diagonal
c = mx.eye(3, 3, k=1)
print(c)
# array([[0, 1, 0],
#        [0, 0, 1],
#        [0, 0, 0]], dtype=float32)
n
int
required
Number of rows.
m
int
Number of columns. If not specified, defaults to n.
k
int
default:"0"
Index of the diagonal. 0 is the main diagonal, positive values are above, negative below.
dtype
Dtype
default:"float32"
The data type of the output array.
stream
Stream
The stream on which to schedule the operation.
result
array
2D array with ones on the specified diagonal.

identity

Create a square identity matrix.
import mlx.core as mx

a = mx.identity(4)
print(a)
# array([[1, 0, 0, 0],
#        [0, 1, 0, 0],
#        [0, 0, 1, 0],
#        [0, 0, 0, 1]], dtype=float32)
n
int
required
Size of the square matrix.
dtype
Dtype
default:"float32"
The data type of the output array.
stream
Stream
The stream on which to schedule the operation.
result
array
Square identity matrix of size n x n.

Shape Manipulation

reshape

Reshape an array without changing its data.
import mlx.core as mx

a = mx.arange(12)
print(a.shape)  # (12,)

b = mx.reshape(a, (3, 4))
print(b.shape)  # (3, 4)
print(b)
# array([[0, 1, 2, 3],
#        [4, 5, 6, 7],
#        [8, 9, 10, 11]], dtype=int32)

# Use -1 to infer dimension
c = mx.reshape(a, (2, -1))
print(c.shape)  # (2, 6)
a
array
required
Input array.
shape
tuple[int, ...]
required
New shape. One dimension can be -1, which will be inferred.
stream
Stream
The stream on which to schedule the operation.
result
array
Reshaped array.

flatten

Flatten an array to 1D.
import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6]])
b = mx.flatten(a)
print(b)  # array([1, 2, 3, 4, 5, 6], dtype=int32)

# Flatten specific axes
c = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(c.shape)  # (2, 2, 2)
d = mx.flatten(c, start_axis=1)
print(d.shape)  # (2, 4)
a
array
required
Input array.
start_axis
int
default:"0"
First axis to flatten.
end_axis
int
default:"-1"
Last axis to flatten.
stream
Stream
The stream on which to schedule the operation.
result
array
Flattened array.

squeeze

Remove singleton dimensions.
import mlx.core as mx

a = mx.array([[[1, 2, 3]]])
print(a.shape)  # (1, 1, 3)

b = mx.squeeze(a)
print(b.shape)  # (3,)

# Squeeze specific axis
c = mx.squeeze(a, axis=0)
print(c.shape)  # (1, 3)
a
array
required
Input array.
axis
int | tuple[int, ...]
Axis or axes to squeeze. If not specified, all singleton dimensions are removed.
stream
Stream
The stream on which to schedule the operation.
result
array
Squeezed array.

expand_dims

Add singleton dimensions to an array.
import mlx.core as mx

a = mx.array([1, 2, 3])
print(a.shape)  # (3,)

b = mx.expand_dims(a, axis=0)
print(b.shape)  # (1, 3)

c = mx.expand_dims(a, axis=1)
print(c.shape)  # (3, 1)

# Add multiple dimensions
d = mx.expand_dims(a, axis=(0, 2))
print(d.shape)  # (1, 3, 1)
a
array
required
Input array.
axis
int | tuple[int, ...]
required
Position(s) where new axis should be added.
stream
Stream
The stream on which to schedule the operation.
result
array
Array with added dimensions.

transpose

Permute the dimensions of an array.
import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6]])
print(a.shape)  # (2, 3)

# Reverse all dimensions
b = mx.transpose(a)
print(b.shape)  # (3, 2)
print(b)
# array([[1, 4],
#        [2, 5],
#        [3, 6]], dtype=int32)

# Specific permutation for 3D
c = mx.zeros((2, 3, 4))
d = mx.transpose(c, (2, 0, 1))
print(d.shape)  # (4, 2, 3)
a
array
required
Input array.
axes
tuple[int, ...]
Permutation of axes. If not specified, reverses all dimensions.
stream
Stream
The stream on which to schedule the operation.
result
array
Transposed array.

swapaxes

Swap two axes of an array.
import mlx.core as mx

a = mx.zeros((2, 3, 4))
print(a.shape)  # (2, 3, 4)

b = mx.swapaxes(a, 0, 2)
print(b.shape)  # (4, 3, 2)
a
array
required
Input array.
axis1
int
required
First axis.
axis2
int
required
Second axis.
stream
Stream
The stream on which to schedule the operation.
result
array
Array with swapped axes.

moveaxis

Move an axis to a new position.
import mlx.core as mx

a = mx.zeros((2, 3, 4, 5))
print(a.shape)  # (2, 3, 4, 5)

b = mx.moveaxis(a, 0, -1)
print(b.shape)  # (3, 4, 5, 2)

c = mx.moveaxis(a, 2, 0)
print(c.shape)  # (4, 2, 3, 5)
a
array
required
Input array.
source
int
required
Original position of the axis.
destination
int
required
Destination position for the axis.
stream
Stream
The stream on which to schedule the operation.
result
array
Array with moved axis.

Arithmetic Operations

add

Element-wise addition.
import mlx.core as mx

a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
c = mx.add(a, b)
print(c)  # array([5, 7, 9], dtype=int32)

# Operator overload
d = a + b
print(d)  # array([5, 7, 9], dtype=int32)

# Broadcasting
e = mx.add(a, 10)
print(e)  # array([11, 12, 13], dtype=int32)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Element-wise sum.

subtract

Element-wise subtraction.
import mlx.core as mx

a = mx.array([5, 7, 9])
b = mx.array([1, 2, 3])
c = mx.subtract(a, b)
print(c)  # array([4, 5, 6], dtype=int32)

# Operator overload
d = a - b
print(d)  # array([4, 5, 6], dtype=int32)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Element-wise difference.

multiply

Element-wise multiplication.
import mlx.core as mx

a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
c = mx.multiply(a, b)
print(c)  # array([4, 10, 18], dtype=int32)

# Operator overload
d = a * b
print(d)  # array([4, 10, 18], dtype=int32)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Element-wise product.

divide

Element-wise division.
import mlx.core as mx

a = mx.array([10, 20, 30], dtype=mx.float32)
b = mx.array([2, 4, 5], dtype=mx.float32)
c = mx.divide(a, b)
print(c)  # array([5, 5, 6], dtype=float32)

# Operator overload
d = a / b
print(d)  # array([5, 5, 6], dtype=float32)
a
array
required
First input array (numerator).
b
array
required
Second input array (denominator).
stream
Stream
The stream on which to schedule the operation.
result
array
Element-wise quotient.

floor_divide

Element-wise floor division (integer division).
import mlx.core as mx

a = mx.array([10, 21, 32])
b = mx.array([3, 4, 5])
c = mx.floor_divide(a, b)
print(c)  # array([3, 5, 6], dtype=int32)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Element-wise floor quotient.

remainder

Element-wise remainder of division.
import mlx.core as mx

a = mx.array([10, 21, 32])
b = mx.array([3, 4, 5])
c = mx.remainder(a, b)
print(c)  # array([1, 1, 2], dtype=int32)

# Operator overload
d = a % b
print(d)  # array([1, 1, 2], dtype=int32)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Element-wise remainder.

power

Element-wise exponentiation.
import mlx.core as mx

a = mx.array([2, 3, 4])
b = mx.array([2, 2, 2])
c = mx.power(a, b)
print(c)  # array([4, 9, 16], dtype=int32)

# Broadcasting
d = mx.power(a, 2)
print(d)  # array([4, 9, 16], dtype=int32)
a
array
required
Base array.
b
array
required
Exponent array.
stream
Stream
The stream on which to schedule the operation.
result
array
Element-wise power.

Mathematical Functions

abs

Element-wise absolute value.
import mlx.core as mx

a = mx.array([-1, -2, 3, -4])
b = mx.abs(a)
print(b)  # array([1, 2, 3, 4], dtype=int32)
a
array
required
Input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Absolute values.

sqrt

Element-wise square root.
import mlx.core as mx

a = mx.array([1.0, 4.0, 9.0, 16.0])
b = mx.sqrt(a)
print(b)  # array([1, 2, 3, 4], dtype=float32)
a
array
required
Input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Square roots.

rsqrt

Element-wise reciprocal square root (1/sqrt(x)).
import mlx.core as mx

a = mx.array([1.0, 4.0, 9.0])
b = mx.rsqrt(a)
print(b)  # array([1, 0.5, 0.333...], dtype=float32)
a
array
required
Input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Reciprocal square roots.

exp

Element-wise exponential (e^x).
import mlx.core as mx

a = mx.array([0.0, 1.0, 2.0])
b = mx.exp(a)
print(b)  # array([1, 2.71828, 7.38906], dtype=float32)
a
array
required
Input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Exponentials.

log

Element-wise natural logarithm.
import mlx.core as mx

a = mx.array([1.0, 2.71828, 7.38906])
b = mx.log(a)
print(b)  # array([0, 1, 2], dtype=float32)
a
array
required
Input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Natural logarithms.

log2

Element-wise base-2 logarithm.
import mlx.core as mx

a = mx.array([1.0, 2.0, 4.0, 8.0])
b = mx.log2(a)
print(b)  # array([0, 1, 2, 3], dtype=float32)
a
array
required
Input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Base-2 logarithms.

log10

Element-wise base-10 logarithm.
import mlx.core as mx

a = mx.array([1.0, 10.0, 100.0, 1000.0])
b = mx.log10(a)
print(b)  # array([0, 1, 2, 3], dtype=float32)
a
array
required
Input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Base-10 logarithms.

Trigonometric Functions

sin

Element-wise sine.
import mlx.core as mx
import math

a = mx.array([0.0, math.pi/2, math.pi])
b = mx.sin(a)
print(b)  # array([0, 1, 0], dtype=float32)
a
array
required
Input array in radians.
stream
Stream
The stream on which to schedule the operation.
result
array
Sines of input values.

cos

Element-wise cosine.
import mlx.core as mx
import math

a = mx.array([0.0, math.pi/2, math.pi])
b = mx.cos(a)
print(b)  # array([1, 0, -1], dtype=float32)
a
array
required
Input array in radians.
stream
Stream
The stream on which to schedule the operation.
result
array
Cosines of input values.

tan

Element-wise tangent.
import mlx.core as mx
import math

a = mx.array([0.0, math.pi/4])
b = mx.tan(a)
print(b)  # array([0, 1], dtype=float32)
a
array
required
Input array in radians.
stream
Stream
The stream on which to schedule the operation.
result
array
Tangents of input values.

Linear Algebra

matmul

Matrix multiplication.
import mlx.core as mx

# Matrix-matrix multiplication
a = mx.array([[1, 2], [3, 4]])
b = mx.array([[5, 6], [7, 8]])
c = mx.matmul(a, b)
print(c)
# array([[19, 22],
#        [43, 50]], dtype=int32)

# Matrix-vector multiplication
d = mx.array([1, 2])
e = mx.matmul(a, d)
print(e)  # array([5, 11], dtype=int32)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Matrix product of a and b.

outer

Outer product of two vectors.
import mlx.core as mx

a = mx.array([1, 2, 3])
b = mx.array([4, 5])
c = mx.outer(a, b)
print(c)
# array([[4, 5],
#        [8, 10],
#        [12, 15]], dtype=int32)
a
array
required
First input vector.
b
array
required
Second input vector.
stream
Stream
The stream on which to schedule the operation.
result
array
Outer product.

inner

Inner product of two vectors.
import mlx.core as mx

a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
c = mx.inner(a, b)
print(c)  # array(32, dtype=int32)  # 1*4 + 2*5 + 3*6
a
array
required
First input vector.
b
array
required
Second input vector.
stream
Stream
The stream on which to schedule the operation.
result
array
Inner product.

Comparison Operations

equal

Element-wise equality comparison.
import mlx.core as mx

a = mx.array([1, 2, 3])
b = mx.array([1, 0, 3])
c = mx.equal(a, b)
print(c)  # array([True, False, True], dtype=bool)

# Operator overload
d = a == b
print(d)  # array([True, False, True], dtype=bool)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Boolean array of comparison results.

greater

Element-wise greater-than comparison.
import mlx.core as mx

a = mx.array([1, 3, 5])
b = mx.array([2, 2, 4])
c = mx.greater(a, b)
print(c)  # array([False, True, True], dtype=bool)

# Operator overload
d = a > b
print(d)  # array([False, True, True], dtype=bool)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Boolean array of comparison results.

less

Element-wise less-than comparison.
import mlx.core as mx

a = mx.array([1, 3, 5])
b = mx.array([2, 2, 4])
c = mx.less(a, b)
print(c)  # array([True, False, False], dtype=bool)

# Operator overload
d = a < b
print(d)  # array([True, False, False], dtype=bool)
a
array
required
First input array.
b
array
required
Second input array.
stream
Stream
The stream on which to schedule the operation.
result
array
Boolean array of comparison results.

allclose

Test if two arrays are element-wise equal within a tolerance.
import mlx.core as mx

a = mx.array([1.0, 2.0, 3.0])
b = mx.array([1.0001, 2.0001, 3.0001])

# Within default tolerance
print(mx.allclose(a, b))  # array(True, dtype=bool)

# Strict tolerance
print(mx.allclose(a, b, rtol=1e-6, atol=1e-6))  # array(False, dtype=bool)
a
array
required
First input array.
b
array
required
Second input array.
rtol
float
default:"1e-5"
Relative tolerance.
atol
float
default:"1e-8"
Absolute tolerance.
equal_nan
bool
default:"false"
If True, NaNs are considered equal.
stream
Stream
The stream on which to schedule the operation.
result
array
Boolean scalar indicating if arrays are close.

Concatenation and Stacking

concatenate

Concatenate arrays along an existing axis.
import mlx.core as mx

a = mx.array([[1, 2], [3, 4]])
b = mx.array([[5, 6]])

# Concatenate along axis 0
c = mx.concatenate([a, b], axis=0)
print(c)
# array([[1, 2],
#        [3, 4],
#        [5, 6]], dtype=int32)

# Concatenate along axis 1
d = mx.array([[5], [6]])
e = mx.concatenate([a, d], axis=1)
print(e)
# array([[1, 2, 5],
#        [3, 4, 6]], dtype=int32)
arrays
list[array]
required
List of arrays to concatenate.
axis
int
default:"0"
Axis along which to concatenate.
stream
Stream
The stream on which to schedule the operation.
result
array
Concatenated array.

stack

Stack arrays along a new axis.
import mlx.core as mx

a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])

# Stack along new axis 0
c = mx.stack([a, b], axis=0)
print(c)
# array([[1, 2, 3],
#        [4, 5, 6]], dtype=int32)
print(c.shape)  # (2, 3)

# Stack along new axis 1
d = mx.stack([a, b], axis=1)
print(d)
# array([[1, 4],
#        [2, 5],
#        [3, 6]], dtype=int32)
print(d.shape)  # (3, 2)
arrays
list[array]
required
List of arrays to stack.
axis
int
default:"0"
Axis along which to stack.
stream
Stream
The stream on which to schedule the operation.
result
array
Stacked array with one additional dimension.

Indexing and Slicing

take

Take elements from an array along an axis.
import mlx.core as mx

a = mx.array([10, 20, 30, 40, 50])
indices = mx.array([0, 2, 4])
b = mx.take(a, indices)
print(b)  # array([10, 30, 50], dtype=int32)

# 2D indexing
c = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
d = mx.take(c, mx.array([0, 2]), axis=0)
print(d)
# array([[1, 2, 3],
#        [7, 8, 9]], dtype=int32)
a
array
required
Input array.
indices
array
required
Indices of elements to take.
axis
int
Axis along which to take elements. If not specified, the array is flattened.
stream
Stream
The stream on which to schedule the operation.
result
array
Array with selected elements.

where

Select elements from x or y depending on condition.
import mlx.core as mx

condition = mx.array([True, False, True, False])
x = mx.array([1, 2, 3, 4])
y = mx.array([10, 20, 30, 40])
result = mx.where(condition, x, y)
print(result)  # array([1, 20, 3, 40], dtype=int32)

# Use with comparisons
a = mx.array([1, 2, 3, 4, 5])
b = mx.where(a > 2, a, 0)
print(b)  # array([0, 0, 3, 4, 5], dtype=int32)
condition
array
required
Boolean condition array.
x
array
required
Values to select where condition is True.
y
array
required
Values to select where condition is False.
stream
Stream
The stream on which to schedule the operation.
result
array
Array with elements from x where condition is True, otherwise from y.