Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/ml-explore/mlx/llms.txt

Use this file to discover all available pages before exploring further.

You can extend MLX with custom operations on the CPU or GPU. This guide explains how to build custom C++ extensions with a complete example.

Overview

MLX extensions consist of two main components:
  1. Operations: Front-end functions that operate on arrays
  2. Primitives: Building blocks that define computation and transformations
Operations build the computation graph, while primitives provide the rules for evaluating and transforming that graph.

Example: Axpby Operation

We’ll implement a custom operation that computes z = alpha * x + beta * y, combining two scaled arrays.

Simple Implementation

First, let’s implement it using existing MLX operations:
#include "mlx/mlx.h"

namespace mx = mlx::core;

array axpby(
    const array& x,
    const array& y,
    const float alpha,
    const float beta,
    StreamOrDevice s = {}
) {
  // Scale x and y
  auto ax = mx::multiply(array(alpha), x, s);
  auto by = mx::multiply(array(beta), y, s);
  
  // Add and return
  return mx::add(ax, by, s);
}
This works but creates multiple operations in the graph. For better performance, we’ll create a custom primitive.

Creating a Primitive

Define the Primitive Class

A primitive inherits from Primitive and implements evaluation and transformation methods:
axpby.h
#include "mlx/mlx.h"

class Axpby : public mx::Primitive {
 public:
  explicit Axpby(mx::Stream stream, float alpha, float beta)
      : Primitive(stream), alpha_(alpha), beta_(beta) {}
  
  // Evaluate on CPU
  void eval_cpu(
      const std::vector<mx::array>& inputs,
      std::vector<mx::array>& outputs) override;
  
  // Evaluate on GPU
  void eval_gpu(
      const std::vector<mx::array>& inputs,
      std::vector<mx::array>& outputs) override;
  
  // Jacobian-vector product (forward mode AD)
  std::vector<mx::array> jvp(
      const std::vector<mx::array>& primals,
      const std::vector<mx::array>& tangents,
      const std::vector<int>& argnums) override;
  
  // Vector-Jacobian product (reverse mode AD)
  std::vector<mx::array> vjp(
      const std::vector<mx::array>& primals,
      const std::vector<mx::array>& cotangents,
      const std::vector<int>& argnums,
      const std::vector<mx::array>& outputs) override;
  
  // Vectorization
  std::pair<std::vector<mx::array>, std::vector<int>> vmap(
      const std::vector<mx::array>& inputs,
      const std::vector<int>& axes) override;
  
  // Name for debugging
  const char* name() const override { return "Axpby"; }
  
  // Equivalence check
  bool is_equivalent(const Primitive& other) const override;
  
 private:
  float alpha_;
  float beta_;
};

Implement the Operation

The operation handles type promotion and broadcasting:
axpby.cpp
array axpby(
    const array& x,
    const array& y,
    const float alpha,
    const float beta,
    StreamOrDevice s = {}
) {
  // Promote dtypes between x and y
  auto promoted_dtype = mx::promote_types(x.dtype(), y.dtype());
  
  // Upcast to float32 for non-floating point inputs
  auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
      ? promoted_dtype
      : mx::promote_types(promoted_dtype, mx::float32);
  
  // Cast x and y to output dtype
  auto x_casted = mx::astype(x, out_dtype, s);
  auto y_casted = mx::astype(y, out_dtype, s);
  
  // Broadcast shapes
  auto broadcasted_inputs = mx::broadcast_arrays({x_casted, y_casted}, s);
  auto out_shape = broadcasted_inputs[0].shape();
  
  // Construct output array with Axpby primitive
  return array(
      out_shape,
      out_dtype,
      std::make_shared<Axpby>(mx::to_stream(s), alpha, beta),
      broadcasted_inputs
  );
}

CPU Implementation

CPU Kernel

