Skip to content

Commit

Permalink
Process OpenVINO in thread (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkurt authored Nov 29, 2020
1 parent a9ecbdd commit 3cfa6c5
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions deepvariant/openvino_estimator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import tensorflow as tf
from threading import Thread
from queue import Queue

try:
from openvino.inference_engine import IECore, StatusCode
Expand All @@ -16,24 +18,25 @@ def __init__(self, model_xml, model_bin, *, input_fn, model):
self.tf_sess = tf.compat.v1.Session()
self.input_fn = input_fn
self.model = model
self.outputs = []
self.outputs = {}
self.results = Queue()
self.process_thread = Thread(target=self._process)
self.features = tf.compat.v1.data.make_one_shot_iterator(self.input_fn({'batch_size': 64})).get_next()


def __iter__(self):
def _process(self):
# Read input data
features = tf.compat.v1.data.make_one_shot_iterator(
self.input_fn({'batch_size': 64})).get_next()
self.images = features['image']
self.variant = features['variant']
self.alt_allele_indices = features['alt_allele_indices']
self.iter_id = 0
self.images = self.features['image']
self.variant = self.features['variant']
self.alt_allele_indices = self.features['alt_allele_indices']

try:
# List that maps infer requests to index of processed input.
# -1 means that request has not been started yet.
infer_request_input_id = [-1] * len(self.exec_net.requests)

inp_id = 0
iter_id = 0
while True:
# Get next input
inp, variant, alt_allele_indices = self.tf_sess.run([self.images, self.variant, self.alt_allele_indices])
Expand All @@ -58,14 +61,22 @@ def __iter__(self):

# Start this request on new data
infer_request_input_id[infer_request_id] = inp_id
inp_id += 1
self.outputs.append({
self.outputs[inp_id] = {
'probabilities': None,
'variant': variant[i],
'alt_allele_indices': alt_allele_indices[i]
})
}
inp_id += 1
request.async_infer({'input': inp[i:i+1].transpose(0, 3, 1, 2)})

while self.outputs:
if not self.outputs[iter_id]['probabilities'] is None:
self.results.put(self.outputs.pop(iter_id))
iter_id += 1
else:
break


except (StopIteration, tf.errors.OutOfRangeError):
# Copy rest of outputs
status = self.exec_net.wait()
Expand All @@ -74,15 +85,18 @@ def __iter__(self):
for infer_request_id, out_id in enumerate(infer_request_input_id):
if not self.outputs[out_id]['probabilities']:
request = self.exec_net.requests[infer_request_id]
self.outputs[out_id]['probabilities'] = request.output_blobs['InceptionV3/Predictions/Softmax'].buffer.reshape(-1)
res = self.outputs[out_id]
res['probabilities'] = request.output_blobs['InceptionV3/Predictions/Softmax'].buffer.reshape(-1)
self.results.put(res)


def __iter__(self):
self.process_thread.start()
return self


def __next__(self):
if self.iter_id < len(self.outputs):
res = self.outputs[self.iter_id]
self.iter_id += 1
return res
if self.process_thread.isAlive() or not self.results.empty():
return self.results.get()
else:
raise StopIteration

0 comments on commit 3cfa6c5

Please sign in to comment.