Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
LZHgrla committed Apr 14, 2023
1 parent fab082d commit 4dbb9b2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 184 deletions.
16 changes: 0 additions & 16 deletions cfg/dy-yolov7-w6-step1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,6 @@ head:
[-13, 1, Conv, [768, 3, 1]],
[-4, 1, Conv, [1024, 3, 1]],

# [-35, 1, Conv, [320, 3, 1]],
# [-48, 1, Conv, [640, 3, 1]],
# [-61, 1, Conv, [960, 3, 1]],
# [-74, 1, Conv, [1280, 3, 1]],

# [[-8, -7, -6, -5, -4, -3, -2, -1], 1, IAuxDetect, [nc, anchors]], # Detect(P3, P4, P5, P6)
# ]

[[-4, -3, -2, -1], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5, P6)
]

Expand Down Expand Up @@ -325,13 +317,5 @@ head2:
[-13, 1, Conv, [768, 3, 1]],
[-4, 1, Conv, [1024, 3, 1]],

# [-35, 1, Conv, [320, 3, 1]],
# [-48, 1, Conv, [640, 3, 1]],
# [-61, 1, Conv, [960, 3, 1]],
# [-74, 1, Conv, [1280, 3, 1]],

# [[-8, -7, -6, -5, -4, -3, -2, -1], 1, IAuxDetect, [nc, anchors]], # Detect(P3, P4, P5, P6)
# ]

[[-4, -3, -2, -1], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5, P6)
]
177 changes: 9 additions & 168 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,128 +135,6 @@ def convert(self, z):
return (box, score)


class IAuxDetect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
end2end = False
include_nms = False
concat = False

def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(IAuxDetect, self).__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[:self.nl]) # output conv
self.m2 = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[self.nl:]) # output conv

self.ia = nn.ModuleList(ImplicitA(x) for x in ch[:self.nl])
self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch[:self.nl])

def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](self.ia[i](x[i])) # conv
x[i] = self.im[i](x[i])
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

x[i+self.nl] = self.m2[i](x[i+self.nl])
x[i+self.nl] = x[i+self.nl].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
if not torch.onnx.is_in_onnx_export():
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else:
xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, -1, self.no))

return x if self.training else (torch.cat(z, 1), x[:self.nl])

def fuseforward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
if not torch.onnx.is_in_onnx_export():
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else:
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

if self.training:
out = x
elif self.end2end:
out = torch.cat(z, 1)
elif self.include_nms:
z = self.convert(z)
out = (z, )
elif self.concat:
out = torch.cat(z, 1)
else:
out = (torch.cat(z, 1), x)

return out

def fuse(self):
print("IAuxDetect.fuse")
# fuse ImplicitA and Convolution
for i in range(len(self.m)):
c1,c2,_,_ = self.m[i].weight.shape
c1_,c2_, _,_ = self.ia[i].implicit.shape
self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)

# fuse ImplicitM and Convolution
for i in range(len(self.m)):
c1,c2, _,_ = self.im[i].implicit.shape
self.m[i].bias *= self.im[i].implicit.reshape(c2)
self.m[i].weight *= self.im[i].implicit.transpose(0,1)

@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

def convert(self, z):
z = torch.cat(z, 1)
box = z[:, :, :4]
conf = z[:, :, 4:5]
score = z[:, :, 5:]
score *= conf
convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=z.device)
box @= convert_matrix
return (box, score)


class Model(nn.Module):
def __init__(self, cfg, ch=3, nc=None): # model, input channels, number of classes
super(Model, self).__init__()
Expand Down Expand Up @@ -303,19 +181,6 @@ def __init__(self, cfg, ch=3, nc=None): # model, input channels, number of clas
self.stride = m.stride
self._initialize_biases() # only run once
# print('Strides: %s' % m.stride.tolist())
if isinstance(m, IAuxDetect):
s = 256 # 2x min stride
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))[:4]]) # forward
#print(m.stride)
check_anchor_order(m)
m.anchors /= m.stride.view(-1, 1, 1)
m2.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))[:4]]) # forward
#print(m2.stride)
check_anchor_order(m2)
m2.anchors /= m2.stride.view(-1, 1, 1)
self.stride = m.stride
self._initialize_aux_biases() # only run once
# print('Strides: %s' % m.stride.tolist())
# Init weights, biases
initialize_weights(self)
# self.initialize_cblinear()
Expand Down Expand Up @@ -353,7 +218,7 @@ def forward_once(self, x, profile=False):
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers

