Skip to main content
This page documents the core C++ operations available in MLX. All operations are in the mlx::core namespace.

Array Creation

Basic Constructors

// Scalar array
auto x = mx::array(1.0);
auto y = mx::array(42, mx::int32);

// From initializer list
auto z = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});

Constant Arrays

// Fill with zeros
mx::array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
auto x = mx::zeros({3, 4}, mx::float32);

// Fill with ones
mx::array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
auto y = mx::ones({3, 4});

// Fill with custom value
mx::array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {});
auto z = mx::full({3, 4}, 3.14f);

Range Arrays

// Evenly spaced values
mx::array arange(double start, double stop, double step, 
                  Dtype dtype, StreamOrDevice s = {});
                  
auto x = mx::arange(0, 10, 1);        // [0, 1, 2, ..., 9]
auto y = mx::arange(0.0, 1.0, 0.1);   // [0.0, 0.1, 0.2, ...]

// Linearly spaced values
mx::array linspace(double start, double stop, int num = 50,
                    Dtype dtype = float32, StreamOrDevice s = {});
                    
auto z = mx::linspace(0.0, 1.0, 5);   // [0.0, 0.25, 0.5, 0.75, 1.0]

Identity and Diagonal

// Identity matrix
mx::array identity(int n, Dtype dtype, StreamOrDevice s = {});
auto I = mx::identity(3);  // 3x3 identity

// Eye (diagonal ones)
mx::array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
auto E = mx::eye(3, 3, 0);  // Same as identity

Shape Manipulation

Reshaping

// Reshape array
mx::array reshape(const array& a, Shape shape, StreamOrDevice s = {});
auto x = mx::ones({6});
auto y = mx::reshape(x, {2, 3});

// Flatten array
mx::array flatten(const array& a, StreamOrDevice s = {});
auto z = mx::flatten(y);  // Back to {6}

Transposing

// Transpose with custom axis order
mx::array transpose(const array& a, std::vector<int> axes, 
                     StreamOrDevice s = {});
                     
auto x = mx::ones({2, 3, 4});
auto y = mx::transpose(x, {2, 0, 1});  // Shape: {4, 2, 3}

// Reverse all axes
mx::array transpose(const array& a, StreamOrDevice s = {});
auto z = mx::transpose(x);  // Shape: {4, 3, 2}

// Swap two axes
mx::array swapaxes(const array& a, int axis1, int axis2, 
                    StreamOrDevice s = {});

Expanding and Squeezing

// Add dimensions
mx::array expand_dims(const array& a, int axis, StreamOrDevice s = {});
auto x = mx::ones({3, 4});
auto y = mx::expand_dims(x, 0);  // Shape: {1, 3, 4}

// Remove singleton dimensions
mx::array squeeze(const array& a, StreamOrDevice s = {});
auto z = mx::squeeze(y);  // Back to {3, 4}

// Squeeze specific axis
mx::array squeeze(const array& a, int axis, StreamOrDevice s = {});

Arithmetic Operations

Element-wise Operations

// Addition
mx::array add(const array& a, const array& b, StreamOrDevice s = {});
auto z = mx::add(x, y);
auto z = x + y;  // Operator overload

// Subtraction
mx::array subtract(const array& a, const array& b, StreamOrDevice s = {});
auto z = x - y;

// Multiplication
mx::array multiply(const array& a, const array& b, StreamOrDevice s = {});
auto z = x * y;

// Division
mx::array divide(const array& a, const array& b, StreamOrDevice s = {});
auto z = x / y;

// Remainder
mx::array remainder(const array& a, const array& b, StreamOrDevice s = {});
auto z = x % y;

Mathematical Functions

// Exponential and logarithm
mx::array exp(const array& a, StreamOrDevice s = {});
mx::array log(const array& a, StreamOrDevice s = {});
mx::array log2(const array& a, StreamOrDevice s = {});
mx::array log10(const array& a, StreamOrDevice s = {});
mx::array log1p(const array& a, StreamOrDevice s = {});  // log(1 + a)

// Power and roots
mx::array square(const array& a, StreamOrDevice s = {});
mx::array sqrt(const array& a, StreamOrDevice s = {});
mx::array power(const array& a, const array& b, StreamOrDevice s = {});

// Trigonometric
mx::array sin(const array& a, StreamOrDevice s = {});
mx::array cos(const array& a, StreamOrDevice s = {});
mx::array tan(const array& a, StreamOrDevice s = {});
mx::array arcsin(const array& a, StreamOrDevice s = {});
mx::array arccos(const array& a, StreamOrDevice s = {});
mx::array arctan(const array& a, StreamOrDevice s = {});

// Hyperbolic
mx::array sinh(const array& a, StreamOrDevice s = {});
mx::array cosh(const array& a, StreamOrDevice s = {});
mx::array tanh(const array& a, StreamOrDevice s = {});

Linear Algebra

Matrix Operations

// Matrix multiplication
mx::array matmul(const array& a, const array& b, StreamOrDevice s = {});
auto C = mx::matmul(A, B);

// Outer product
mx::array outer(const array& a, const array& b, StreamOrDevice s = {});

