Skip to main content
You can extend MLX with custom operations on the CPU or GPU. This guide explains how to build custom C++ extensions with a complete example.

Overview

MLX extensions consist of two main components:
  1. Operations: Front-end functions that operate on arrays
  2. Primitives: Building blocks that define computation and transformations
Operations build the computation graph, while primitives provide the rules for evaluating and transforming that graph.

Example: Axpby Operation

We’ll implement a custom operation that computes z = alpha * x + beta * y, combining two scaled arrays.

Simple Implementation

First, let’s implement it using existing MLX operations:
#include "mlx/mlx.h"

namespace mx = mlx::core;

array axpby(
    const array& x,
    const array& y,
    const float alpha,
    const float beta,
    StreamOrDevice s = {}
) {
  // Scale x and y
  auto ax = mx::multiply(array(alpha), x, s);
  auto by = mx::multiply(array(beta), y, s);
  
  // Add and return
  return mx::add(ax, by, s);
}
This works but creates multiple operations in the graph. For better performance, we’ll create a custom primitive.

Creating a Primitive

Define the Primitive Class

A primitive inherits from Primitive and implements evaluation and transformation methods:
axpby.h
#include "mlx/mlx.h"

class Axpby : public mx::Primitive {
 public:
  explicit Axpby(mx::Stream stream, float alpha, float beta)
      : Primitive(stream), alpha_(alpha), beta_(beta) {}
  
  // Evaluate on CPU
  void eval_cpu(
      const std::vector<mx::array>& inputs,
      std::vector<mx::array>& outputs) override;
  
  // Evaluate on GPU
  void eval_gpu(
      const std::vector<mx::array>& inputs,
      std::vector<mx::array>& outputs) override;
  
  // Jacobian-vector product (forward mode AD)
  std::vector<mx::array> jvp(
      const std::vector<mx::array>& primals,
      const std::vector<mx::array>& tangents,
      const std::vector<int>& argnums) override;
  
  // Vector-Jacobian product (reverse mode AD)
  std::vector<mx::array> vjp(
      const std::vector<mx::array>& primals,
      const std::vector<mx::array>& cotangents,
      const std::vector<int>& argnums,
      const std::vector<mx::array>& outputs) override;
  
  // Vectorization
  std::pair<std::vector<mx::array>, std::vector<int>> vmap(
      const std::vector<mx::array>& inputs,
      const std::vector<int>& axes) override;
  
  // Name for debugging
  const char* name() const override { return "Axpby"; }
  
  // Equivalence check
  bool is_equivalent(const Primitive& other) const override;
  
 private:
  float alpha_;
  float beta_;
};

Implement the Operation

The operation handles type promotion and broadcasting:
axpby.cpp
array axpby(
    const array& x,
    const array& y,
    const float alpha,
    const float beta,
    StreamOrDevice s = {}
) {
  // Promote dtypes between x and y
  auto promoted_dtype = mx::promote_types(x.dtype(), y.dtype());
  
  // Upcast to float32 for non-floating point inputs
  auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
      ? promoted_dtype
      : mx::promote_types(promoted_dtype, mx::float32);
  
  // Cast x and y to output dtype
  auto x_casted = mx::astype(x, out_dtype, s);
  auto y_casted = mx::astype(y, out_dtype, s);
  
  // Broadcast shapes
  auto broadcasted_inputs = mx::broadcast_arrays({x_casted, y_casted}, s);
  auto out_shape = broadcasted_inputs[0].shape();
  
  // Construct output array with Axpby primitive
  return array(
      out_shape,
      out_dtype,
      std::make_shared<Axpby>(mx::to_stream(s), alpha, beta),
      broadcasted_inputs
  );
}

CPU Implementation

CPU Kernel