if profile:
c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
c = isinstance(m, IDetect)
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
for _ in range(10):
m(x.copy() if c else x)
Expand Down Expand Up @@ -387,7 +252,7 @@ def forward_once(self, x, profile=False):
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers

if profile:
c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
c = isinstance(m, IDetect)
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
for _ in range(10):
m(x.copy() if c else x)
Expand All @@ -407,7 +272,7 @@ def forward_once(self, x, profile=False):
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers

if profile:
c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
c = isinstance(m, IDetect)
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
for _ in range(10):
m(x.copy() if c else x)
Expand All @@ -433,7 +298,7 @@ def forward_once(self, x, profile=False):
x = y[cur_f] if isinstance(cur_f, int) else [x if j == -1 else y[j] for j in cur_f] # from earlier layers

if profile:
c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
c = isinstance(m, IDetect)
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
for _ in range(10):
m(x.copy() if c else x)
Expand Down Expand Up @@ -477,30 +342,6 @@ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is
b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

def _initialize_aux_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
# https://arxiv.org/abs/1708.02002 section 3.3
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
m = self.model_h[-1] # Detect() module
for mi, mi2, s in zip(m.m, m.m2, m.stride): # from
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
b2 = mi2.bias.view(m.na, -1) # conv.bias(255) to (3,85)
b2.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
b2.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
mi2.bias = torch.nn.Parameter(b2.view(-1), requires_grad=True)
m = self.model_h2[-1] # Detect() module
for mi, mi2, s in zip(m.m, m.m2, m.stride): # from
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
b2 = mi2.bias.view(m.na, -1) # conv.bias(255) to (3,85)
b2.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
b2.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
mi2.bias = torch.nn.Parameter(b2.view(-1), requires_grad=True)

def _print_biases(self):
m = self.model_h[-1] # Detect() module
for mi in m.m: # from
Expand All @@ -527,7 +368,7 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.fuseforward # update forward
elif isinstance(m, (IDetect, IAuxDetect)):
elif isinstance(m, IDetect):
m.fuse()
m.forward = m.fuseforward
self.info()
Expand Down Expand Up @@ -599,7 +440,7 @@ def parse_model(d, ch_b): # model_dict, input_channels(3)
c2 = sum([ch_b[x] for x in f])
elif m is Shortcut:
c2 = ch_b[f[0]]
elif m in [IDetect, IAuxDetect]:
elif m is IDetect:
args.append([ch_b[x] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
Expand Down Expand Up @@ -661,7 +502,7 @@ def parse_model(d, ch_b): # model_dict, input_channels(3)
elif m is Shortcut:
assert len(chs) == 1
c2 = chs[0][f[0]]
elif m in [IDetect, IAuxDetect]:
elif m is IDetect:
args.append([ch[x] for x, ch in zip(f, chs)])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
Expand Down Expand Up @@ -730,7 +571,7 @@ def parse_model(d, ch_b): # model_dict, input_channels(3)
elif m is Shortcut:
assert len(chs) == 1
c2 = chs[0][f[0]]
elif m in [IDetect, IAuxDetect]:
elif m is IDetect:
args.append([ch[x] for x, ch in zip(f, chs)])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
Expand Down Expand Up @@ -792,7 +633,7 @@ def parse_model(d, ch_b): # model_dict, input_channels(3)
elif m is Shortcut:
assert len(chs) == 1
c2 = chs[0][f[0]]
elif m in [IDetect, IAuxDetect]:
elif m is IDetect:
args.append([ch[x] for x, ch in zip(f, chs)])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
Expand Down

0 comments on commit 4dbb9b2

Please sign in to comment.