Skip to content

Commit

Permalink
script: make more dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafael Stahl committed Aug 26, 2020
1 parent e04e6b9 commit ad1b146
Showing 1 changed file with 43 additions and 27 deletions.
70 changes: 43 additions & 27 deletions scripts/find_breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,38 @@
from ortools.linear_solver import pywraplp

# Number of nodes.
N = 10

# Test model
#S = [100*100*3, 100*100*20, 50*50*20, 50*50*50, 25*25*50, 25*25*100, 12*12*100, 12*12*500]
#W = [3*3*20*3, 0, 3*3*50*20, 0, 3*3*100*50, 0, 3*3*500*100]
#CONV = [1, 1, 1, 1, 1, 1, 1]

# YOLOv2
S = [608*608*3, 608*608*32, 304*304*32, 304*304*64, 152*152*64, 152*152*128, 152*152*64, 152*152*128, 76*76*128, 76*76*256, 76*76*128, 76*76*256, 38*38*256, 38*38*512, 38*38*256, 38*38*512, 38*38*256, 38*38*512, 19*19*512, 19*19*1024, 19*19*512, 19*19*1024, 19*19*512, 19*19*1024, 19*19*1024, 19*19*1024, 38*38*512, 38*38*64, 19*19*256, 19*19*1280, 19*19*1024, 19*19*425, 0]
W = [3*3*32*3, 0, 3*3*64*32, 0, 3*3*128*64, 1*1*64*128, 3*3*128*64, 0, 3*3*256*128, 1*1*128*256, 3*3*256*128, 0, 3*3*512*256, 1*1*256*512, 3*3*512*256, 1*1*256*512, 3*3*512*256, 0, 3*3*1024*512, 1*1*512*1024, 3*3*1024*512, 1*1*512*1024, 3*3*1024*512, 3*3*1024*1024, 3*3*1024*1024, 0, 1*1*64*512, 0, 0, 3*3*1024*1280, 1*1*425*1024, 0]
CONV = [1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0]

# VGG-16
#S = [0, 224*224*3, 224*224*64, 224*224*64, 112*112*64, 112*112*128, 112*112*128, 56*56*128, 56*56*256, 56*56*256, 56*56*256, 28*28*256, 28*28*512, 28*28*512, 28*28*512, 14*14*512, 14*14*512, 14*14*512, 14*14*512, 7*7*512]
#W = [0, 3*3*64*3, 3*3*64*64, 0, 3*3*128*64, 3*3*128*128, 0, 3*3*256*128, 3*3*256*256, 3*3*256*256, 0, 3*3*512*256, 3*3*512*512, 3*3*512*512, 0, 3*3*512*512, 3*3*512*512, 3*3*512*512, 0]
#CONV = [0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0]

# Darknet Extraction
#S = [224*224*3, 112*112*64, 56*56*64, 56*56*192, 28*28*192, 28*28*128, 28*28*256, 28*28*256, 28*28*512, 14*14*512, 14*14*256, 14*14*512, 14*14*256, 14*14*512, 14*14*256, 14*14*512, 14*14*256, 14*14*512, 14*14*512, 14*14*1024, 7*7*1024, 7*7*512, 7*7*1024, 7*7*512, 7*7*1024, 7*7*1000]
#W = [7*7*64*3, 0, 3*3*192*64, 0, 1*1*128*192, 3*3*256*128, 1*1*256*256, 3*3*512*256, 0, 1*1*256*512, 3*3*512*256, 1*1*256*512, 3*3*512*256, 1*1*256*512, 3*3*512*256, 1*1*256*512, 3*3*512*256, 1*1*512*512, 3*3*1024*512, 0, 1*1*512*1024, 3*3*1024*512, 1*1*512*1024, 3*3*1024*512, 1*1*1000*1024]
#CONV = [1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1]

# AlexNet
#S = [227*227*3, 55*55*96, 27*27*96, 27*27*256, 13*13*256, 13*13*384, 13*13*384, 13*13*256, 6*6*256]
#W = [11*11*96*3, 0, 5*5*256*96, 0, 3*3*384*256, 3*3*384*384, 3*3*256*384, 0]
#CONV = [1, 0, 1, 0, 1, 1, 1, 0]
N = 1
# Model to use.
model = "yolo" # test, yolo, vgg, extract, alex

