Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ComplexBatchNorm1d Error #20

Open
Metamorphosis-chm opened this issue Apr 12, 2022 · 3 comments
Open

ComplexBatchNorm1d Error #20

Metamorphosis-chm opened this issue Apr 12, 2022 · 3 comments

Comments

@Metamorphosis-chm
Copy link

File "D:/Pycharm/coplexcnn/train.py", line 124, in
y_hat = net(X)
File "C:\Users\MyPC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "D:/Pycharm/complexcnn/train.py", line 79, in forward
x = self.bn1(x)
File "C:\Users\MyPC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "D:\Pycharm\complexcnn\complexLayers.py", line 294, in forward
self.running_mean = exponential_average_factor * mean
RuntimeError: The size of tensor a (253) must match the size of tensor b (32) at non-singleton dimension 1

How to solve it?Thank

@Metamorphosis-chm Metamorphosis-chm changed the title ComplexBatchNorm1d Erro ComplexBatchNorm1d Error Apr 12, 2022
@Glen9010
Copy link

maybe you can try again, the code has updated 3 months ago.
Beside, there still exit memory leakage In ComplexBatchNorm1d.
You can add "with torch.no_grad():" after code line 254

@saugatkandel
Copy link

In case someone is reading this, I had to change the Batchnorm code slightly to make it work properly. Here are my changes to get Batchnorm2d working (Batchnorm1d is similar). The formatting is a bit weird because I use Black as my default linter.

class _ComplexBatchNorm(Module):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,
        affine=True,
        track_running_stats=True,
    ):
        super(_ComplexBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features, 3))
            self.bias = Parameter(torch.Tensor(num_features, 2))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer("running_mean_r", torch.zeros(num_features))
            self.register_buffer("running_mean_i", torch.zeros(num_features))
            self.register_buffer("running_covar", torch.zeros(num_features, 3))
            self.running_covar[:, 0] = 1 / 1.4142135623730951
            self.running_covar[:, 1] = 1 / 1.4142135623730951
            self.register_buffer(
                "num_batches_tracked", torch.tensor(0, dtype=torch.long)
            )
        else:
            self.register_parameter("running_mean_r", None)
            self.register_parameter("running_mean_i", None)
            self.register_parameter("running_covar", None)
            self.register_parameter("num_batches_tracked", None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean_r.zero_()
            self.running_mean_i.zero_()
            self.running_covar.zero_()
            self.running_covar[:, :2] = 1 / 1.4142135623730951
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            init.constant_(self.weight[:, :2], 1 / 1.4142135623730951)
            init.zeros_(self.weight[:, 2])
            init.zeros_(self.bias)


class ComplexBatchNorm2d(_ComplexBatchNorm):
    def forward(self, inputs):
        exponential_average_factor = 0.0 if self.momentum is None else self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        if self.training or (not self.training and not self.track_running_stats):
            # calculate mean of real and imaginary part
            # mean does not support automatic differentiation for outputs with complex dtype.

            mean_r = inputs.real.mean([0, 2, 3])
            mean_i = inputs.imag.mean([0, 2, 3])
        else:
            mean_r = self.running_mean_r.clone()
            mean_i = self.running_mean_i.clone()

        if self.training and self.track_running_stats:
            # update running mean
            with torch.no_grad():

                self.running_mean_r[:] = (
                    exponential_average_factor * mean_r
                    + (1 - exponential_average_factor) * self.running_mean_r
                )
                self.running_mean_i[:] = (
                    exponential_average_factor * mean_i
                    + (1 - exponential_average_factor) * self.running_mean_i
                )

        inputs = inputs - (mean_r + 1j * mean_i)[None, :, None, None]

        if self.training or (not self.training and not self.track_running_stats):
            # Elements of the covariance matrix (biased for train)

            # n = input.numel() / input.size(1)
            Crr = inputs.real.pow(2).mean(dim=[0, 2, 3]) + self.eps
            Cii = inputs.imag.pow(2).mean(dim=[0, 2, 3]) + self.eps
            Cri = (inputs.real * inputs.imag).mean(dim=[0, 2, 3])
        else:
            Crr = self.running_covar[:, 0] + self.eps
            Cii = self.running_covar[:, 1] + self.eps
            Cri = self.running_covar[:, 2]  # +self.eps

        if self.training and self.track_running_stats:
            with torch.no_grad():
                self.running_covar[:, 0] = (
                    exponential_average_factor * Crr
                    + (1 - exponential_average_factor) * self.running_covar[:, 0]
                )

                self.running_covar[:, 1] = (
                    exponential_average_factor * Cii
                    + (1 - exponential_average_factor) * self.running_covar[:, 1]
                )

                self.running_covar[:, 2] = (
                    exponential_average_factor * Cri
                    + (1 - exponential_average_factor) * self.running_covar[:, 2]
                )

        # calculate the inverse square root the covariance matrix
        det = Crr * Cii - Cri.pow(2)

        s = torch.sqrt(det)
        t = torch.sqrt(Cii + Crr + 2 * s)
        inverse_st = 1.0 / (s * t)
        Rrr = (Cii + s) * inverse_st
        Rii = (Crr + s) * inverse_st
        Rri = -Cri * inverse_st

        inputs = (
            Rrr[None, :, None, None] * inputs.real
            + Rri[None, :, None, None] * inputs.imag
        ).type(torch.complex64) + 1j * (
            Rii[None, :, None, None] * inputs.imag
            + Rri[None, :, None, None] * inputs.real
        ).type(
            torch.complex64
        )

        if self.affine:
            inputs = (
                self.weight[None, :, 0, None, None] * inputs.real
                + self.weight[None, :, 2, None, None] * inputs.imag
                + self.bias[None, :, 0, None, None]
            ).type(torch.complex64) + 1j * (
                self.weight[None, :, 2, None, None] * inputs.real
                + self.weight[None, :, 1, None, None] * inputs.imag
                + self.bias[None, :, 1, None, None]
            ).type(
                torch.complex64
            )

        return inputs

The exact changes are as follows:

  1. From my reading of the linked paper and the associated code (https://github.com/ChihebTrabelsi/deep_complex_networks), the running_covar and the weight initialization should be initialized to 1/sqrt(2) and not sqrt(2).
  2. The running mean buffer registration and calculation. The buffer assignment was the hardest to figure out, in that assigning
    self.running_mean_r = ...
    in the forward step does not work, but
    self.running_mean_r[:] = ...
    works.
    Something to do with the Pytorch internals, I guess.

@karli262
Copy link

Hello,
I had the same problem as you described. I noticed that the "ComplexBatchNorm2d" function, which is designed for 4D data (N, C, H, W), calculates the mean and the variance over 3 dimensions e.g. mean_r = input.real.mean([0, 2, 3]) and also applies those parameters in a similar fashion e.g. input = input - mean[None, :, None, None].

The function "ComplexBatchNorm1d" does this only for 2D data and my data was 3d (N, C, L) which caused the problem in my case. To given an example: I changed mean_r = input.real.mean(dim=0).type(torch.complex64) to mean_r = input.real.mean([0, 2]).type(torch.complex64). When applying the mean I changed input = input - mean[None, ...] to input = input - mean[None, :, None].
By doing this also for the imaginary parts and the covariance I obtained the normalized output for 3D data.
I hope this is helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants