Skip to content

Commit

Permalink
Fix the problem that async logging has hanging thread (#10374)
Browse files Browse the repository at this point in the history
Signed-off-by: chenmoneygithub <[email protected]>
  • Loading branch information
chenmoneygithub authored and BenWilson2 committed Nov 14, 2023
1 parent 21b917b commit d017719
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions mlflow/utils/async_logging/async_logging_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ def _at_exit_callback(self) -> None:
try:
# Stop the data processing thread
self._stop_data_logging_thread_event.set()
# Waits till queue is drained.
self._run_data_logging_thread.result()
self._batch_logging_threadpool.shutdown(wait=False)
# Waits till logging queue is drained.
self._batch_logging_thread.join()
self._batch_status_check_threadpool.shutdown(wait=False)
except Exception as e:
_logger.error(f"Encountered error while trying to finish logging: {e}")
Expand Down Expand Up @@ -132,14 +131,10 @@ def __getstate__(self):
del state["_run_data_logging_thread"]
if "_stop_data_logging_thread_event" in state:
del state["_stop_data_logging_thread_event"]
if "_batch_logging_threadpool" in state:
del state["_batch_logging_threadpool"]
if "_batch_logging_thread" in state:
del state["_batch_logging_thread"]
if "_batch_status_check_threadpool" in state:
del state["_batch_status_check_threadpool"]
if "_run_data_logging_thread" in state:
del state["_run_data_logging_thread"]
if "_stop_data_logging_thread_event" in state:
del state["_stop_data_logging_thread_event"]

return state

Expand All @@ -158,7 +153,7 @@ def __setstate__(self, state):
self._queue = Queue()
self._lock = threading.RLock()
self._is_activated = False
self._batch_logging_threadpool = None
self._batch_logging_thread = None
self._batch_status_check_threadpool = None
self._stop_data_logging_thread_event = None

Expand Down Expand Up @@ -193,7 +188,6 @@ def log_batch_async(
)

self._queue.put(batch)

operation_future = self._batch_status_check_threadpool.submit(self._wait_for_batch, batch)
return RunOperations(operation_futures=[operation_future])

Expand All @@ -217,14 +211,17 @@ def activate(self) -> None:
self._stop_data_logging_thread_event = threading.Event()

# Keeping max_workers=1 so that there are no two threads
self._batch_logging_threadpool = ThreadPoolExecutor(max_workers=1)

self._batch_status_check_threadpool = ThreadPoolExecutor(max_workers=10)

self._run_data_logging_thread = self._batch_logging_threadpool.submit(
self._logging_loop
) # concurrent.futures.Future[self._logging_loop]
self._batch_logging_thread = threading.Thread(
target=self._logging_loop,
name="MLflowAsyncLoggingLoop",
daemon=True,
)

self._batch_status_check_threadpool = ThreadPoolExecutor(
max_workers=10,
thread_name_prefix="MLflowAsyncLoggingStatusCheck",
)
self._batch_logging_thread.start()
atexit.register(self._at_exit_callback)

self._is_activated = True

0 comments on commit d017719

Please sign in to comment.