Implement the element-wise operation on CPU:
template <typename T>
void axpby_impl(
    const mx::array& x,
    const mx::array& y,
    mx::array& out,
    float alpha_,
    float beta_,
    mx::Stream stream
) {
  // Allocate output
  out.set_data(mx::allocator::malloc(out.nbytes()));
  
  // Get CPU command encoder
  auto& encoder = mx::cpu::get_command_encoder(stream);
  encoder.set_input_array(x);
  encoder.set_input_array(y);
  encoder.set_output_array(out);
  
  // Launch CPU kernel
  encoder.dispatch([x_ptr = x.data<T>(),
                    y_ptr = y.data<T>(),
                    out_ptr = out.data<T>(),
                    size = out.size(),
                    shape = out.shape(),
                    x_strides = x.strides(),
                    y_strides = y.strides(),
                    alpha_,
                    beta_]() {
    
    T alpha = static_cast<T>(alpha_);
    T beta = static_cast<T>(beta_);
    
    // Process each output element
    for (size_t out_idx = 0; out_idx < size; out_idx++) {
      auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
      auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
      out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
    }
  });
}

CPU Evaluation

Dispatch to the correct type:
void Axpby::eval_cpu(
    const std::vector<mx::array>& inputs,
    std::vector<mx::array>& outputs
) {
  auto& x = inputs[0];
  auto& y = inputs[1];
  auto& out = outputs[0];
  
  if (out.dtype() == mx::float32) {
    axpby_impl<float>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::float16) {
    axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::bfloat16) {
    axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
  } else if (out.dtype() == mx::complex64) {
    axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
  } else {
    throw std::runtime_error(
        "Axpby only supports floating point types.");
  }
}

GPU Implementation

Metal Kernel

Write a Metal kernel for GPU execution:
axpby.metal
#include <metal_stdlib>
using namespace metal;

template <typename T>
[[kernel]] void axpby_general(
    device const T* x [[buffer(0)]],
    device const T* y [[buffer(1)]],
    device T* out [[buffer(2)]],
    constant const float& alpha [[buffer(3)]],
    constant const float& beta [[buffer(4)]],
    constant const int* shape [[buffer(5)]],
    constant const int64_t* x_strides [[buffer(6)]],
    constant const int64_t* y_strides [[buffer(7)]],
    constant const int& ndim [[buffer(8)]],
    uint index [[thread_position_in_grid]]
) {
  // Convert linear index to offsets
  auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
  auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
  
  // Compute and write result
  out[index] = static_cast<T>(alpha) * x[x_offset] + 
               static_cast<T>(beta) * y[y_offset];
}

// Instantiate for each type
instantiate_kernel("axpby_general_float32", axpby_general, float)
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)

GPU Evaluation

void Axpby::eval_gpu(
    const std::vector<array>& inputs,
    std::vector<array>& outputs
) {
  auto& x = inputs[0];
  auto& y = inputs[1];
  auto& out = outputs[0];
  
  auto& s = stream();
  auto& d = mx::metal::device(s.device);
  
  // Allocate output
  out.set_data(mx::allocator::malloc(out.nbytes()));
  
  // Get kernel name
  std::string kname = "axpby_general_" + mx::type_to_name(out);
  
  // Load metal library and kernel
  auto lib = d.get_library("mlx_ext");
  auto kernel = d.get_kernel(kname, lib);
  
  // Get command encoder
  auto& compute_encoder = d.get_command_encoder(s.index);
  compute_encoder.set_compute_pipeline_state(kernel);
  
  // Set kernel arguments
  int ndim = out.ndim();
  size_t nelem = out.size();
  
  compute_encoder.set_input_array(x, 0);
  compute_encoder.set_input_array(y, 1);
  compute_encoder.set_output_array(out, 2);
  compute_encoder.set_bytes(alpha_, 3);
  compute_encoder.set_bytes(beta_, 4);
  compute_encoder.set_vector_bytes(x.shape(), 5);
  compute_encoder.set_vector_bytes(x.strides(), 6);
  compute_encoder.set_vector_bytes(y.strides(), 7);
  compute_encoder.set_bytes(ndim, 8);
  
  // Launch kernel
  size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
  MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
  MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
  
  compute_encoder.dispatch_threads(grid_dims, group_dims);
}

Automatic Differentiation

Forward Mode (JVP)

