Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save the frequency domain image after FFT operation? #2

Open
AbandonedWarlord opened this issue Jun 20, 2023 · 1 comment
Open

Comments

@AbandonedWarlord
Copy link

` # 2D FFT
x_freq = torch.fft.fft2(image)
# shift low frequency to the center
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
# mask a portion of frequencies
x_freq_masked = x_freq
# restore the original frequency order
x_freq_masked = torch.fft.ifftshift(x_freq_masked, dim=(-2, -1))
# 2D iFFT (only keep the real part)
x_corrupted = torch.fft.ifft2(x_freq_masked).real
x_corrupted = torch.clamp(x_corrupted, min=0., max=1.)
x_np = x_corrupted.numpy()

    im = Image.fromarray((x_np * 255).astype(np.uint8))

    im.save(os.path.join(output_folder, filename))`

The image I save with the above code is far from the image of the cat example, can you provide a demo of this please?

@KeiChiTse
Copy link

You can use the script below:

def fft(x):
    # x: Tensor, (B, 3, H, W), 0-1
    # 2D FFT
    x_freq = torch.fft.fft2(x)
    # shift low frequency to the center
    x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
    return x_freq


def show_image(image, save_path):
    plt.imshow(image, cmap='viridis')
    plt.axis('off')
    plt.colorbar()

    fig = plt.gcf()
    fig.set_size_inches(image.shape[1] / 100, image.shape[0] / 100)
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
    plt.savefig(save_path, dpi=100, bbox_inches='tight', pad_inches=0)


img = Image.open("path_to_image")
img = img.convert('L')  # convert to L channel
img = T.ToTensor()(img)

fft_img = fft(img)
fft_img = torch.abs(fft_img)  # magnitude
fft_img = torch.log1p(fft_img)  # convert to log scale for better visualization

max_val, min_val = torch.max(fft_img), torch.min(fft_img)
fft_img = torch.div(fft_img - min_val, max_val - min_val)  # to 0-1

fft_img = torch.einsum('chw->hwc', fft_img).numpy()  # convert to channel last
show_image(fft_img, "save_path")

` # 2D FFT x_freq = torch.fft.fft2(image) # shift low frequency to the center x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) # mask a portion of frequencies x_freq_masked = x_freq # restore the original frequency order x_freq_masked = torch.fft.ifftshift(x_freq_masked, dim=(-2, -1)) # 2D iFFT (only keep the real part) x_corrupted = torch.fft.ifft2(x_freq_masked).real x_corrupted = torch.clamp(x_corrupted, min=0., max=1.) x_np = x_corrupted.numpy()

    im = Image.fromarray((x_np * 255).astype(np.uint8))

    im.save(os.path.join(output_folder, filename))`

The image I save with the above code is far from the image of the cat example, can you provide a demo of this please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants