forked from alibaba/TinyNeuralNetwork
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodifier_speed_test.py
127 lines (93 loc) · 3.81 KB
/
modifier_speed_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import time
import unittest
from operator import add
import torch
import torch.nn as nn
import copy
import torchvision
import numpy as np
import random
import torch.nn.functional
from tinynn.converter import TFLiteConverter
from tinynn.graph.modifier import is_dw_conv, l2_norm
from tinynn.prune.oneshot_pruner import OneShotChannelPruner as OneShotChannelPrunerOld
from tinynn.prune.oneshot_pruner import OneShotChannelPruner as OneShotChannelPrunerNew
from tinynn.util.util import import_from_path, get_logger
CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
log = get_logger(__name__)
def get_topk(lst, k, offset=0):
idx_lst = [(i, lst[i]) for i in range(len(lst))]
sorted_lst = sorted(idx_lst, key=lambda x: x[1])
sorted_lst_k = sorted_lst[:k]
idx = [sorted_lst_k[i][0] + offset for i in range(len(sorted_lst_k))]
return sorted(idx)
def get_rd_lst(length):
rd_lst = random.sample(range(0, 10000), length)
random.shuffle(rd_lst)
print(rd_lst)
return rd_lst
def init_conv_by_list(conv, ch_value):
assert conv.weight.shape[0] == len(ch_value)
for i in range(len(ch_value)):
conv.weight.data[i, :] = ch_value[i]
def module_init(model: nn.Module, init_dict=None):
init_value = {}
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
channel = module.out_features
elif isinstance(module, nn.Conv2d):
channel = module.out_channels
else:
continue
if init_dict and name in init_dict:
ch_value = init_dict[name]
else:
ch_value = get_rd_lst(channel)
init_conv_by_list(module, ch_value)
init_value[name] = ch_value
print(init_value)
return init_value
def speed_test(model, dummy_input):
with torch.no_grad():
model.eval()
st = time.time()
pruner_new = OneShotChannelPrunerNew(model, dummy_input, {"sparsity": 0.5, "metrics": "l2_norm"})
log.info(f"[SPEED TEST][Pruner Init] {time.time() - st}")
st = time.time()
pruner_new.register_mask()
log.info(f"[SPEED TEST][Register Mask] {time.time() - st}")
st = time.time()
pruner_new.apply_mask()
log.info(f"[SPEED TEST][Apply Mask] {time.time() - st}")
pruner_new.graph.generate_code('out/new_model.py', 'out/new_model.pth', 'new_model')
new_model_pruned = import_from_path('out.new_model', "out/new_model.py", "new_model")()
new_model_pruned(dummy_input)
class ModifierForwardTester(unittest.TestCase):
def test_mbv2(self):
model = torchvision.models.mobilenet_v2(pretrained=False)
speed_test(model, torch.randn((1, 3, 224, 224)))
def test_mbv3(self):
model = torchvision.models.mobilenet_v3_small(pretrained=False)
speed_test(model, torch.randn((1, 3, 224, 224)))
def test_mbv3_large(self):
model = torchvision.models.mobilenet_v3_large(pretrained=False)
speed_test(model, torch.randn((1, 3, 224, 224)))
def test_vgg16(self):
model = torchvision.models.vgg16(pretrained=False)
speed_test(model, torch.randn((1, 3, 224, 224)))
def test_googlenet(self):
model = torchvision.models.googlenet(pretrained=False)
speed_test(model, torch.randn((1, 3, 224, 224)))
def test_shufflenet(self):
model = torchvision.models.shufflenet_v2_x0_5(pretrained=False)
speed_test(model, torch.randn((1, 3, 224, 224)))
def test_resnet18(self):
model = torchvision.models.resnet18(pretrained=False)
module_init(model)
speed_test(model, torch.randn((1, 3, 224, 224)))
def test_densenet121(self):
model = torchvision.models.densenet121(pretrained=False)
speed_test(model, torch.randn((1, 3, 224, 224)))
if __name__ == '__main__':
unittest.main()