Implement the element-wise operation on CPU:
template <typename T>
void axpby_impl(
    const mx::array& x,
    const mx::array& y,
    mx::array& out,
    float alpha_,
    float beta_,
    mx::Stream stream
) {
  // Allocate output
  out.set_data(mx::allocator::malloc(out.nbytes()));
  
  // Get CPU command encoder
  auto& encoder = mx::cpu::get_command_encoder(stream);
  encoder.set_input_array(x);
  encoder.set_input_array(y);
  encoder.set_output_array(out);
  
  // Launch CPU kernel
  encoder.dispatch([x_ptr = x.data<T>(),
                    y_ptr = y.data<T>(),
                    out_ptr = out.data<T>(),
                    size = out.size(),
                    shape = out.shape(),
                    x_strides = x.strides(),
                    y_strides = y.strides(),
                    alpha_,
                    beta_]() {
    
    T alpha = static_cast<T>(alpha_);
    T beta = static_cast<T>(beta_);
    
    // Process each output element
    for (size_t out_idx = 0; out_idx < size; out_idx++) {
      auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
      auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
      out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
    }
  });
}

CPU Evaluation

Dispatch to the correct type:
void Axpby::eval_cpu(
    const std::vector<mx::array>& inputs,
    std::vector<mx::array>& outputs
) {
  auto& x = inputs[0];
  auto& y = inputs[1];
  auto& out = outputs[0];
  
  if (out.dtype() == mx::float32) {
    axpby_impl<float>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::float16) {
    axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::bfloat16) {
    axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::complex64) {
    axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
  } else {
    throw std::runtime_error(
        "Axpby only supports floating point types.");
  }
}

GPU Implementation

Metal Kernel

Write a Metal kernel for GPU execution:
axpby.metal
#include <metal_stdlib>
using namespace metal;

template <typename T>
[[kernel]] void axpby_general(
    device const T* x [[buffer(0)]],
    device const T* y [[buffer(1)]],
    device T* out [[buffer(2)]],
    constant const float& alpha [[buffer(3)]],
    constant const float& beta [[buffer(4)]],
    constant const int* shape [[buffer(5)]],
    constant const int64_t* x_strides [[buffer(6)]],
    constant const int64_t* y_strides [[buffer(7)]],
    constant const int& ndim [[buffer(8)]],
    uint index [[thread_position_in_grid]]
) {
  // Convert linear index to offsets
  auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
  auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
  
  // Compute and write result
  out[index] = static_cast<T>(alpha) * x[x_offset] + 
               static_cast<T>(beta) * y[y_offset];
}

// Instantiate for each type
instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)

GPU Evaluation

void Axpby::eval_gpu(
    const std::vector<array>& inputs,
    std::vector<array>& outputs
) {
  auto& x = inputs[0];
  auto& y = inputs[1];
  auto& out = outputs[0];
  
  auto& s = stream();
  auto& d = mx::metal::device(s.device);
  
  // Allocate output
  out.set_data(mx::allocator::malloc(out.nbytes()));
  
  // Get kernel name
  std::string kname = "axpby_general_" + mx::type_to_name(out);
  
  // Load metal library and kernel
  auto lib = d.get_library("mlx_ext");
  auto kernel = d.get_kernel(kname, lib);
  
  // Get command encoder
  auto& compute_encoder = d.get_command_encoder(s.index);
  compute_encoder.set_compute_pipeline_state(kernel);
  
  // Set kernel arguments
  int ndim = out.ndim();
  size_t nelem = out.size();
  
  compute_encoder.set_input_array(x, 0);
  compute_encoder.set_input_array(y, 1);
  compute_encoder.set_output_array(out, 2);
  compute_encoder.set_bytes(alpha_, 3);
  compute_encoder.set_bytes(beta_, 4);
  compute_encoder.set_vector_bytes(x.shape(), 5);
  compute_encoder.set_vector_bytes(x.strides(), 6);
  compute_encoder.set_vector_bytes(y.strides(), 7);
  compute_encoder.set_bytes(ndim, 8);
  
  // Launch kernel
  size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
  MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
  MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
  
  compute_encoder.dispatch_threads(grid_dims, group_dims);
}

Automatic Differentiation

Forward Mode (JVP)

