First, let’s implement it using existing MLX operations:
#include "mlx/mlx.h"namespace mx = mlx::core;array axpby( const array& x, const array& y, const float alpha, const float beta, StreamOrDevice s = {}) { // Scale x and y auto ax = mx::multiply(array(alpha), x, s); auto by = mx::multiply(array(beta), y, s); // Add and return return mx::add(ax, by, s);}
This works but creates multiple operations in the graph. For better performance, we’ll create a custom primitive.
The operation handles type promotion and broadcasting:
axpby.cpp
array axpby( const array& x, const array& y, const float alpha, const float beta, StreamOrDevice s = {}) { // Promote dtypes between x and y auto promoted_dtype = mx::promote_types(x.dtype(), y.dtype()); // Upcast to float32 for non-floating point inputs auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32) ? promoted_dtype : mx::promote_types(promoted_dtype, mx::float32); // Cast x and y to output dtype auto x_casted = mx::astype(x, out_dtype, s); auto y_casted = mx::astype(y, out_dtype, s); // Broadcast shapes auto broadcasted_inputs = mx::broadcast_arrays({x_casted, y_casted}, s); auto out_shape = broadcasted_inputs[0].shape(); // Construct output array with Axpby primitive return array( out_shape, out_dtype, std::make_shared<Axpby>(mx::to_stream(s), alpha, beta), broadcasted_inputs );}
# Add C++ libraryadd_library(mlx_ext)target_sources( mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)target_include_directories( mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})# Link to MLXtarget_link_libraries(mlx_ext PUBLIC mlx)# Build Metal libraryif(MLX_BUILD_METAL) mlx_build_metallib( TARGET mlx_ext_metallib TITLE mlx_ext SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY} ) add_dependencies(mlx_ext mlx_ext_metallib)endif()
#include <nanobind/nanobind.h>#include "axpby.h"namespace nb = nanobind;NB_MODULE(_ext, m) { m.doc() = "Custom MLX extension"; m.def( "axpby", &axpby, "x"_a, "y"_a, "alpha"_a, "beta"_a, nb::kw_only(), "stream"_a = nb::none(), R"( Scale and sum two vectors element-wise. Computes: z = alpha * x + beta * y Args: x (array): Input array y (array): Input array alpha (float): Scaling factor for x beta (float): Scaling factor for y Returns: array: alpha * x + beta * y )" );}
Custom primitives can significantly improve performance by fusing operations:
import mlx.core as mximport time# Simple version using basic opsdef simple_axpby(x, y, alpha, beta): return alpha * x + beta * y# Custom primitive versionfrom mlx_sample_extensions import axpbyM, N = 4096, 4096x = mx.random.normal((M, N))y = mx.random.normal((M, N))# Benchmarkfor impl in [simple_axpby, axpby]: # Warmup for _ in range(5): z = impl(x, y, 4.0, 2.0) mx.eval(z) # Time start = time.perf_counter() for _ in range(100): z = impl(x, y, 4.0, 2.0) mx.eval(z) elapsed = time.perf_counter() - start print(f"{impl.__name__}: {elapsed*10:.3f} ms")