Skip to content

Commit

Permalink
fix memory leak with torch.nograd()
Browse files Browse the repository at this point in the history
  • Loading branch information
wavefrontshaping committed Sep 17, 2019
1 parent 431f2d6 commit 6ed32b0
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions complexLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@ def forward(self, input_r, input_i):
mean = torch.stack((mean_r,mean_i),dim=1)

# update running mean
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean

input_r = input_r-mean_r[None, :, None, None]
input_i = input_i-mean_i[None, :, None, None]
Expand All @@ -226,14 +227,15 @@ def forward(self, input_r, input_i):
Cii = 1./n*input_i.pow(2).sum(dim=[0,2,3])+self.eps
Cri = (input_r.mul(input_i)).mean(dim=[0,2,3])

self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,0]
with torch.no_grad():
self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,0]

self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,1]
self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,1]

self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,2]
self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,2]

else:
mean = self.running_mean
Expand Down Expand Up @@ -291,8 +293,9 @@ def forward(self, input_r, input_i):
mean = torch.stack((mean_r,mean_i),dim=1)

# update running mean
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean

# zero mean values
input_r = input_r-mean_r[None, :]
Expand All @@ -305,14 +308,15 @@ def forward(self, input_r, input_i):
Cii = input_i.var(dim=0,unbiased=False)+self.eps
Cri = (input_r.mul(input_i)).mean(dim=0)

self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,0]
with torch.no_grad():
self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,0]

self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,1]
self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,1]

self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,2]
self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,2]

else:
mean = self.running_mean
Expand Down

0 comments on commit 6ed32b0

Please sign in to comment.