diff --git a/model-optimizer/telemetry/utils/sender.py b/model-optimizer/telemetry/utils/sender.py index d2338c6e03cec4..481984ce7a6411 100644 --- a/model-optimizer/telemetry/utils/sender.py +++ b/model-optimizer/telemetry/utils/sender.py @@ -34,13 +34,18 @@ def _future_callback(future): with self._lock: self.queue_size -= 1 + free_space = False with self._lock: if self.queue_size < MAX_QUEUE_SIZE: - fut = self.executor.submit(backend.send, message) - fut.add_done_callback(_future_callback) + free_space = True self.queue_size += 1 else: pass # dropping a message because the queue is full + # to avoid dead lock we should not add callback inside the "with self._lock" block because it will be executed + # immediately if the fut is available + if free_space: + fut = self.executor.submit(backend.send, message) + fut.add_done_callback(_future_callback) def force_shutdown(self, timeout: float): """ @@ -53,11 +58,14 @@ def force_shutdown(self, timeout: float): :return: None """ try: + need_sleep = False with self._lock: if self.queue_size > 0: - sleep(timeout) - self.executor.shutdown(wait=False) - self.executor._threads.clear() - futures.thread._threads_queues.clear() + need_sleep = True + if need_sleep: + sleep(timeout) + self.executor.shutdown(wait=False) + self.executor._threads.clear() + futures.thread._threads_queues.clear() except Exception: pass diff --git a/model-optimizer/telemetry/utils/sender_test.py b/model-optimizer/telemetry/utils/sender_test.py new file mode 100644 index 00000000000000..45214e8f7e3c81 --- /dev/null +++ b/model-optimizer/telemetry/utils/sender_test.py @@ -0,0 +1,81 @@ +""" + Copyright (C) 2018-2021 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest +import time + +from telemetry.utils.sender import TelemetrySender + + +class FakeTelemetryBackend: + def send(self, param): + pass + + +class FakeTelemetryBackendWithSleep: + def send(self, param): + time.sleep(1) + + +class TelemetrySenderStress(unittest.TestCase): + def test_stress(self): + """ + Stress tests to schedule a lot of threads which works super fast (do nothing) with sending messages." + """ + tm = TelemetrySender() + fake_backend = FakeTelemetryBackend() + for _ in range(1000000): + tm.send(fake_backend, None) + + def test_check_shutdown(self): + """ + Stress test to schedule many threads taking 1 second and then ask to force shutdown. Make sure that the elapsed + time is small. + """ + tm = TelemetrySender() + fake_backend = FakeTelemetryBackendWithSleep() + # schedule many requests which just wait 1 second + for _ in range(100000): + tm.send(fake_backend, None) + + start_time = time.time() + # ask to shutdown with timeout of 1 second + tm.force_shutdown(1) + while len(tm.executor._threads): + pass + # check that no more than 3 seconds spent + self.assertTrue(time.time() - start_time < 3) + + def test_check_shutdown_negative(self): + """ + Test to check that without forcing shutdown total execution time is expected. + """ + tm = TelemetrySender(1) # only one worker thread + fake_backend = FakeTelemetryBackendWithSleep() + start_time = time.time() + # schedule 5 requests which totally should work more than 4 seconds + for _ in range(5): + tm.send(fake_backend, None) + + try: + # wait until all threads finish their work. We use internal ThreadPoolExecutor attribute _work_queue to make + # sure that all workers completed their work, so the whole code is wrapped to try/except to avoid exceptions + # if internal implementation is changed in the future + while tm.executor._work_queue.qsize(): + pass + self.assertTrue(time.time() - start_time > 4.0) + except: + pass