Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] FFT fails on certain array lengths #1800

Open
cdcapano opened this issue Jan 27, 2025 · 3 comments
Open

[BUG] FFT fails on certain array lengths #1800

cdcapano opened this issue Jan 27, 2025 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@cdcapano
Copy link

cdcapano commented Jan 27, 2025

Describe the bug
Doing an FFT on array lengths 2^(21) and 2^(22) results in a kernel failure, but larger array sizes work.

To Reproduce

A simple script to reproduce:

import mlx.core as mx
size = int(2**21)
x = mx.ones(size)
mx.eval(mx.fft.fft(x, stream=mx.gpu))

This will result in the following error:

Terminating due to uncaught exception: [metal::Device] Unable to load function four_step_mem_8192_float2_float2_0_false
Function four_step_mem_8192_float2_float2_0_false was not found in the library

Abort trap: 6

A similar thing happens for an array that is 2**22 long. However, the code succeeds for arrays that have length 2**23, 2**24, 2**25, etc., up to 2**28. (I don't have enough memory to test beyond that.) By "succeed" I mean the function runs without failure. I haven't checked that the output is actually correct.

Expected behavior
The FFT should work for 2**21 and 2**22 if larger array sizes work. At the very least, the error should be caught appropriately with a more graceful exit.

Desktop (please complete the following information):

  • OS Version: MacOS 15.1.1
  • Version 0.22.0

Additional context
Digging into the code a bit I can see why it's failing. For a size of 2**21, plan.n1 here will get set to 2048. Later on, that will cause threadgroup_mem_size to get set to 8192 here. However, I don't know why that doesn't cause the assert at line 641 to raise an error.

I see the comment at line 640 that // FFTs up to 2^20 are currently supported, so I'm not sure why the 2^23 FFTs are running. Even if the assert worked properly, why the limit of 2^20? In the research application we're trying to use this for we will be evaluating arrays of 2^21 - 2^25, so it would be ideal if these array sizes could be handled.

@awni awni added the bug Something isn't working label Jan 27, 2025
@awni
Copy link
Member

awni commented Jan 27, 2025

Indeed looks like a bug.

@hriverg
Copy link

hriverg commented Feb 16, 2025

why assert not works
assert works in debug mode, not release mode.
this command makes assert work:
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace --debug

why the 2^21 error happens
The algorithm needs the FFT kernel 'four_step_mem_8192_float2_float2_0_false' in this case.
but the library cannot find it because it is not instantiated.
add instantiate_ffts(8192) here would instantiate it.

but as explained here, this kernel won't fit into 32KB of threadgroup memory.
Most Mac devices' Metal GPUs have this limitation.

but why 2^23 works
fft algorithm breaks large fft into smaller fft recursively.
for 2^23, it breaks into (128 x 64) x 1024.
for 2^21, it breaks into 2048 x 1024.
algorithm here chooses not to break 2048 further since it is not larger than MAX_STOCKHAM_FFT_SIZE, which is set to be 4096.
And that is also why 2^22 = 4096 x 1024 fails.

@barronalex
Copy link
Collaborator

I will fix this properly soon but in the meantime here are two options:

  1. Run on CPU
  2. Run a four step FFT with MLX ops (this is roughly what I'm going to implement in the C++ backend):
import math
import mlx.core as mx

def four_step_fft(x, axis: int = 0):
    n = x.shape[axis]
    assert n & (n - 1) == 0, "Only supports powers of two"
    log_n = math.log2(n) / 2
    n1, n2 = 2**(math.ceil(log_n)), 2**(math.floor(log_n))
    orig_shape = x.shape
    shape = x.shape[:axis] + (n1, n2) + x.shape[axis+1:]
    x = x.reshape(shape)
    ij = mx.arange(n1)[:, mx.newaxis] * mx.arange(n2)[mx.newaxis]
    twiddles = mx.exp(mx.array(-2j * mx.pi * ij / n))
    step_one = mx.fft.fft(x, axis=axis) * twiddles
    step_two = mx.fft.fft(mx.swapaxes(step_one, axis, axis + 1), axis=axis)
    return step_two.reshape((orig_shape))

The problem at the moment is that the strided four step FFT implementation runs out of thread group memory when the constituent FFTs are larger than 1024 (hence the 1024*1024=2**20 limit).
I'll implement a nested four step fft as above to fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants