-
Notifications
You must be signed in to change notification settings - Fork 718
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix thread safety of logging while adding handler (#22)
- Loading branch information
Showing
3 changed files
with
90 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,99 @@ | ||
from threading import Thread | ||
from threading import Thread, Barrier | ||
import itertools | ||
from loguru import logger | ||
import time | ||
|
||
|
||
def test_safe(capsys): | ||
first_thread_initialized = False | ||
second_thread_initialized = False | ||
entered = False | ||
output = "" | ||
class NonSafeSink: | ||
def __init__(self): | ||
self.written = "" | ||
|
||
def non_safe_sink(msg): | ||
nonlocal entered | ||
nonlocal output | ||
assert not entered | ||
entered = True | ||
def write(self, message): | ||
self.written += message | ||
time.sleep(1) | ||
entered = False | ||
output += msg | ||
self.written += message | ||
|
||
def first_thread(): | ||
nonlocal first_thread_initialized | ||
first_thread_initialized = True | ||
time.sleep(1) | ||
assert second_thread_initialized | ||
logger.debug("message 1") | ||
|
||
def second_thread(): | ||
nonlocal second_thread_initialized | ||
second_thread_initialized = True | ||
time.sleep(1) | ||
assert first_thread_initialized | ||
def test_safe_logging(): | ||
barrier = Barrier(2) | ||
counter = itertools.count() | ||
|
||
sink = NonSafeSink() | ||
logger.add(sink, format="{message}", catch=False) | ||
|
||
def threaded(): | ||
barrier.wait() | ||
logger.info("{}", next(counter)) | ||
|
||
threads = [Thread(target=threaded) for _ in range(2)] | ||
|
||
for thread in threads: | ||
thread.start() | ||
|
||
for thread in threads: | ||
thread.join() | ||
|
||
assert sink.written == "0\n0\n1\n1\n" | ||
|
||
|
||
def test_safe_adding_while_logging(writer): | ||
barrier = Barrier(2) | ||
counter = itertools.count() | ||
|
||
sink_1 = NonSafeSink() | ||
sink_2 = NonSafeSink() | ||
logger.add(sink_1, format="{message}", catch=False) | ||
logger.add(sink_2, format="-> {message}", catch=False) | ||
|
||
def thread_1(): | ||
barrier.wait() | ||
logger.info("{}", next(counter)) | ||
|
||
def thread_2(): | ||
barrier.wait() | ||
time.sleep(0.5) | ||
logger.debug("message 2") | ||
logger.add(writer, format="{message}", catch=False) | ||
logger.info("{}", next(counter)) | ||
|
||
threads = [Thread(target=thread_1), Thread(target=thread_2)] | ||
|
||
for thread in threads: | ||
thread.start() | ||
|
||
logger.add(non_safe_sink, format="{message}", catch=False) | ||
for thread in threads: | ||
thread.join() | ||
|
||
assert sink_1.written == "0\n0\n1\n1\n" | ||
assert sink_2.written == "-> 0\n-> 0\n-> 1\n-> 1\n" | ||
assert writer.read() == "1\n" | ||
|
||
|
||
def test_safe_removing_while_logging(): | ||
barrier = Barrier(2) | ||
counter = itertools.count() | ||
|
||
sink_1 = NonSafeSink() | ||
sink_2 = NonSafeSink() | ||
a = logger.add(sink_1, format="{message}", catch=False) | ||
b = logger.add(sink_2, format="-> {message}", catch=False) | ||
|
||
def thread_1(): | ||
barrier.wait() | ||
logger.info("{}", next(counter)) | ||
|
||
def thread_2(): | ||
barrier.wait() | ||
time.sleep(0.5) | ||
logger.remove(b) | ||
logger.info("{}", next(counter)) | ||
|
||
threads = [Thread(target=first_thread), Thread(target=second_thread)] | ||
threads = [Thread(target=thread_1), Thread(target=thread_2)] | ||
|
||
for thread in threads: | ||
thread.start() | ||
|
||
for thread in threads: | ||
thread.join() | ||
|
||
out, err = capsys.readouterr() | ||
assert out == err == "" | ||
assert output == "message 1\nmessage 2\n" | ||
assert sink_1.written == "0\n0\n1\n1\n" | ||
assert sink_2.written == "-> 0\n-> 0\n" |