// Inner product
mx::array inner(const array& a, const array& b, StreamOrDevice s = {});

// Tensor contraction
mx::array tensordot(const array& a, const array& b, int axis = 2,
                     StreamOrDevice s = {});

Reductions

Common Reductions

// Sum
mx::array sum(const array& a, StreamOrDevice s = {});
mx::array sum(const array& a, int axis, bool keepdims = false,
               StreamOrDevice s = {});
               
auto total = mx::sum(x);           // Sum all elements
auto row_sums = mx::sum(x, 1);     // Sum along axis 1

// Mean
mx::array mean(const array& a, StreamOrDevice s = {});
mx::array mean(const array& a, int axis, bool keepdims = false,
                StreamOrDevice s = {});

// Maximum and minimum
mx::array max(const array& a, StreamOrDevice s = {});
mx::array min(const array& a, StreamOrDevice s = {});

// Product
mx::array prod(const array& a, StreamOrDevice s = {});

Statistical Reductions

// Variance
mx::array var(const array& a, bool keepdims, int ddof = 0,
               StreamOrDevice s = {});
               
// Standard deviation
mx::array std(const array& a, bool keepdims, int ddof = 0,
               StreamOrDevice s = {});
               
// Median
mx::array median(const array& a, StreamOrDevice s = {});

Cumulative Operations

// Cumulative sum
mx::array cumsum(const array& a, int axis, bool reverse = false,
                  bool inclusive = true, StreamOrDevice s = {});
                  
// Cumulative product
mx::array cumprod(const array& a, int axis, bool reverse = false,
                   bool inclusive = true, StreamOrDevice s = {});

Comparison and Logic

Comparison Operations

mx::array equal(const array& a, const array& b, StreamOrDevice s = {});
auto mask = (x == y);

mx::array not_equal(const array& a, const array& b, StreamOrDevice s = {});
auto mask = (x != y);

mx::array greater(const array& a, const array& b, StreamOrDevice s = {});
auto mask = (x > y);

mx::array less(const array& a, const array& b, StreamOrDevice s = {});
auto mask = (x < y);

Logical Operations

mx::array logical_and(const array& a, const array& b, StreamOrDevice s = {});
auto result = (a && b);

mx::array logical_or(const array& a, const array& b, StreamOrDevice s = {});
auto result = (a || b);

mx::array logical_not(const array& a, StreamOrDevice s = {});

Testing

// Test for NaN, infinity
mx::array isnan(const array& a, StreamOrDevice s = {});
mx::array isinf(const array& a, StreamOrDevice s = {});
mx::array isfinite(const array& a, StreamOrDevice s = {});

// All/any reductions
mx::array all(const array& a, StreamOrDevice s = {});
mx::array any(const array& a, StreamOrDevice s = {});

Indexing and Slicing

Basic Indexing

// Slice array
mx::array slice(const array& a, Shape start, Shape stop, Shape strides,
                 StreamOrDevice s = {});
                 
auto x = mx::arange(10);
auto y = mx::slice(x, {2}, {8}, {2});  // [2, 4, 6]

// Take elements
mx::array take(const array& a, const array& indices, int axis,
                StreamOrDevice s = {});
                
auto indices = mx::array({0, 2, 4});
auto y = mx::take(x, indices, 0);

Advanced Indexing

// Gather with indices
mx::array gather(const array& a, const std::vector<array>& indices,
                  const std::vector<int>& axes, const Shape& slice_sizes,
                  StreamOrDevice s = {});

// Scatter updates
mx::array scatter(const array& a, const std::vector<array>& indices,
                   const array& updates, const std::vector<int>& axes,
                   StreamOrDevice s = {});

Type and Data Access

Type Conversion

// Cast to different type
mx::array astype(array a, Dtype dtype, StreamOrDevice s = {});
auto y = mx::astype(x, mx::float16);

Data Access

// Get scalar value (evaluates array)
template <typename T>
T item();

auto x = mx::array(3.14f);
float val = x.item<float>();

// Get data pointer (array must be evaluated)
template <typename T>
const T* data();

auto ptr = x.data<float>();

Array Properties

// Shape and size
const Shape& shape() const;
int shape(int axis) const;
size_t size() const;
int ndim() const;

// Data type
Dtype dtype() const;

// Strides
const Strides& strides() const;

Random Number Generation

See mlx::core::random namespace:
// Normal distribution
mx::array random::normal(const Shape& shape, Dtype dtype = float32,
                          StreamOrDevice s = {});
                          
// Uniform distribution
mx::array random::uniform(double low, double high, const Shape& shape,
                           Dtype dtype = float32, StreamOrDevice s = {});
                           
// Set seed
void random::seed(uint64_t seed);

Convolutions

// 1D convolution
mx::array conv1d(const array& input, const array& weight,
                  int stride = 1, int padding = 0, int dilation = 1,
                  int groups = 1, StreamOrDevice s = {});

// 2D convolution
mx::array conv2d(const array& input, const array& weight,
                  const std::pair<int, int>& stride = {1, 1},
                  const std::pair<int, int>& padding = {0, 0},
                  const std::pair<int, int>& dilation = {1, 1},
                  int groups = 1, StreamOrDevice s = {});

See Also