-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_flask.py
87 lines (72 loc) · 2.41 KB
/
run_flask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# coding: utf-8
from gevent import pywsgi
from flask import Flask, request
import threading
import torch
import time
import base64
import os
from util.loaders import read_classes, get_color_dict, read_anchors
from models import YOLO
def after_request(response):
response.headers['Cache-Control'] = 'no-cache'
response.headers['Access-Control-Allow-Origin'] = '*'
return response
app = Flask(__name__)
app.after_request(after_request)
out_dir = "./data/images"
model_path = "./models/yolov3.pth"
@app.route('/detection/', methods=['POST'])
def upload():
global model_lock, model, model_use_time
model_lock.acquire()
if model.backbone is None:
model.backbone = torch.load(model_path).cuda()
f = request.files["image"]
print("Receive image:", f.filename)
in_path = os.path.join(out_dir, f.filename)
out_path = os.path.join(out_dir, f.filename.split(".")[0] + "_pred.jpg")
if not os.path.exists(out_path):
f.save(in_path)
model.detect_image(
in_path,
0.5,
0.5,
color_dict,
do_show=False,
output_path=out_path
)
model_use_time = time.time()
model_lock.release()
with open(out_path, 'rb') as f:
image = base64.b64encode(f.read())
return image
def cuda_memory_control(max_time, time_gap):
global model_lock, model, model_use_time
while True:
model_lock.acquire()
if time.time() - model_use_time >= max_time and model.backbone is not None:
model.backbone = None
torch.cuda.empty_cache()
print("Cache cleaned.")
model_lock.release()
time.sleep(time_gap)
def run_server():
server = pywsgi.WSGIServer(('0.0.0.0', 34560), app, log=app.logger)
server.serve_forever()
# app.run(threaded=False)
if __name__ == '__main__':
classes = read_classes("./data/coco-ch.names")
color_dict = get_color_dict(classes, "./data/colors")
anchors = read_anchors("./data/anchors")
model = YOLO(classes,
model_load_path=model_path,
anchors=anchors,
device_ids="0")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
model_use_time = time.time()
model_lock = threading.Lock()
threading.Thread(target=cuda_memory_control, args=(60, 30)).start()
threading.Thread(target=run_server).start()
print("Server is successfully started.")