if model == "test":
# Test model
S = [100*100*3, 100*100*20, 50*50*20, 50*50*50, 25*25*50, 25*25*100, 12*12*100, 12*12*500]
W = [3*3*20*3, 0, 3*3*50*20, 0, 3*3*100*50, 0, 3*3*500*100]
CONV = [1, 1, 1, 1, 1, 1, 1]
elif model == "yolo":
# YOLOv2
S = [608*608*3, 608*608*32, 304*304*32, 304*304*64, 152*152*64, 152*152*128, 152*152*64, 152*152*128, 76*76*128, 76*76*256, 76*76*128, 76*76*256, 38*38*256, 38*38*512, 38*38*256, 38*38*512, 38*38*256, 38*38*512, 19*19*512, 19*19*1024, 19*19*512, 19*19*1024, 19*19*512, 19*19*1024, 19*19*1024, 19*19*1024, 38*38*512, 38*38*64, 19*19*256, 19*19*1280, 19*19*1024, 19*19*425, 0]
W = [3*3*32*3, 0, 3*3*64*32, 0, 3*3*128*64, 1*1*64*128, 3*3*128*64, 0, 3*3*256*128, 1*1*128*256, 3*3*256*128, 0, 3*3*512*256, 1*1*256*512, 3*3*512*256, 1*1*256*512, 3*3*512*256, 0, 3*3*1024*512, 1*1*512*1024, 3*3*1024*512, 1*1*512*1024, 3*3*1024*512, 3*3*1024*1024, 3*3*1024*1024, 0, 1*1*64*512, 0, 0, 3*3*1024*1280, 1*1*425*1024, 0]
CONV = [1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0]
elif model == "vgg":
# VGG-16
S = [0, 224*224*3, 224*224*64, 224*224*64, 112*112*64, 112*112*128, 112*112*128, 56*56*128, 56*56*256, 56*56*256, 56*56*256, 28*28*256, 28*28*512, 28*28*512, 28*28*512, 14*14*512, 14*14*512, 14*14*512, 14*14*512, 7*7*512]
W = [0, 3*3*64*3, 3*3*64*64, 0, 3*3*128*64, 3*3*128*128, 0, 3*3*256*128, 3*3*256*256, 3*3*256*256, 0, 3*3*512*256, 3*3*512*512, 3*3*512*512, 0, 3*3*512*512, 3*3*512*512, 3*3*512*512, 0]
CONV = [0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0]
elif model == "extract":
# Darknet Extraction
S = [224*224*3, 112*112*64, 56*56*64, 56*56*192, 28*28*192, 28*28*128, 28*28*256, 28*28*256, 28*28*512, 14*14*512, 14*14*256, 14*14*512, 14*14*256, 14*14*512, 14*14*256, 14*14*512, 14*14*256, 14*14*512, 14*14*512, 14*14*1024, 7*7*1024, 7*7*512, 7*7*1024, 7*7*512, 7*7*1024, 7*7*1000]
W = [7*7*64*3, 0, 3*3*192*64, 0, 1*1*128*192, 3*3*256*128, 1*1*256*256, 3*3*512*256, 0, 1*1*256*512, 3*3*512*256, 1*1*256*512, 3*3*512*256, 1*1*256*512, 3*3*512*256, 1*1*256*512, 3*3*512*256, 1*1*512*512, 3*3*1024*512, 0, 1*1*512*1024, 3*3*1024*512, 1*1*512*1024, 3*3*1024*512, 1*1*1000*1024]
CONV = [1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1]
elif model == "alex":
# AlexNet
S = [227*227*3, 55*55*96, 27*27*96, 27*27*256, 13*13*256, 13*13*384, 13*13*384, 13*13*256, 6*6*256]
W = [11*11*96*3, 0, 5*5*256*96, 0, 3*3*384*256, 3*3*384*384, 3*3*256*384, 0]
CONV = [1, 0, 1, 0, 1, 1, 1, 0]
else:
print("unknown model")
exit(1)

genId = 0

Expand Down Expand Up @@ -93,6 +99,7 @@ def solve():
a = []
b = []
for l in range(0, L):
print("W[", l, "]: ", W[l])
a.append(solver.IntVar(0, solver.Infinity(), 'a' + str(l)))
b.append(solver.BoolVar('b' + str(l)))

Expand All @@ -111,6 +118,8 @@ def solve():
ct1.SetCoefficient(b[l], 1)
ct1.SetCoefficient(b[l-1], -1)

#forceTo(solver, b[12], 1)

# max(a[l]) + sum(b[l]W[l]/N + (1-b[l])W[l])
obj = solver.Objective()
obj.SetMinimization()
Expand All @@ -136,9 +145,16 @@ def solve():
max_a = 0
w_tail = 0
for l in range(0, L):
if not CONV[l]:
continue
a_i = M[l] + K[l]
if b[l].solution_value() == 0:
a_i /= N
w_tail += W[l] / N
else:
w_tail += W[l]
max_a = a_i if a_i > max_a else max_a
w_tail += W[l]
print("l:", l, "i/o:", (M[l]+K[l])/1000, "w:", W[l]/1000)
print("max_a =", max_a)
print("w_tail =", w_tail)
print("total (single device) =", max_a + w_tail)
Expand Down

0 comments on commit ad1b146

Please sign in to comment.