mlx::core namespace.
Array Creation
Basic Constructors
// Scalar array
auto x = mx::array(1.0);
auto y = mx::array(42, mx::int32);
// From initializer list
auto z = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
Constant Arrays
// Fill with zeros
mx::array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
auto x = mx::zeros({3, 4}, mx::float32);
// Fill with ones
mx::array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
auto y = mx::ones({3, 4});
// Fill with custom value
mx::array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {});
auto z = mx::full({3, 4}, 3.14f);
Range Arrays
// Evenly spaced values
mx::array arange(double start, double stop, double step,
Dtype dtype, StreamOrDevice s = {});
auto x = mx::arange(0, 10, 1); // [0, 1, 2, ..., 9]
auto y = mx::arange(0.0, 1.0, 0.1); // [0.0, 0.1, 0.2, ...]
// Linearly spaced values
mx::array linspace(double start, double stop, int num = 50,
Dtype dtype = float32, StreamOrDevice s = {});
auto z = mx::linspace(0.0, 1.0, 5); // [0.0, 0.25, 0.5, 0.75, 1.0]
Identity and Diagonal
// Identity matrix
mx::array identity(int n, Dtype dtype, StreamOrDevice s = {});
auto I = mx::identity(3); // 3x3 identity
// Eye (diagonal ones)
mx::array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
auto E = mx::eye(3, 3, 0); // Same as identity
Shape Manipulation
Reshaping
// Reshape array
mx::array reshape(const array& a, Shape shape, StreamOrDevice s = {});
auto x = mx::ones({6});
auto y = mx::reshape(x, {2, 3});
// Flatten array
mx::array flatten(const array& a, StreamOrDevice s = {});
auto z = mx::flatten(y); // Back to {6}
Transposing
// Transpose with custom axis order
mx::array transpose(const array& a, std::vector<int> axes,
StreamOrDevice s = {});
auto x = mx::ones({2, 3, 4});
auto y = mx::transpose(x, {2, 0, 1}); // Shape: {4, 2, 3}
// Reverse all axes
mx::array transpose(const array& a, StreamOrDevice s = {});
auto z = mx::transpose(x); // Shape: {4, 3, 2}
// Swap two axes
mx::array swapaxes(const array& a, int axis1, int axis2,
StreamOrDevice s = {});
Expanding and Squeezing
// Add dimensions
mx::array expand_dims(const array& a, int axis, StreamOrDevice s = {});
auto x = mx::ones({3, 4});
auto y = mx::expand_dims(x, 0); // Shape: {1, 3, 4}
// Remove singleton dimensions
mx::array squeeze(const array& a, StreamOrDevice s = {});
auto z = mx::squeeze(y); // Back to {3, 4}
// Squeeze specific axis
mx::array squeeze(const array& a, int axis, StreamOrDevice s = {});
Arithmetic Operations
Element-wise Operations
// Addition
mx::array add(const array& a, const array& b, StreamOrDevice s = {});
auto z = mx::add(x, y);
auto z = x + y; // Operator overload
// Subtraction
mx::array subtract(const array& a, const array& b, StreamOrDevice s = {});
auto z = x - y;
// Multiplication
mx::array multiply(const array& a, const array& b, StreamOrDevice s = {});
auto z = x * y;
// Division
mx::array divide(const array& a, const array& b, StreamOrDevice s = {});
auto z = x / y;
// Remainder
mx::array remainder(const array& a, const array& b, StreamOrDevice s = {});
auto z = x % y;
Mathematical Functions
// Exponential and logarithm
mx::array exp(const array& a, StreamOrDevice s = {});
mx::array log(const array& a, StreamOrDevice s = {});
mx::array log2(const array& a, StreamOrDevice s = {});
mx::array log10(const array& a, StreamOrDevice s = {});
mx::array log1p(const array& a, StreamOrDevice s = {}); // log(1 + a)
// Power and roots
mx::array square(const array& a, StreamOrDevice s = {});
mx::array sqrt(const array& a, StreamOrDevice s = {});
mx::array power(const array& a, const array& b, StreamOrDevice s = {});
// Trigonometric
mx::array sin(const array& a, StreamOrDevice s = {});
mx::array cos(const array& a, StreamOrDevice s = {});
mx::array tan(const array& a, StreamOrDevice s = {});
mx::array arcsin(const array& a, StreamOrDevice s = {});
mx::array arccos(const array& a, StreamOrDevice s = {});
mx::array arctan(const array& a, StreamOrDevice s = {});
// Hyperbolic
mx::array sinh(const array& a, StreamOrDevice s = {});
mx::array cosh(const array& a, StreamOrDevice s = {});
mx::array tanh(const array& a, StreamOrDevice s = {});
Linear Algebra
Matrix Operations
// Matrix multiplication
mx::array matmul(const array& a, const array& b, StreamOrDevice s = {});
auto C = mx::matmul(A, B);
// Outer product
mx::array outer(const array& a, const array& b, StreamOrDevice s = {});
// Inner product
mx::array inner(const array& a, const array& b, StreamOrDevice s = {});
// Tensor contraction
mx::array tensordot(const array& a, const array& b, int axis = 2,
StreamOrDevice s = {});
Reductions
Common Reductions
// Sum
mx::array sum(const array& a, StreamOrDevice s = {});
mx::array sum(const array& a, int axis, bool keepdims = false,
StreamOrDevice s = {});
auto total = mx::sum(x); // Sum all elements
auto row_sums = mx::sum(x, 1); // Sum along axis 1
// Mean
mx::array mean(const array& a, StreamOrDevice s = {});
mx::array mean(const array& a, int axis, bool keepdims = false,
StreamOrDevice s = {});
// Maximum and minimum
mx::array max(const array& a, StreamOrDevice s = {});
mx::array min(const array& a, StreamOrDevice s = {});
// Product
mx::array prod(const array& a, StreamOrDevice s = {});
Statistical Reductions
// Variance
mx::array var(const array& a, bool keepdims, int ddof = 0,
StreamOrDevice s = {});
// Standard deviation
mx::array std(const array& a, bool keepdims, int ddof = 0,
StreamOrDevice s = {});
// Median
mx::array median(const array& a, StreamOrDevice s = {});
Cumulative Operations
// Cumulative sum
mx::array cumsum(const array& a, int axis, bool reverse = false,
bool inclusive = true, StreamOrDevice s = {});
// Cumulative product
mx::array cumprod(const array& a, int axis, bool reverse = false,
bool inclusive = true, StreamOrDevice s = {});
Comparison and Logic
Comparison Operations
mx::array equal(const array& a, const array& b, StreamOrDevice s = {});
auto mask = (x == y);
mx::array not_equal(const array& a, const array& b, StreamOrDevice s = {});
auto mask = (x != y);
mx::array greater(const array& a, const array& b, StreamOrDevice s = {});
auto mask = (x > y);
mx::array less(const array& a, const array& b, StreamOrDevice s = {});
auto mask = (x < y);
Logical Operations
mx::array logical_and(const array& a, const array& b, StreamOrDevice s = {});
auto result = (a && b);
mx::array logical_or(const array& a, const array& b, StreamOrDevice s = {});
auto result = (a || b);
mx::array logical_not(const array& a, StreamOrDevice s = {});
Testing
// Test for NaN, infinity
mx::array isnan(const array& a, StreamOrDevice s = {});
mx::array isinf(const array& a, StreamOrDevice s = {});
mx::array isfinite(const array& a, StreamOrDevice s = {});
// All/any reductions
mx::array all(const array& a, StreamOrDevice s = {});
mx::array any(const array& a, StreamOrDevice s = {});
Indexing and Slicing
Basic Indexing
// Slice array
mx::array slice(const array& a, Shape start, Shape stop, Shape strides,
StreamOrDevice s = {});
auto x = mx::arange(10);
auto y = mx::slice(x, {2}, {8}, {2}); // [2, 4, 6]
// Take elements
mx::array take(const array& a, const array& indices, int axis,
StreamOrDevice s = {});
auto indices = mx::array({0, 2, 4});
auto y = mx::take(x, indices, 0);
Advanced Indexing
// Gather with indices
mx::array gather(const array& a, const std::vector<array>& indices,
const std::vector<int>& axes, const Shape& slice_sizes,
StreamOrDevice s = {});
// Scatter updates
mx::array scatter(const array& a, const std::vector<array>& indices,
const array& updates, const std::vector<int>& axes,
StreamOrDevice s = {});
Type and Data Access
Type Conversion
// Cast to different type
mx::array astype(array a, Dtype dtype, StreamOrDevice s = {});
auto y = mx::astype(x, mx::float16);
Data Access
// Get scalar value (evaluates array)
template <typename T>
T item();
auto x = mx::array(3.14f);
float val = x.item<float>();
// Get data pointer (array must be evaluated)
template <typename T>
const T* data();
auto ptr = x.data<float>();
Array Properties
// Shape and size
const Shape& shape() const;
int shape(int axis) const;
size_t size() const;
int ndim() const;
// Data type
Dtype dtype() const;
// Strides
const Strides& strides() const;
Random Number Generation
Seemlx::core::random namespace:
// Normal distribution
mx::array random::normal(const Shape& shape, Dtype dtype = float32,
StreamOrDevice s = {});
// Uniform distribution
mx::array random::uniform(double low, double high, const Shape& shape,
Dtype dtype = float32, StreamOrDevice s = {});
// Set seed
void random::seed(uint64_t seed);
Convolutions
// 1D convolution
mx::array conv1d(const array& input, const array& weight,
int stride = 1, int padding = 0, int dilation = 1,
int groups = 1, StreamOrDevice s = {});
// 2D convolution
mx::array conv2d(const array& input, const array& weight,
const std::pair<int, int>& stride = {1, 1},
const std::pair<int, int>& padding = {0, 0},
const std::pair<int, int>& dilation = {1, 1},
int groups = 1, StreamOrDevice s = {});
See Also
- For the Python API reference, see Operations
- For custom operations, see Building Extensions
- For GPU kernels, see Metal Kernels