For the most part, indexing an MLX array works the same as indexing a NumPy ndarray. However, there are some important differences to be aware of.
Basic Indexing
You can use regular integers and slices to index arrays:
arr = mx.arange(10)
arr[3]
# array(3, dtype=int32)
arr[-2] # negative indexing works
# array(8, dtype=int32)
arr[2:8:2] # start, stop, stride
# array([2, 4, 6], dtype=int32)
Multi-dimensional Arrays
For multi-dimensional arrays, the ... or Ellipsis syntax works as in NumPy:
arr = mx.arange(8).reshape(2, 2, 2)
arr[:, :, 0]
# array([[0, 2],
# [4, 6]], dtype=int32)
arr[..., 0]
# array([[0, 2],
# [4, 6]], dtype=int32)
Creating New Axes
You can index with None to create a new axis:
arr = mx.arange(8)
arr.shape
# [8]
arr[None].shape
# [1, 8]
Array Indexing
You can also use an array to index another array:
arr = mx.arange(10)
idx = mx.array([5, 7])
arr[idx]
# array([5, 7], dtype=int32)
Mixing and matching integers, slice, ..., and array indices works just as in NumPy.
Other useful functions for indexing arrays include mx.take() and mx.take_along_axis().
Differences from NumPy
MLX indexing is different from NumPy indexing in two important ways:
-
No bounds checking: Indexing out of bounds is undefined behavior. The reason is that exceptions cannot propagate from the GPU, and performing bounds checking before launching kernels would be extremely inefficient.
-
Limited boolean mask support: Boolean mask-based indexing is supported for assignment only (see Boolean Mask Assignment below).
MLX has limited support for operations where output shapes are dependent on input data. Other examples of these types of operations which MLX does not yet support include numpy.nonzero and the single input version of numpy.where.
In-Place Updates
In-place updates to indexed arrays are possible in MLX:
a = mx.array([1, 2, 3])
a[2] = 0
a
# array([1, 2, 0], dtype=int32)
Just as in NumPy, in-place updates will be reflected in all references to the same array:
a = mx.array([1, 2, 3])
b = a
b[2] = 0
b
# array([1, 2, 0], dtype=int32)
a
# array([1, 2, 0], dtype=int32)
Unlike NumPy, slicing an array creates a copy, not a view. So mutating it does not mutate the original array:a = mx.array([1, 2, 3])
b = a[:]
b[2] = 0
b
# array([1, 2, 0], dtype=int32)
a
# array([1, 2, 3], dtype=int32)
Nondeterministic Updates
Unlike NumPy, updates to the same location are nondeterministic:
a = mx.array([1, 2, 3])
a[[0, 0]] = mx.array([4, 5])
# The first element of 'a' could be 4 or 5
Transformations of functions which use in-place updates are allowed and work as expected:
def fun(x, idx):
x[idx] = 2.0
return x.sum()
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
In the above, dfdx will have the correct gradient: zeros at idx and ones elsewhere.
Boolean Mask Assignment
MLX supports boolean indices using NumPy syntax. A mask must already be a bool_ MLX array or a NumPy ndarray with dtype=bool.
a = mx.array([1.0, 2.0, 3.0])
mask = mx.array([True, False, True])
updates = mx.array([5.0, 6.0])
a[mask] = updates
a
# array([5.0, 2.0, 6.0], dtype=float32)
Scalar assignments broadcast to every True entry in the mask. For non-scalar assignments, updates must provide at least as many elements as there are True entries in mask.
a = mx.zeros((2, 3))
mask = mx.array([[True, False, True],
[False, False, True]])
a[mask] = 1.0
a
# array([[1.0, 0.0, 1.0],
# [0.0, 0.0, 1.0]], dtype=float32)
Boolean Mask Semantics
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. The only exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full.
a = mx.arange(1000).reshape(10, 10, 10)
a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape (10, 10) applies to the first two axes, so a[mask] selects the 1-D slices a[i, j, :] where mask[i, j] is True. Shapes such as (1, 10, 10) or (10, 10, 1) do not match the indexed axes and therefore raise errors.