Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type_of_windows in the pseudo code #48

Open
Aquila96 opened this issue Dec 6, 2023 · 2 comments
Open

type_of_windows in the pseudo code #48

Aquila96 opened this issue Dec 6, 2023 · 2 comments

Comments

@Aquila96
Copy link

Aquila96 commented Dec 6, 2023

I am confused about the way the pseudo code defined type_of_windows:

Given that input to the attention:

input_shape = (B * num_windows[0] * num_windows[1] * num_windows[2], window_size[0] * window_size[1] * window_size[2])

and type_of_windows defined in the pseudocode:

self.type_of_windows = (input_shape[0] // window_size[0]) * (input_shape[1] // window_size[1])

Its size would be:

B * num_windows[0] * num_windows[1] * num_windows[2] * window_size[2]

which contradicates the size prescribed in the paper: Mlat*Mpl, which presumably should be:

B * num_windows[0] * num_windows[1] 

It would not work in actual code unless type_of_window is set to:

B * num_windows[0] * num_windows[1] * num_windows[2]

which is in turn not the way position_index is designed to do.

Any help is appreicated

@Aquila96 Aquila96 changed the title type type_of_windows in the pseudo code Dec 6, 2023
@Aquila96
Copy link
Author

Aftering some digging in an onnx file, I've figured out why this is. For others wondering, pangu's window_partition returns a 4-D tensor, instead of the standard 3-D tensor in other swin-transformers:

x_window = x.reshape(x.shape[0], 
                     self.num_windows[0],
                     self.window_size[0],
                     self.num_windows[1],
                     self.window_size[1],
                     self.num_windows[2],
                     self.window_size[2],
                     x.shape[-1])
x_window = x_window.permute(0, 5, 1, 3, 2, 4, 6, 7).contiguous()
x_window = x_window.reshape(-1, 
                            self.type_of_windows, 
                            self.window_size[0] * self.window_size[1] * self.window_size[2],
                            x_window.shape[-1])

Similarly, you would have to modify reverse_window_partition as well, plus the permute order post attn @ v

@198808xc
Copy link
Owner

Hi, thanks for the clarification. As far as I can see, the understanding is correct. Just to say, the window_partition function naturally returns a 4D tensor because we are dealing with 3D data while the standard Swin transformer is dealing with 2D images.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants