Overview
The mlx.core.fft module provides functions for computing discrete Fourier transforms. These functions support 1D, 2D, and n-dimensional transforms for both complex and real-valued inputs.
fft
mx.fft.fft(
a: array,
n: int = None,
axis: int = -1,
stream: StreamOrDevice = None
) -> array
Compute the one-dimensional discrete Fourier Transform.
Input array (can be complex or real)
Length of the transformed axis. If n is smaller than the input, the input is cropped. If larger, the input is zero-padded. If None, uses the length of the input along the specified axis.
Axis over which to compute the FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the FFT
Example:
import mlx.core as mx
# FFT of a real signal
x = mx.array([1.0, 2.0, 3.0, 4.0])
X = mx.fft.fft(x)
# With padding
X = mx.fft.fft(x, n=8)
ifft
mx.fft.ifft(
a: array,
n: int = None,
axis: int = -1,
stream: StreamOrDevice = None
) -> array
Compute the one-dimensional inverse discrete Fourier Transform.
Input array (typically complex-valued)
Length of the transformed axis
Axis over which to compute the inverse FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the inverse FFT
Example:
X = mx.fft.fft(x)
x_reconstructed = mx.fft.ifft(X)
rfft
mx.fft.rfft(
a: array,
n: int = None,
axis: int = -1,
stream: StreamOrDevice = None
) -> array
Compute the one-dimensional FFT of a real-valued input.
This function is more efficient than fft for real inputs, as it only computes the positive frequency components. The output contains n//2 + 1 complex values.
Length of the transformed axis
Axis over which to compute the FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array of length n//2 + 1
Example:
x = mx.array([1.0, 2.0, 3.0, 4.0])
X = mx.fft.rfft(x) # Output shape: (3,)
irfft
mx.fft.irfft(
a: array,
n: int = None,
axis: int = -1,
stream: StreamOrDevice = None
) -> array
Compute the inverse of rfft.
Complex-valued input array (output of rfft)
Length of the output. If not specified, computed as 2 * (m - 1) where m is the length of the input along the transformed axis.
Axis over which to compute the inverse FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Real-valued array
Example:
X = mx.fft.rfft(x)
x_reconstructed = mx.fft.irfft(X)
fft2
mx.fft.fft2(
a: array,
s: tuple = None,
axes: tuple = (-2, -1),
stream: StreamOrDevice = None
) -> array
Compute the two-dimensional discrete Fourier Transform.
Shape of the output along the transformed axes. If not specified, uses the shape of the input.
Axes over which to compute the FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the 2D FFT
Example:
import mlx.core as mx
x = mx.random.uniform(shape=(8, 8))
X = mx.fft.fft2(x)
ifft2
mx.fft.ifft2(
a: array,
s: tuple = None,
axes: tuple = (-2, -1),
stream: StreamOrDevice = None
) -> array
Compute the two-dimensional inverse discrete Fourier Transform.
Input array (typically complex-valued)
Shape of the output along the transformed axes
Axes over which to compute the inverse FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the 2D inverse FFT
Example:
X = mx.fft.fft2(x)
x_reconstructed = mx.fft.ifft2(X)
rfft2
mx.fft.rfft2(
a: array,
s: tuple = None,
axes: tuple = (-2, -1),
stream: StreamOrDevice = None
) -> array
Compute the two-dimensional FFT of a real-valued input.
Shape of the output along the transformed axes
Axes over which to compute the FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array
Example:
x = mx.random.uniform(shape=(8, 8))
X = mx.fft.rfft2(x)
irfft2
mx.fft.irfft2(
a: array,
s: tuple = None,
axes: tuple = (-2, -1),
stream: StreamOrDevice = None
) -> array
Compute the inverse of rfft2.
Complex-valued input array (output of rfft2)
Shape of the output along the transformed axes
Axes over which to compute the inverse FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Real-valued array
Example:
X = mx.fft.rfft2(x)
x_reconstructed = mx.fft.irfft2(X)
fftn
mx.fft.fftn(
a: array,
s: tuple = None,
axes: tuple = None,
stream: StreamOrDevice = None
) -> array
Compute the n-dimensional discrete Fourier Transform.
Shape of the output along the transformed axes. If not specified, uses the shape of the input.
Axes over which to compute the FFT. If None, transforms over all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the n-D FFT
Example:
x = mx.random.uniform(shape=(4, 4, 4))
X = mx.fft.fftn(x)
# Transform specific axes
X = mx.fft.fftn(x, axes=(0, 2))
ifftn
mx.fft.ifftn(
a: array,
s: tuple = None,
axes: tuple = None,
stream: StreamOrDevice = None
) -> array
Compute the n-dimensional inverse discrete Fourier Transform.
Input array (typically complex-valued)
Shape of the output along the transformed axes
Axes over which to compute the inverse FFT. If None, transforms over all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the n-D inverse FFT
Example:
X = mx.fft.fftn(x)
x_reconstructed = mx.fft.ifftn(X)
rfftn
mx.fft.rfftn(
a: array,
s: tuple = None,
axes: tuple = None,
stream: StreamOrDevice = None
) -> array
Compute the n-dimensional FFT of a real-valued input.
Shape of the output along the transformed axes
Axes over which to compute the FFT. If None, transforms over all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array
Example:
x = mx.random.uniform(shape=(4, 4, 4))
X = mx.fft.rfftn(x)
irfftn
mx.fft.irfftn(
a: array,
s: tuple = None,
axes: tuple = None,
stream: StreamOrDevice = None
) -> array
Compute the inverse of rfftn.
Complex-valued input array (output of rfftn)
Shape of the output along the transformed axes
Axes over which to compute the inverse FFT. If None, transforms over all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Real-valued array
Example:
X = mx.fft.rfftn(x)
x_reconstructed = mx.fft.irfftn(X)
Helper Functions
fftshift
mx.fft.fftshift(
a: array,
axes: tuple = None,
stream: StreamOrDevice = None
) -> array
Shift the zero-frequency component to the center of the spectrum.
This function swaps half-spaces for all axes, moving the zero-frequency component from the beginning to the center of the array.
Axes over which to shift. If None, shifts all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Shifted array
Example:
X = mx.fft.fft(x)
X_centered = mx.fft.fftshift(X)
ifftshift
mx.fft.ifftshift(
a: array,
axes: tuple = None,
stream: StreamOrDevice = None
) -> array
The inverse of fftshift.
This function undoes the effect of fftshift, shifting the zero-frequency component back to the beginning of the array.
Axes over which to shift. If None, shifts all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Shifted array
Example:
X_centered = mx.fft.fftshift(X)
X_original = mx.fft.ifftshift(X_centered)