-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnn.py
36 lines (26 loc) · 1.08 KB
/
cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from resnet import *
import pickle
import os
def resModel(args, device): #resnet18
model = resnet18(end2end= False, pretrained= False, num_class=args.num_class).to(device)
if args.pretrained_backbone_path:
checkpoint = torch.load(args.pretrained_backbone_path, map_location=device)
pretrained_state_dict = checkpoint['state_dict']
model_state_dict = model.state_dict()
for key in pretrained_state_dict:
if ((key == 'fc.weight') | (key=='fc.bias') | (key=='feature.weight') | (key=='feature.bias') ) :
pass
else:
model_state_dict[key] = pretrained_state_dict[key]
model.load_state_dict(model_state_dict, strict = False)
print('Model loaded from Msceleb pretrained')
else:
print('No pretrained resent18 model built.')
return model