-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprune.py
110 lines (102 loc) · 3.15 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torchvision.models import densenet201
from baseline import *
from modules import *
import time
def printing(model):
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.conv.weight == 0))
/ float(model.conv1.conv.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.conv.weight == 0))
/ float(model.conv2.conv.weight.nelement())
)
)
print(
"Sparsity in conv3.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv3.conv.weight == 0))
/ float(model.conv3.conv.weight.nelement())
)
)
print(
"Sparsity in conv4.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv4.conv.weight == 0))
/ float(model.conv4.conv.weight.nelement())
)
)
print(
"Sparsity in deconv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.deconv1.deconv.weight == 0))
/ float(model.deconv1.deconv.weight.nelement())
)
)
print(
"Sparsity in deconv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.deconv2.deconv.weight == 0))
/ float(model.deconv2.deconv.weight.nelement())
)
)
print(
"Sparsity in deconv3.weight: {:.2f}%".format(
100. * float(torch.sum(model.deconv3.deconv.weight == 0))
/ float(model.deconv3.deconv.weight.nelement())
)
)
print(
"Sparsity in deconv4.weight: {:.2f}%".format(
100. * float(torch.sum(model.deconv4.deconv.weight == 0))
/ float(model.deconv4.deconv.weight.nelement())
)
)
r = torch.rand(1, 3, 1920, 2160).to('cuda')
r0 = torch.rand(1, 3, 1920, 2160).to('cuda')
def measure(model, r):
arr = [0] * 9
for i in range(25):
model(r)
for i in range(10):
arr = [sum(x) for x in zip(arr, model(r))]
return sum(arr)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Baseline().to("cuda")
print("Unpruned time", measure(model, r), "\n")
parameters_to_prune = (
(model.conv1.conv, 'weight'),
(model.conv1.conv, 'bias'),
(model.conv2.conv, 'weight'),
(model.conv2.conv, 'bias'),
(model.conv3.conv, 'weight'),
(model.conv3.conv, 'bias'),
(model.conv4.conv, 'weight'),
(model.conv4.conv, 'bias'),
(model.deconv1.deconv, 'weight'),
(model.deconv1.deconv, 'bias'),
(model.deconv2.deconv, 'weight'),
(model.deconv2.deconv, 'bias'),
(model.deconv3.deconv, 'weight'),
(model.deconv3.deconv, 'bias'),
(model.deconv4.deconv, 'weight'),
(model.deconv4.deconv, 'bias'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
print("Pruned time", measure(model, r0))
printing(model)
print()
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.7,
)
print("Pruned time", measure(model, r0))
printing(model)