Skip to main content
MLX has an API to export and import functions to and from a file. This lets you run computations written in one MLX front-end (e.g. Python) in another MLX front-end (e.g. C++). This guide walks through the basics of the MLX export API with some examples.

Basics of Exporting

Let’s start with a simple example:
def fun(x, y):
    return x + y

x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)
To export a function, provide sample input arrays that the function can be called with. The data doesn’t matter, but the shapes and types of the arrays do. In the above example, we exported fun with two float32 scalar arrays.

Importing and Running

We can then import the function and run it:
add_fun = mx.import_function("add.mlxfn")

out, = add_fun(mx.array(1.0), mx.array(2.0))
print(out)
# array(3, dtype=float32)

out, = add_fun(mx.array(1.0), mx.array(3.0))
print(out)
# array(4, dtype=float32)
The following calls will raise exceptions because the shapes and types of the inputs differ from the example inputs:
# Raises an exception - type mismatch
add_fun(mx.array(1), mx.array(3.0))

# Raises an exception - shape mismatch
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
Even though the original fun returns a single output array, the imported function always returns a tuple of one or more arrays.

Input Specifications

The inputs to mx.export_function() and to an imported function can be specified as variable positional arguments or as a tuple of arrays:
def fun(x, y):
    return x + y

x = mx.array(1.0)
y = mx.array(1.0)

# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)

# Same as above
mx.export_function("add.mlxfn", fun, (x, y))

imported_fun = mx.import_function("add.mlxfn")

# Ok
out, = imported_fun(x, y)

# Also ok
out, = imported_fun((x, y))

Keyword Arguments

You can pass example inputs to functions as positional or keyword arguments. If you use keyword arguments to export the function, then you have to use the same keyword arguments when calling the imported function.
def fun(x, y):
    return x + y

# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)

imported_fun = mx.import_function("add.mlxfn")

# Ok
out, = imported_fun(x, y=y)

# Also ok
out, = imported_fun((x,), {"y": y})

# Raises since the keyword argument is missing
out, = imported_fun(x, y)

# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)

Exporting Modules

An mlx.nn.Module can be exported with or without the parameters included in the exported function.

With Parameters

Here’s an example:
model = nn.Linear(4, 4)
mx.eval(model.parameters())

def call(x):
    return model(x)

mx.export_function("model.mlxfn", call, mx.zeros(4))
In the above example, the nn.Linear module is exported. Its parameters are also saved to the model.mlxfn file.
For enclosed arrays inside an exported function, be extra careful to ensure they are evaluated. The computation graph that gets exported will include the computation that produces enclosed inputs.If the above example was missing mx.eval(model.parameters()), the exported function would include the random initialization of the nn.Module parameters.

Without Parameters

If you only want to export the Module.__call__ function without the parameters, pass them as inputs to the call wrapper:
model = nn.Linear(4, 4)
mx.eval(model.parameters())

def call(x, **params):
    # Set the model's parameters to the input parameters
    model.update(tree_unflatten(list(params.items())))
    return model(x)

params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)

Exporting with a Callback

To inspect the exported graph, you can pass a callback instead of a file path to mx.export_function().
def fun(x):
    return x.astype(mx.int32)

def callback(args):
    print(args)

mx.export_function(callback, fun, mx.array([1.0, 2.0]))
The argument to the callback (args) is a dictionary which includes a type field. The possible types are:
  • "inputs": The ordered positional inputs to the exported function
  • "keyword_inputs": The keyword specified inputs to the exported function
  • "outputs": The ordered outputs of the exported function
  • "constants": Any graph constants
  • "primitives": Inner graph nodes representing the operations
Each type has additional fields in the args dictionary.

Shapeless Exports

Just like mx.compile(), functions can also be exported for dynamically shaped inputs. Pass shapeless=True to mx.export_function() or mx.exporter() to export a function which can be used for inputs with variable shapes:
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")

# Ok
out, = imported_abs(mx.array([-1.0]))

# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With shapeless=False (which is the default), the second call to imported_abs would raise an exception with a shape mismatch.
Shapeless exporting works the same as shapeless compilation and should be used carefully. See the documentation on shapeless compilation for more information.

Exporting Multiple Traces

In some cases, functions build different computation graphs for different input arguments. A simple way to manage this is to export to a new file with each set of inputs. This is a fine option in many cases. But it can be suboptimal if the exported functions have a large amount of duplicate constant data (for example the parameters of a nn.Module). The export API in MLX lets you export multiple traces of the same function to a single file by creating an exporting context manager with mx.exporter():
def fun(x, y=None):
    constant = mx.array(3.0)
    if y is not None:
        x += y
    return x + constant

with mx.exporter("fun.mlxfn", fun) as exporter:
    exporter(mx.array(1.0))
    exporter(mx.array(1.0), y=mx.array(0.0))

imported_function = mx.import_function("fun.mlxfn")

# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)

# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)
In the above example, the function constant data (i.e. constant) is only saved once.

Transformations with Imported Functions

Function transformations like mx.grad(), mx.vmap(), and mx.compile() work on imported functions just like regular Python functions:
def fun(x):
    return mx.sin(x)

x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)

imported_fun = mx.import_function("sine.mlxfn")

# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
print(dfdx(x))
# array(1, dtype=float32)

# Compile the imported function
compiled_fun = mx.compile(imported_fun)
print(compiled_fun(x)[0])
# array(0, dtype=float32)

Importing Functions in C++

Importing and running functions in C++ is basically the same as importing and running them in Python. First, follow the instructions to setup a simple C++ project that uses MLX as a library. Next, export a simple function from Python:
def fun(x, y):
    return mx.exp(x + y)

x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)
Import and run the function in C++ with only a few lines of code:
auto fun = mx::import_function("fun.mlxfn");

auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);

// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use std::vector<mx::array> for positional arguments and std::map<std::string, mx::array> for keyword arguments when calling imported functions in C++.

More Examples

Here are a few more complete examples exporting more complex functions from Python and importing and running them in C++: