diff --git a/asserts/__init__.py b/asserts/__init__.py index 1c65ecf..410e8dc 100644 --- a/asserts/__init__.py +++ b/asserts/__init__.py @@ -20,12 +20,14 @@ """ +from __future__ import annotations + import re import sys from datetime import datetime, timedelta from json import loads as json_loads -from typing import Set -from warnings import catch_warnings +from typing import Any, Callable, Set +from warnings import WarningMessage, catch_warnings def fail(msg=None): @@ -864,7 +866,7 @@ def assert_datetime_about_now_utc(actual, msg_fmt="{msg}"): fail(msg_fmt.format(msg=msg, actual=actual, now=now)) -class AssertRaisesContext(object): +class AssertRaisesContext: """A context manager to test for exceptions with certain properties. When the context is left and no exception has been raised, an @@ -906,7 +908,7 @@ def __init__(self, exception, msg_fmt="{msg}"): self._exc_type = exception self._exc_val = None self._exception_name = getattr(exception, "__name__", str(exception)) - self._tests = [] + self._tests: list[Callable[[Any], object]] = [] def __enter__(self): return self @@ -929,7 +931,7 @@ def format_message(self, default_msg): exc_name=self._exception_name, ) - def add_test(self, cb): + def add_test(self, cb: Callable[[Any], object]) -> None: """Add a test callback. This callback is called after determining that the right exception @@ -1188,9 +1190,11 @@ class AssertWarnsContext(object): def __init__(self, warning_class, msg_fmt="{msg}"): self._warning_class = warning_class self._msg_fmt = msg_fmt - self._warning_context = None + self._warning_context: catch_warnings[list[WarningMessage]] | None = ( + None + ) self._warnings = [] - self._tests = [] + self._tests: list[Callable[[Warning], bool]] = [] def __enter__(self): self._warning_context = catch_warnings(record=True) @@ -1198,6 +1202,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + assert self._warning_context is not None self._warning_context.__exit__(exc_type, exc_val, exc_tb) if not any(self._is_expected_warning(w) for w in self._warnings): fail(self.format_message()) @@ -1210,12 +1215,12 @@ def format_message(self): exc_name=self._warning_class.__name__, ) - def _is_expected_warning(self, warning): + def _is_expected_warning(self, warning) -> bool: if not issubclass(warning.category, self._warning_class): return False return all(test(warning) for test in self._tests) - def add_test(self, cb): + def add_test(self, cb: Callable[[Warning], bool]) -> None: """Add a test callback. This callback is called after determining that the right warning diff --git a/test_asserts.py b/test_asserts.py index d66d61b..52d505a 100644 --- a/test_asserts.py +++ b/test_asserts.py @@ -1269,8 +1269,9 @@ def extra_test(warning): def test_assert_warns__add_test_not_called(self): called = Box(False) - def extra_test(_): + def extra_test(_: Warning) -> bool: called.value = True + return False with assert_raises(AssertionError): with assert_warns(UserWarning) as context: @@ -1342,10 +1343,7 @@ def test_assert_warns_regex__not_issued__default_message(self): pass def test_assert_warns_regex__not_issued__custom_message(self): - expected = ( - "no ImportWarning matching 'abc' issued;ImportWarning;" - "ImportWarning;abc" - ) + expected = "no ImportWarning matching 'abc' issued;ImportWarning;ImportWarning;abc" with _assert_raises_assertion(expected): msg_fmt = "{msg};{exc_type.__name__};{exc_name};{pattern}" with assert_warns_regex(ImportWarning, r"abc", msg_fmt=msg_fmt):