-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_weights.py
88 lines (77 loc) · 2.72 KB
/
convert_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import os
import megengine as mge
import numpy as np
import torch
import torch.nn as nn
from models.light_cnn import LightCNN_9Layers, LightCNN_29Layers, LightCNN_29Layers_v2
from models.torch_models import LightCNN_9Layers as torch_LightCNN_9Layers
from models.torch_models import LightCNN_29Layers as torch_LightCNN_29Layers
from models.torch_models import LightCNN_29Layers_v2 as torch_LightCNN_29Layers_v2
MODEL_MAPPER = {
'9': (LightCNN_9Layers, torch_LightCNN_9Layers),
'29': (LightCNN_29Layers, torch_LightCNN_29Layers),
'29v2': (LightCNN_29Layers_v2, torch_LightCNN_29Layers_v2),
}
def get_atttr_by_name(torch_module, k):
name_list = k.split('.')
sub_module = getattr(torch_module, name_list[0])
if len(name_list) != 1:
for i in name_list[1:-1]:
try:
sub_module = getattr(sub_module, i)
except:
sub_module = sub_module[int(i)]
return sub_module
def convert(torch_model, torch_dict):
new_dict = {}
for k, v in torch_dict.items():
data = v.numpy()
sub_module = get_atttr_by_name(torch_model, k)
is_conv = isinstance(sub_module, nn.Conv2d)
if is_conv:
groups = sub_module.groups
is_group = groups > 1
else:
is_group = False
if "weight" in k and is_group:
out_ch, in_ch, h, w = data.shape
data = data.reshape(groups, out_ch // groups, in_ch, h, w)
if "bias" in k:
if is_conv:
data = data.reshape(1, -1, 1, 1)
if "num_batches_tracked" in k:
continue
new_dict[k] = data
return new_dict
def main(torch_name, torch_path):
torch_state_dict = torch.load(torch_path, map_location='cpu')
torch_state_dict = torch_state_dict['state_dict']
s = {}
for k in torch_state_dict.keys():
s[k.replace("module.", "")] = torch_state_dict[k]
torch_model = MODEL_MAPPER[torch_name][1]()
torch_model.load_state_dict(s)
model = MODEL_MAPPER[torch_name][0]()
new_dict = convert(torch_model, s)
model.load_state_dict(new_dict)
os.makedirs('pretrained', exist_ok=True)
mge.save(new_dict, os.path.join('pretrained', torch_name + '.pkl'))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
type=str,
default='29',
help=f"which model to convert from torch to megengine, optional: {list(MODEL_MAPPER.keys())}",
)
parser.add_argument(
"-c",
"--ckpt",
type=str,
default="./LightCNN_29Layers_checkpoint.pth.tar",
help=f"path to torch checkpoint",
)
args = parser.parse_args()
main(args.model, args.ckpt)