-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrace.py
61 lines (46 loc) · 2 KB
/
trace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torchvision.models import mobilenet_v2
import torch
import coremltools as ct
from torchvision import transforms
from coremltools.models.neural_network import quantization_utils
with open('class_names.txt') as f:
class_labels = f.read().splitlines()
# Initialise model
mnet = mobilenet_v2(pretrained=False, progress=False,
num_classes=len(class_labels))
mnet.load_state_dict(torch.load('mobilenet.bin',
map_location=torch.device('cpu')))
# Add a softmax on top
model = torch.nn.Sequential(mnet, torch.nn.Softmax(dim=1))
model.eval()
rand = torch.rand(1,3,200,300)
traced_model = torch.jit.trace(model, rand)
ctmodel = ct.convert(traced_model,
inputs=[ct.ImageType(name="drawing", shape=rand.shape, bias=[-1,-1,-1],
scale=1/127)],
classifier_config = ct.ClassifierConfig(class_labels))
spec = ctmodel.get_spec()
# Rename the output dictionary to something sensible
ct.utils.rename_feature(spec, '649', 'classLabelProbs')
ctmodel = ct.models.MLModel(spec)
# Set feature descriptions (these show up as comments in XCode)
ctmodel.input_description["drawing"] = "Input drawing to be classified"
ctmodel.output_description["classLabel"] = "Most likely symbol"
ctmodel.output_description["classLabelProbs"] = "Probability scores for each symbol"
# Set model author name
ctmodel.author = "Venkata S Govindarajan"
# Set the license of the model
ctmodel.license = "MIT License"
# Set a short description for the Xcode UI
ctmodel.short_description = "Detects the most likely LaTeX mathematical symbol \
corresponding to a drawing."
# Set a version for the model
ctmodel.version = "0.95"
# Save model
ctmodel.save("deTeX.mlmodel")
# Quantisation to FP16 model that reduces size by half without (supposedly)
# affecting accuracy
ctmodel_fp16 = quantization_utils.quantize_weights(ctmodel, nbits=16)
ctmodel_fp16.save("deTeX16.mlmodel")
ctmodel_fp8 = quantization_utils.quantize_weights(ctmodel, nbits=8)
ctmodel_fp8.save("deTeX8.mlmodel")