-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmasking_model.py
63 lines (49 loc) · 2.32 KB
/
masking_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import models
import utils
from .models import register
from .masking_utils import encoder_wrapper
@register('masking-model')
class MaskingModel(nn.Module):
def __init__(self, encoder, encoder_args, masking, masking_args):
super().__init__()
self.encoder = models.make(encoder, **encoder_args)
self.encoder_name = encoder
masking_args['inplanes'] = self.encoder.out_dim * 2
self.masking_model = models.make(masking, **masking_args)
def forward(self, x_shot, x_pseudo, x_query):
"""
param x_shot: num episodes x N classes x k shot(s) x 3 channels x 84 pixels x 84 pixels
param x_query: num episodes x Nq x 3 channels x 84 pixels x 84 pixels
param x_pseudo: num episodes x N classes x p shot(s) x 3 channels x 84 pixels x 84 pixels
"""
x_shot, x_pseudo, x_query = encoder_wrapper(self.encoder_name, self.encoder, x_shot, x_pseudo, x_query)
ep_per_batch = x_shot.shape[0]
n_way = x_shot.shape[1]
n_shot = x_shot.shape[2]
n_pseudo = x_pseudo.shape[2]
n_query = x_query.shape[1]
a_shot = torch.mean(x_shot, dim=-4)
a_pseudo = torch.mean(x_pseudo, dim=-4)
total = torch.cat((a_shot, a_pseudo), dim=-3) # [2, 5, 1280, 5, 5]
batch_shape = total.shape[:2]
feat_shape = total.shape[2:]
total = total.view(-1, *feat_shape) # [10, 1280, 5, 5]
mask = self.masking_model(total)
mask = mask.view(*batch_shape, *mask.shape[1:]).unsqueeze(dim=2)
x_pseudo = torch.mul(x_pseudo, mask) # [ep_per_batch, n_way, n_pseudo, 640, 5, 5]
img_shape = x_shot.shape[-3:]
# x_shot = x_shot.view(-1, *img_shape) # shape is [10, 640, 5, 5] = [2, 5, 1, 640, 5, 5]
# x_pseudo = x_pseudo.view(-1, *img_shape) # shape is [150, 640, 5, 5] = [2, 5, 15, 640, 5, 5]
# x_query = x_query.view(-1, *img_shape) # shape is [150, 640, 5, 5] = [2, 5, 15, 640, 5, 5]
#
# x_tot = self.final(torch.cat([x_shot, x_pseudo, x_query], dim=0)) # shape is [310, 640, 5, 5] = [2, 5, 31, 640, 5, 5]
# x_shot, x_pseudo, x_query = x_tot[:len(x_shot)], x_tot[len(x_shot):len(x_shot) + len(x_pseudo)], x_tot[len(
# x_shot) + len(x_pseudo):]
x_shot = x_shot.view(ep_per_batch, n_way, n_shot, *img_shape)
x_pseudo = x_pseudo.view(ep_per_batch, n_way, n_pseudo, *img_shape)
x_query = x_query.view(ep_per_batch, n_query, *img_shape)
return x_shot, x_pseudo, x_query