std::vector<array> Axpby::jvp(
    const std::vector<array>& primals,
    const std::vector<array>& tangents,
    const std::vector<int>& argnums
) {
  // If only one argument
  if (argnums.size() == 1) {
    auto scale = argnums[0] == 0 ? alpha_ : beta_;
    auto scale_arr = array(scale, tangents[0].dtype());
    return {mx::multiply(scale_arr, tangents[0], stream())};
  }
  
  // If both arguments
  return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
}

Reverse Mode (VJP)

std::vector<array> Axpby::vjp(
    const std::vector<array>& primals,
    const std::vector<array>& cotangents,
    const std::vector<int>& argnums,
    const std::vector<array>& /* unused */
) {
  std::vector<array> vjps;
  for (auto arg : argnums) {
    auto scale = arg == 0 ? alpha_ : beta_;
    auto scale_arr = array(scale, cotangents[0].dtype());
    vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
  }
  return vjps;
}

Building with CMake

Directory Structure

extensions/
├── axpby/
│   ├── axpby.cpp
│   ├── axpby.h
│   └── axpby.metal
├── CMakeLists.txt
└── setup.py

CMakeLists.txt

# Add C++ library
add_library(mlx_ext)

target_sources(
    mlx_ext
    PUBLIC
    ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)

target_include_directories(
    mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)

# Link to MLX
target_link_libraries(mlx_ext PUBLIC mlx)

# Build Metal library
if(MLX_BUILD_METAL)
  mlx_build_metallib(
      TARGET mlx_ext_metallib
      TITLE mlx_ext
      SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
      INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
      OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
  )
  
  add_dependencies(mlx_ext mlx_ext_metallib)
endif()

Python Bindings

Use nanobind to create Python bindings:
bindings.cpp
#include <nanobind/nanobind.h>
#include "axpby.h"

namespace nb = nanobind;

NB_MODULE(_ext, m) {
  m.doc() = "Custom MLX extension";
  
  m.def(
      "axpby",
      &axpby,
      "x"_a,
      "y"_a,
      "alpha"_a,
      "beta"_a,
      nb::kw_only(),
      "stream"_a = nb::none(),
      R"(
      Scale and sum two vectors element-wise.
      
      Computes: z = alpha * x + beta * y
      
      Args:
          x (array): Input array
          y (array): Input array  
          alpha (float): Scaling factor for x
          beta (float): Scaling factor for y
      
      Returns:
          array: alpha * x + beta * y
      )"
  );
}

setup.py

setup.py
from mlx import extension
from setuptools import setup

if __name__ == "__main__":
    setup(
        name="mlx_sample_extensions",
        version="0.0.0",
        description="Sample C++ extension for MLX",
        ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
        cmdclass={"build_ext": extension.CMakeBuild},
        packages=["mlx_sample_extensions"],
        package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
        zip_safe=False,
        python_requires=">=3.8",
    )

Building the Extension

1

Install build dependencies

pip install -r requirements.txt
2

Build in place for development

python setup.py build_ext -j8 --inplace
3

Or install the package

pip install .

Usage

Now you can use your custom operation:
import mlx.core as mx
from mlx_sample_extensions import axpby

a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0)

print(f"Result: {c}")  # Should be all 6.0
print(f"Correct: {mx.all(c == 6.0).item()}")  # True

Performance

Custom primitives can significantly improve performance by fusing operations:
import mlx.core as mx
import time

# Simple version using basic ops
def simple_axpby(x, y, alpha, beta):
    return alpha * x + beta * y

# Custom primitive version
from mlx_sample_extensions import axpby

M, N = 4096, 4096
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))

# Benchmark
for impl in [simple_axpby, axpby]:
    # Warmup
    for _ in range(5):
        z = impl(x, y, 4.0, 2.0)
        mx.eval(z)
    
    # Time
    start = time.perf_counter()
    for _ in range(100):
        z = impl(x, y, 4.0, 2.0)
        mx.eval(z)
    elapsed = time.perf_counter() - start
    print(f"{impl.__name__}: {elapsed*10:.3f} ms")
Expected speedup: 2-3x for this fused operation.

Next Steps

Metal Kernels

Write custom Metal GPU kernels

C++ Operations

Browse the C++ API reference

Resources