std::vector<array> Axpby::jvp(
    const std::vector<array>& primals,
    const std::vector<array>& tangents,
    const std::vector<int>& argnums
) {
  // If only one argument
  if (argnums.size() == 1) {
    auto scale = argnums[0] == 0 ? alpha_ : beta_;
    auto scale_arr = array(scale, tangents[0].dtype());
    return {mx::multiply(scale_arr, tangents[0], stream())};
  }
  
  // If both arguments
  return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
}

Reverse Mode (VJP)

std::vector<array> Axpby::vjp(
    const std::vector<array>& primals,
    const std::vector<array>& cotangents,
    const std::vector<int>& argnums,
    const std::vector<array>& /* unused */
) {
  std::vector<array> vjps;
  for (auto arg : argnums) {
    auto scale = arg == 0 ? alpha_ : beta_;
    auto scale_arr = array(scale, cotangents[0].dtype());
    vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
  }
  return vjps;
}

Building with CMake

Directory Structure

extensions/
├── axpby/
│   ├── axpby.cpp
│   ├── axpby.h
│   └── axpby.metal
├── CMakeLists.txt
└── setup.py

CMakeLists.txt

# Add C++ library
add_library(mlx_ext)

target_sources(
    mlx_ext
    PUBLIC
    ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)

target_include_directories(
    mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)

# Link to MLX
target_link_libraries(mlx_ext PUBLIC mlx)

# Build Metal library
if(MLX_BUILD_METAL)
  mlx_build_metallib(
      TARGET mlx_ext_metallib
      TITLE mlx_ext
      SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
      INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
      OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
  )
  
  add_dependencies(mlx_ext mlx_ext_metallib)
endif()

Python Bindings

Use nanobind to create Python bindings:
bindings.cpp
#include <nanobind/nanobind.h>
#include "axpby.h"

namespace nb = nanobind;

NB_MODULE(_ext, m) {
  m.doc() = "Custom MLX extension";
  
  m.def(
      "axpby",
      &axpby,
      "x"_a,
      "y"_a,
      "alpha"_a,
      "beta"_a,
      nb::kw_only(),
      "stream"_a = nb::none(),
      R"(
      Scale and sum two vectors element-wise.
      
      Computes: z = alpha * x + beta * y
      
      Args:
          x (array): Input array
          y (array): Input array  
          alpha (float): Scaling factor for x
          beta (float): Scaling factor for y
      
      Returns:
          array: alpha * x + beta * y
      )"
  );
}

setup.py

setup.py
from mlx import extension
from setuptools import setup

if __name__ == "__main__":
    setup(
        name="mlx_sample_extensions",
        version="0.0.0",
        description="Sample C++ extension for MLX",
        ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
        cmdclass={"build_ext": extension.CMakeBuild},
        packages=["mlx_sample_extensions"],
        package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
        zip_safe=False,
        python_requires=">=3.8",
    )

Building the Extension

1

Install build dependencies

pip install -r requirements.txt
2

Build in place for development

python setup.py build_ext -j8 --inplace
3

Or install the package

pip install .

Usage

Now you can use your custom operation:
import mlx.core as mx
from mlx_sample_extensions import axpby

a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0)

print(f"Result: {c}")  # Should be all 6.0
print(f"Correct: {mx.all(c == 6.0).item()}")  # True

Performance

Custom primitives can significantly improve performance by fusing operations:
import mlx.core as mx
import time

# Simple version using basic ops
def simple_axpby(x, y, alpha, beta):
    return alpha * x + beta * y

# Custom primitive version
from mlx_sample_extensions import axpby

M, N = 4096, 4096
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))

# Benchmark
for impl in [simple_axpby, axpby]:
    # Warmup
    for _ in range(5):
        z = impl(x, y, 4.0, 2.0)
        mx.eval(z)
    
    # Time
    start = time.perf_counter()
    for _ in range(100):
        z = impl(x, y, 4.0, 2.0)
        mx.eval(z)
    elapsed = time.perf_counter() - start
    print(f"{impl.__name__}: {elapsed*10:.3f} ms")
Expected speedup: 2-3x for this fused operation.

Next Steps

Metal Kernels

Write custom Metal GPU kernels

C++ Operations

Browse the C++ API reference

Resources