NumPy
Let’s convert an array to NumPy and back:Since NumPy does not support Otherwise, you will receive an error like:
bfloat16 arrays, you will need to convert to float16 or float32 first:Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.Creating Array Views
By default, NumPy copies data to a new array. This can be prevented by creating an array view:NumPy arrays with type
float64 will be default converted to MLX arrays with type float32.Gradient Considerations
Let’s demonstrate this in an example:f indirectly modifies the array x through a memory view. However, this modification is not reflected in the gradient, as seen in the last line outputting 1.0, representing the gradient of the sum operation alone. The squaring of x occurs externally to MLX, meaning that no gradient is incorporated.
It’s important to note that a similar issue arises during array conversion and copying. For instance, a function defined as mx.array(np.array(x)**2).sum() would also result in an incorrect gradient, even though no in-place operations on MLX memory are executed.
PyTorch
PyTorch supports the buffer protocol, but it requires an explicitmemoryview:
numpy().
JAX
JAX fully supports the buffer protocol:TensorFlow
TensorFlow supports the buffer protocol, but it requires an explicitmemoryview: