-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_models_windows.py
45 lines (39 loc) · 1.43 KB
/
test_models_windows.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import sys
from ts.metrics.metrics_store import MetricsStore
from ts.torch_handler.base_handler import BaseHandler
from uuid import uuid4
from pprint import pprint
class ModelContext:
def __init__(self):
self.manifest = {
'model': {
'modelName': 'ptclassifier',
'serializedFile': 'traced_pt_classifer.pt',
'modelFile': 'model_ph.py'
}
}
self.system_properties = {
'model_dir': '<ADD COMPLETE PATH HERE>\share_folder\\model-store\\ptclassifier'
}
self.explain = False
self.metrics = MetricsStore(uuid4(), self.manifest['model']['modelName'])
def get_request_header(self, idx, exp):
if exp == 'explain':
return self.explain
return False
def main():
if sys.argv[1] == 'fast':
from ptclassifier.TransformerSeqClassificationHandler import TransformersSeqClassifierHandler as Classifier
else:
from ptclassifiernotr.TransformerSeqClassificationHandler import TransformersSeqClassifierHandler as Classifier
ctxt = ModelContext()
handler = Classifier()
handler.initialize(ctxt)
data = [{'data': 'To be or not to be, that is the question.'}]
for i in range(1000):
processed = handler.handle(data, ctxt)
#print(processed)
for m in ctxt.metrics.store:
print(f'{m.name}: {m.value} {m.unit}')
if __name__ == '__main__':
main()