diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 39e84a7ed841..ad9920b5062f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,3 +2,4 @@ /.github/* @Chia-Network/actions-reviewers /PRETTY_GOOD_PRACTICES.md @altendky @Chia-Network/required-reviewers /pylintrc @altendky @Chia-Network/required-reviewers +/tests/ether.py @altendky @Chia-Network/required-reviewers diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 41fd35c86db3..b94aa610b227 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -124,4 +124,4 @@ jobs: - name: Add benchmark results to workflow summary if: always() run: | - python -m tests.process_benchmarks --xml junit-data/benchmarks.xml --markdown --link-prefix ${{ github.event.repository.html_url }}/blob/${{ github.sha }}/ --link-line-separator \#L >> "$GITHUB_STEP_SUMMARY" + python -m tests.process_junit --type benchmark --xml junit-data/benchmarks.xml --markdown --link-prefix ${{ github.event.repository.html_url }}/blob/${{ github.sha }}/ --link-line-separator \#L >> "$GITHUB_STEP_SUMMARY" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ac4c6920c887..c334cb3bc034 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -142,14 +142,6 @@ jobs: path: junit-data/* if-no-files-found: error - - name: Publish JUnit results - if: always() - uses: actions/upload-artifact@v4 - with: - name: junit-results - path: junit-results/* - if-no-files-found: error - - name: Download Coverage uses: actions/download-artifact@v4 with: @@ -169,6 +161,20 @@ jobs: - uses: chia-network/actions/activate-venv@main + - name: Add time out assert results to workflow summary + if: always() + run: | + python -m tests.process_junit --limit 50 --type time_out_assert --xml junit-results/junit.xml --markdown --link-prefix ${{ github.event.repository.html_url }}/blob/${{ github.sha }}/ --link-line-separator \#L >> "$GITHUB_STEP_SUMMARY" + python -m tests.process_junit --type time_out_assert --xml junit-results/junit.xml --markdown --link-prefix ${{ github.event.repository.html_url }}/blob/${{ github.sha }}/ --link-line-separator \#L >> junit-results/time_out_assert.md + + - name: Publish JUnit results + if: always() + uses: actions/upload-artifact@v4 + with: + name: junit-results + path: junit-results/* + if-no-files-found: error + - name: Coverage Processing run: | coverage combine --rcfile=.coveragerc --data-file=coverage-reports/.coverage coverage-data/ diff --git a/chia/util/misc.py b/chia/util/misc.py index 7afa1dec3d3d..b5110916409c 100644 --- a/chia/util/misc.py +++ b/chia/util/misc.py @@ -8,6 +8,7 @@ import signal import sys from dataclasses import dataclass +from inspect import getframeinfo, stack from pathlib import Path from types import FrameType from typing import ( @@ -19,10 +20,12 @@ ContextManager, Dict, Generic, + Iterable, Iterator, List, Optional, Sequence, + Tuple, TypeVar, Union, final, @@ -421,3 +424,17 @@ def available_logical_cores() -> int: return count return len(psutil.Process().cpu_affinity()) + + +def caller_file_and_line(distance: int = 1, relative_to: Iterable[Path] = ()) -> Tuple[str, int]: + caller = getframeinfo(stack()[distance + 1][0]) + + caller_path = Path(caller.filename) + options: List[str] = [caller_path.as_posix()] + for path in relative_to: + try: + options.append(caller_path.relative_to(path).as_posix()) + except ValueError: + pass + + return min(options, key=len), caller.lineno diff --git a/tests/conftest.py b/tests/conftest.py index f5166d3b7341..0faad3154036 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import dataclasses import datetime import functools +import json import math import multiprocessing import os @@ -20,7 +21,9 @@ # TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469 from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +import tests from chia.clvm.spend_sim import CostLogger from chia.consensus.constants import ConsensusConstants from chia.full_node.full_node import FullNode @@ -66,10 +69,11 @@ from chia.util.task_timing import start_task_instrumentation, stop_task_instrumentation from chia.wallet.wallet_node import WalletNode from chia.wallet.wallet_node_api import WalletNodeAPI +from tests import ether from tests.core.data_layer.util import ChiaRoot from tests.core.node_height import node_height_at_least from tests.simulation.test_simulation import test_constants_modified -from tests.util.misc import BenchmarkRunner, GcMode, RecordingWebServer, _AssertRuntime, measure_overhead +from tests.util.misc import BenchmarkRunner, GcMode, RecordingWebServer, TestId, _AssertRuntime, measure_overhead from tests.util.setup_nodes import ( OldSimulatorsAndWallets, SimulatorsAndWallets, @@ -91,6 +95,20 @@ from tests.util.setup_nodes import setup_farmer_multi_harvester +@pytest.fixture(name="ether_setup", autouse=True) +def ether_setup_fixture(request: SubRequest, record_property: Callable[[str, object], None]) -> Iterator[None]: + with MonkeyPatch.context() as monkeypatch_context: + monkeypatch_context.setattr(ether, "record_property", record_property) + monkeypatch_context.setattr(ether, "test_id", TestId.create(node=request.node)) + yield + + +@pytest.fixture(autouse=True) +def ether_test_id_property_fixture(ether_setup: None, record_property: Callable[[str, object], None]) -> None: + assert ether.test_id is not None, "ether.test_id is None, did you forget to use the ether_setup fixture?" + record_property("test_id", json.dumps(ether.test_id.marshal(), ensure_ascii=True, sort_keys=True)) + + def make_old_setup_simulators_and_wallets(new: SimulatorsAndWallets) -> OldSimulatorsAndWallets: return ( [simulator.peer_api for simulator in new.simulators], @@ -131,16 +149,12 @@ def benchmark_runner_overhead_fixture() -> float: @pytest.fixture(name="benchmark_runner") def benchmark_runner_fixture( - request: SubRequest, benchmark_runner_overhead: float, - record_property: Callable[[str, object], None], benchmark_repeat: int, ) -> BenchmarkRunner: - label = request.node.name return BenchmarkRunner( - label=label, + test_id=ether.test_id, overhead=benchmark_runner_overhead, - record_property=record_property, ) @@ -434,6 +448,13 @@ def pytest_addoption(parser: pytest.Parser): type=int, help=f"The number of times to run each benchmark, default {default_repeats}.", ) + group.addoption( + "--time-out-assert-repeats", + action="store", + default=default_repeats, + type=int, + help=f"The number of times to run each test with time out asserts, default {default_repeats}.", + ) def pytest_configure(config): @@ -459,6 +480,22 @@ def benchmark_repeat_fixture() -> int: globals()[benchmark_repeat_fixture.__name__] = benchmark_repeat_fixture + time_out_assert_repeats = config.getoption("--time-out-assert-repeats") + if time_out_assert_repeats != 1: + + @pytest.fixture( + name="time_out_assert_repeat", + autouse=True, + params=[ + pytest.param(repeat, id=f"time_out_assert_repeat{repeat:03d}") + for repeat in range(time_out_assert_repeats) + ], + ) + def time_out_assert_repeat_fixture(request: SubRequest) -> int: + return request.param + + globals()[time_out_assert_repeat_fixture.__name__] = time_out_assert_repeat_fixture + def pytest_collection_modifyitems(session, config: pytest.Config, items: List[pytest.Function]): # https://github.com/pytest-dev/pytest/issues/3730#issuecomment-567142496 diff --git a/tests/ether.py b/tests/ether.py new file mode 100644 index 000000000000..f1ecb3efb415 --- /dev/null +++ b/tests/ether.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Optional + +if TYPE_CHECKING: + from tests.util.misc import TestId + +# NOTE: Do not just put any useful thing here. This is specifically for making +# fixture values globally available during tests. In _most_ cases fixtures +# should be directly requested using normal mechanisms. Very little should +# be put here. + +# NOTE: When using this module do not import the attributes directly. Rather, import +# something like `from tests import ether`. Importing attributes directly will +# result in you likely getting the default `None` values since they are not +# populated until tests are running. + +record_property: Optional[Callable[[str, object], None]] = None +test_id: Optional[TestId] = None diff --git a/tests/process_benchmarks.py b/tests/process_benchmarks.py deleted file mode 100644 index 6a570b94b583..000000000000 --- a/tests/process_benchmarks.py +++ /dev/null @@ -1,240 +0,0 @@ -from __future__ import annotations - -import json -import random -import re -from collections import defaultdict -from dataclasses import dataclass, field -from pathlib import Path -from statistics import StatisticsError, mean, stdev -from typing import Any, Dict, List, Set, TextIO, Tuple, final - -import click -import lxml.etree - - -@final -@dataclass(frozen=True, order=True) -class Result: - file_path: Path - test_path: Tuple[str, ...] - label: str - line: int = field(compare=False) - durations: Tuple[float, ...] = field(compare=False) - limit: float = field(compare=False) - - def marshal(self) -> Dict[str, Any]: - return { - "file_path": self.file_path.as_posix(), - "test_path": self.test_path, - "label": self.label, - "duration": { - "all": self.durations, - "min": min(self.durations), - "max": max(self.durations), - "mean": mean(self.durations), - }, - } - - def link(self, prefix: str, line_separator: str) -> str: - return f"{prefix}{self.file_path.as_posix()}{line_separator}{self.line}" - - -def sub(matchobj: re.Match[str]) -> str: - result = "" - - if matchobj.group("start") == "[": - result += "[" - - if matchobj.group("start") == matchobj.group("end") == "-": - result += "-" - - if matchobj.group("end") == "]": - result += "]" - - return result - - -@click.command(context_settings={"help_option_names": ["-h", "--help"]}) -@click.option( - "--xml", - "xml_file", - required=True, - type=click.File(), - help="The benchmarks JUnit XML results file", -) -@click.option( - "--link-prefix", - default="", - help="Prefix for output links such as for web links instead of IDE links", - show_default=True, -) -@click.option( - "--link-line-separator", - default=":", - help="The separator between the path and the line number, such as : for local links and #L on GitHub", - show_default=True, -) -@click.option( - "--output", - default="-", - type=click.File(mode="w", encoding="utf-8", lazy=True, atomic=True), - help="Output file, - for stdout", - show_default=True, -) -# TODO: anything but this pattern for output types -@click.option( - "--markdown/--no-markdown", - help="Use markdown as output format", - show_default=True, -) -@click.option( - "--percent-margin", - default=15, - type=int, - help="Highlight results with maximums within this percent of the limit", - show_default=True, -) -@click.option( - "--randomoji/--determimoji", - help="🍿", - show_default=True, -) -def main( - xml_file: TextIO, - link_prefix: str, - link_line_separator: str, - output: TextIO, - markdown: bool, - percent_margin: int, - randomoji: bool, -) -> None: - tree = lxml.etree.parse(xml_file) - root = tree.getroot() - benchmarks = root.find("testsuite[@name='benchmarks']") - - # raw_durations: defaultdict[Tuple[str, ...], List[Result]] = defaultdict(list) - - cases_by_test_path: defaultdict[Tuple[str, ...], List[lxml.etree.Element]] = defaultdict(list) - for case in benchmarks.findall("testcase"): - raw_name = case.attrib["name"] - name = re.sub(r"(?P[-\[])benchmark_repeat\d{3}(?P[-\])])", sub, raw_name) - # TODO: seems to duplicate the class and function name, though not the parametrizations - test_path = ( - *case.attrib["classname"].split("."), - name, - ) - cases_by_test_path[test_path].append(case) - - results: List[Result] = [] - for test_path, cases in cases_by_test_path.items(): - labels: Set[str] = set() - for case in cases: - properties = case.find("properties") - labels.update(property.attrib["name"].partition(":")[2] for property in properties) - - for label in labels: - query = "properties/property[@name='{property}:{label}']" - - durations = [ - float(property.attrib["value"]) - for case in cases - for property in case.xpath(query.format(label=label, property="duration")) - ] - - a_case = cases[0] - - file_path: Path - [file_path] = [ - Path(property.attrib["value"]) for property in a_case.xpath(query.format(label=label, property="path")) - ] - - line: int - [line] = [ - int(property.attrib["value"]) for property in a_case.xpath(query.format(label=label, property="line")) - ] - - limit: float - [limit] = [ - float(property.attrib["value"]) - for property in a_case.xpath(query.format(label=label, property="limit")) - ] - - results.append( - Result( - file_path=file_path, - test_path=test_path, - line=line, - label=label, - durations=tuple(durations), - limit=limit, - ) - ) - - if not markdown: - for result in results: - link = result.link(prefix=link_prefix, line_separator=link_line_separator) - dumped = json.dumps(result.marshal()) - output.write(f"{link} {dumped}\n") - else: - output.write("| Test | 🍿 | Mean | Max | 3σ | Limit | Percent |\n") - output.write("| --- | --- | --- | --- | --- | --- | --- |\n") - for result in sorted(results): - link_url = result.link(prefix=link_prefix, line_separator=link_line_separator) - - mean_str = "-" - three_sigma_str = "-" - if len(result.durations) > 1: - durations_mean = mean(result.durations) - mean_str = f"{durations_mean:.3f} s" - - try: - three_sigma_str = f"{durations_mean + 3 * stdev(result.durations):.3f} s" - except StatisticsError: - pass - - durations_max = max(result.durations) - max_str = f"{durations_max:.3f} s" - - limit_str = f"{result.limit:.3f} s" - - percent = 100 * durations_max / result.limit - if percent >= 100: - # intentionally biasing towards 🍄 - choices = "🍄🍄🍎🍅" # 🌶️🍉🍒🍓 - elif percent >= (100 - percent_margin): - choices = "🍋🍌" # 🍍🌽 - else: - choices = "🫛🍈🍏🍐🥝🥒🥬🥦" - - marker: str - if randomoji: - marker = random.choice(choices) - else: - marker = choices[0] - - percent_str = f"{percent:.0f} %" - - test_path_str = ".".join(result.test_path[1:]) - - test_link_text: str - if result.label == "": - test_link_text = f"`{test_path_str}`" - else: - test_link_text = f"`{test_path_str}` - {result.label}" - - output.write( - f"| [{test_link_text}]({link_url})" - + f" | {marker}" - + f" | {mean_str}" - + f" | {max_str}" - + f" | {three_sigma_str}" - + f" | {limit_str}" - + f" | {percent_str}" - + " |\n" - ) - - -if __name__ == "__main__": - # pylint: disable = no-value-for-parameter - main() diff --git a/tests/process_junit.py b/tests/process_junit.py new file mode 100644 index 000000000000..13b5fc6c558c --- /dev/null +++ b/tests/process_junit.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +import dataclasses +import json +import random +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from statistics import StatisticsError, mean, stdev +from typing import Any, Dict, List, Optional, TextIO, Tuple, Type, final + +import click +import lxml.etree + +from tests.util.misc import BenchmarkData, DataTypeProtocol, TestId +from tests.util.time_out_assert import TimeOutAssertData + +supported_data_types: List[Type[DataTypeProtocol]] = [TimeOutAssertData, BenchmarkData] +supported_data_types_by_tag: Dict[str, Type[DataTypeProtocol]] = {cls.tag: cls for cls in supported_data_types} + + +@final +@dataclass(frozen=True, order=True) +class Result: + file_path: Path + test_path: Tuple[str, ...] + ids: Tuple[str, ...] + label: str + line: int = field(compare=False) + durations: Tuple[float, ...] = field(compare=False) + limit: float = field(compare=False) + + def marshal(self) -> Dict[str, Any]: + return { + "file_path": self.file_path.as_posix(), + "test_path": self.test_path, + "label": self.label, + "duration": { + "all": self.durations, + "min": min(self.durations), + "max": max(self.durations), + "mean": mean(self.durations), + }, + } + + def link(self, prefix: str, line_separator: str) -> str: + return f"{prefix}{self.file_path.as_posix()}{line_separator}{self.line}" + + +@final +@dataclasses.dataclass(frozen=True) +class EventId: + test_id: TestId + tag: str + line: int + path: Path + label: str + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "--xml", + "xml_file", + required=True, + type=click.File(), + help="The benchmarks JUnit XML results file", +) +@click.option( + "--link-prefix", + default="", + help="Prefix for output links such as for web links instead of IDE links", + show_default=True, +) +@click.option( + "--link-line-separator", + default=":", + help="The separator between the path and the line number, such as : for local links and #L on GitHub", + show_default=True, +) +@click.option( + "--output", + default="-", + type=click.File(mode="w", encoding="utf-8", lazy=True, atomic=True), + help="Output file, - for stdout", + show_default=True, +) +# TODO: anything but this pattern for output types +@click.option( + "--markdown/--no-markdown", + help="Use markdown as output format", + show_default=True, +) +@click.option( + "--percent-margin", + default=15, + type=int, + help="Highlight results with maximums within this percent of the limit", + show_default=True, +) +@click.option( + "--randomoji/--determimoji", + help="🍿", + show_default=True, +) +# TODO: subcommands? +@click.option( + "--type", + "tag", + type=click.Choice([cls.tag for cls in supported_data_types]), + help="The type of data to process", + required=True, + show_default=True, +) +@click.option( + "--limit", + "result_count_limit", + type=int, + help="Limit the number of results to output.", +) +def main( + xml_file: TextIO, + link_prefix: str, + link_line_separator: str, + output: TextIO, + markdown: bool, + percent_margin: int, + randomoji: bool, + tag: str, + result_count_limit: Optional[int], +) -> None: + data_type = supported_data_types_by_tag[tag] + + tree = lxml.etree.parse(xml_file) + root = tree.getroot() + + cases_by_test_id: defaultdict[TestId, List[lxml.etree.Element]] = defaultdict(list) + for suite in root.findall("testsuite"): + for case in suite.findall("testcase"): + if case.find("skipped") is not None: + continue + test_id_property = case.find("properties/property[@name='test_id']") + test_id = TestId.unmarshal(json.loads(test_id_property.attrib["value"])) + test_id = dataclasses.replace( + test_id, ids=tuple(id for id in test_id.ids if not id.startswith(f"{data_type.tag}_repeat")) + ) + cases_by_test_id[test_id].append(case) + + data_by_event_id: defaultdict[EventId, List[DataTypeProtocol]] = defaultdict(list) + for test_id, cases in cases_by_test_id.items(): + for case in cases: + for property in case.findall(f"properties/property[@name='{tag}']"): + tag = property.attrib["name"] + data = supported_data_types_by_tag[tag].unmarshal(json.loads(property.attrib["value"])) + event_id = EventId(test_id=test_id, tag=tag, line=data.line, path=data.path, label=data.label) + data_by_event_id[event_id].append(data) + + results: List[Result] = [] + for event_id, datas in data_by_event_id.items(): + [limit] = {data.limit for data in datas} + results.append( + Result( + file_path=event_id.path, + test_path=event_id.test_id.test_path, + ids=event_id.test_id.ids, + line=event_id.line, + durations=tuple(data.duration for data in datas), + limit=limit, + label=event_id.label, + ) + ) + + if result_count_limit is not None: + results = sorted(results, key=lambda result: max(result.durations) / result.limit, reverse=True) + results = results[:result_count_limit] + + handlers = { + BenchmarkData.tag: output_benchmark, + TimeOutAssertData.tag: output_time_out_assert, + } + handler = handlers[data_type.tag] + handler( + link_line_separator=link_line_separator, + link_prefix=link_prefix, + markdown=markdown, + output=output, + percent_margin=percent_margin, + randomoji=randomoji, + results=results, + ) + + +def output_benchmark( + link_line_separator: str, + link_prefix: str, + markdown: bool, + output: TextIO, + percent_margin: int, + randomoji: bool, + results: List[Result], +) -> None: + if not markdown: + for result in sorted(results): + link = result.link(prefix=link_prefix, line_separator=link_line_separator) + dumped = json.dumps(result.marshal()) + output.write(f"{link} {dumped}\n") + else: + output.write("# Benchmark Metrics\n\n") + + output.write("| Test | 🍿 | Mean | Max | 3σ | Limit | Percent |\n") + output.write("| --- | --- | --- | --- | --- | --- | --- |\n") + for result in sorted(results): + link_url = result.link(prefix=link_prefix, line_separator=link_line_separator) + + mean_str = "-" + three_sigma_str = "-" + if len(result.durations) > 1: + durations_mean = mean(result.durations) + mean_str = f"{durations_mean:.3f} s" + + try: + three_sigma_str = f"{durations_mean + 3 * stdev(result.durations):.3f} s" + except StatisticsError: + pass + + durations_max = max(result.durations) + max_str = f"{durations_max:.3f} s" + + limit_str = f"{result.limit:.3f} s" + + percent = 100 * durations_max / result.limit + if percent >= 100: + # intentionally biasing towards 🍄 + choices = "🍄🍄🍎🍅" # 🌶️🍉🍒🍓 + elif percent >= (100 - percent_margin): + choices = "🍋🍌" # 🍍🌽 + else: + choices = "🫛🍈🍏🍐🥝🥒🥬🥦" + + marker: str + if randomoji: + marker = random.choice(choices) + else: + marker = choices[0] + + percent_str = f"{percent:.0f} %" + + test_path_str = ".".join(result.test_path[1:]) + if len(result.ids) > 0: + test_path_str += f"[{'-'.join(result.ids)}]" + + test_link_text: str + if result.label == "": + test_link_text = f"`{test_path_str}`" + else: + test_link_text = f"`{test_path_str}` - {result.label}" + + output.write( + f"| [{test_link_text}]({link_url})" + + f" | {marker}" + + f" | {mean_str}" + + f" | {max_str}" + + f" | {three_sigma_str}" + + f" | {limit_str}" + + f" | {percent_str}" + + " |\n" + ) + + +def output_time_out_assert( + link_line_separator: str, + link_prefix: str, + markdown: bool, + output: TextIO, + percent_margin: int, + randomoji: bool, + results: List[Result], +) -> None: + if not markdown: + for result in sorted(results): + link = result.link(prefix=link_prefix, line_separator=link_line_separator) + dumped = json.dumps(result.marshal()) + output.write(f"{link} {dumped}\n") + else: + output.write("# Time Out Assert Metrics\n\n") + + output.write("| Test | 🍿 | Mean | Max | 3σ | Limit | Percent |\n") + output.write("| --- | --- | --- | --- | --- | --- | --- |\n") + for result in sorted(results): + link_url = result.link(prefix=link_prefix, line_separator=link_line_separator) + + mean_str = "-" + three_sigma_str = "-" + if len(result.durations) > 1: + durations_mean = mean(result.durations) + mean_str = f"{durations_mean:.3f} s" + + try: + three_sigma_str = f"{durations_mean + 3 * stdev(result.durations):.3f} s" + except StatisticsError: + pass + + durations_max = max(result.durations) + max_str = f"{durations_max:.3f} s" + + limit_str = f"{result.limit:.3f} s" + + percent = 100 * durations_max / result.limit + if percent >= 100: + # intentionally biasing towards 🍄 + choices = "🍄🍄🍎🍅" # 🌶️🍉🍒🍓 + elif percent >= (100 - percent_margin): + choices = "🍋🍌" # 🍍🌽 + else: + choices = "🫛🍈🍏🍐🥝🥒🥬🥦" + + marker: str + if randomoji: + marker = random.choice(choices) + else: + marker = choices[0] + + percent_str = f"{percent:.0f} %" + + test_path_str = ".".join(result.test_path[1:]) + if len(result.ids) > 0: + test_path_str += f"[{'-'.join(result.ids)}]" + + test_link_text: str + if result.label == "": + # TODO: but could be in different files too + test_link_text = f"`{test_path_str}` - {result.line}" + else: + test_link_text = f"`{test_path_str}` - {result.label}" + + output.write( + f"| [{test_link_text}]({link_url})" + + f" | {marker}" + + f" | {mean_str}" + + f" | {max_str}" + + f" | {three_sigma_str}" + + f" | {limit_str}" + + f" | {percent_str}" + + " |\n" + ) + + +if __name__ == "__main__": + # pylint: disable = no-value-for-parameter + main() diff --git a/tests/util/misc.py b/tests/util/misc.py index 06854c31c50b..0156ad29409d 100644 --- a/tests/util/misc.py +++ b/tests/util/misc.py @@ -5,6 +5,7 @@ import enum import functools import gc +import json import logging import os import pathlib @@ -13,28 +14,52 @@ import sys from concurrent.futures import Future from dataclasses import dataclass, field -from inspect import getframeinfo, stack +from pathlib import Path from statistics import mean from textwrap import dedent from time import thread_time from types import TracebackType -from typing import Any, Awaitable, Callable, Collection, Dict, Iterator, List, Optional, TextIO, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + ClassVar, + Collection, + Dict, + Iterator, + List, + Optional, + Protocol, + TextIO, + Tuple, + Type, + TypeVar, + Union, + cast, + final, +) import aiohttp import pytest + +# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469 +from _pytest.nodes import Node from aiohttp import web from chia_rs import Coin -from typing_extensions import Protocol, final import chia +import tests from chia.full_node.mempool import Mempool from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.condition_opcodes import ConditionOpcode from chia.util.hash import std_hash from chia.util.ints import uint16, uint32, uint64 +from chia.util.misc import caller_file_and_line from chia.util.network import WebServer from chia.wallet.util.compute_hints import HintedCoin from chia.wallet.wallet_node import WalletNode +from tests import ether from tests.core.data_layer.util import ChiaRoot @@ -70,11 +95,6 @@ def manage_gc(mode: GcMode) -> Iterator[None]: gc.disable() -def caller_file_and_line(distance: int = 1) -> Tuple[str, int]: - caller = getframeinfo(stack()[distance + 1][0]) - return caller.filename, caller.lineno - - @dataclasses.dataclass(frozen=True) class RuntimeResults: start: float @@ -180,7 +200,12 @@ def measure_runtime( overhead: Optional[float] = None, print_results: bool = True, ) -> Iterator[Future[RuntimeResults]]: - entry_file, entry_line = caller_file_and_line() + entry_file, entry_line = caller_file_and_line( + relative_to=( + pathlib.Path(chia.__file__).parent.parent, + pathlib.Path(tests.__file__).parent.parent, + ) + ) results_future: Future[RuntimeResults] = Future() @@ -210,6 +235,43 @@ def measure_runtime( print(results.block(label=label)) +@final +@dataclasses.dataclass(frozen=True) +class BenchmarkData: + if TYPE_CHECKING: + _protocol_check: ClassVar[DataTypeProtocol] = cast("BenchmarkData", None) + + tag: ClassVar[str] = "benchmark" + + duration: float + path: pathlib.Path + line: int + limit: float + + label: str + + __match_args__: ClassVar[Tuple[str, ...]] = () + + @classmethod + def unmarshal(cls, marshalled: Dict[str, Any]) -> BenchmarkData: + return cls( + duration=marshalled["duration"], + path=pathlib.Path(marshalled["path"]), + line=int(marshalled["line"]), + limit=marshalled["limit"], + label=marshalled["label"], + ) + + def marshal(self) -> Dict[str, Any]: + return { + "duration": self.duration, + "path": self.path.as_posix(), + "line": self.line, + "limit": self.limit, + "label": self.label, + } + + @final @dataclasses.dataclass class _AssertRuntime: @@ -236,6 +298,7 @@ class _AssertRuntime: # https://github.com/pytest-dev/pytest/issues/2057 seconds: float + # TODO: Optional? label: str = "" clock: Callable[[], float] = thread_time gc_mode: GcMode = GcMode.disable @@ -247,10 +310,14 @@ class _AssertRuntime: runtime_manager: Optional[contextlib.AbstractContextManager[Future[RuntimeResults]]] = None runtime_results_callable: Optional[Future[RuntimeResults]] = None enable_assertion: bool = True - record_property: Optional[Callable[[str, object], None]] = None def __enter__(self) -> Future[AssertRuntimeResults]: - self.entry_file, self.entry_line = caller_file_and_line() + self.entry_file, self.entry_line = caller_file_and_line( + relative_to=( + pathlib.Path(chia.__file__).parent.parent, + pathlib.Path(tests.__file__).parent.parent, + ) + ) self.runtime_manager = measure_runtime( clock=self.clock, gc_mode=self.gc_mode, overhead=self.overhead, print_results=False @@ -290,16 +357,19 @@ def __exit__( if self.print: print(results.block(label=self.label)) - if self.record_property is not None: - self.record_property(f"duration:{self.label}", results.duration) - - relative_path_str = ( - pathlib.Path(results.entry_file).relative_to(pathlib.Path(chia.__file__).parent.parent).as_posix() + if ether.record_property is not None: + data = BenchmarkData( + duration=results.duration, + path=pathlib.Path(self.entry_file), + line=self.entry_line, + limit=self.seconds, + label=self.label, ) - self.record_property(f"path:{self.label}", relative_path_str) - self.record_property(f"line:{self.label}", results.entry_line) - self.record_property(f"limit:{self.label}", self.seconds) + ether.record_property( # pylint: disable=E1102 + data.tag, + json.dumps(data.marshal(), ensure_ascii=True, sort_keys=True), + ) if exc_type is None and self.enable_assertion: __tracebackhide__ = True @@ -310,15 +380,13 @@ def __exit__( @dataclasses.dataclass class BenchmarkRunner: enable_assertion: bool = True - label: Optional[str] = None + test_id: Optional[TestId] = None overhead: Optional[float] = None - record_property: Optional[Callable[[str, object], None]] = None @functools.wraps(_AssertRuntime) def assert_runtime(self, *args: Any, **kwargs: Any) -> _AssertRuntime: kwargs.setdefault("enable_assertion", self.enable_assertion) kwargs.setdefault("overhead", self.overhead) - kwargs.setdefault("record_property", self.record_property) return _AssertRuntime(*args, **kwargs) @@ -484,3 +552,83 @@ async def handler(self, request: web.Request) -> web.Response: async def await_closed(self) -> None: self.web_server.close() await self.web_server.await_closed() + + +@final +@dataclasses.dataclass(frozen=True) +class TestId: + platform: str + test_path: Tuple[str, ...] + ids: Tuple[str, ...] + + @classmethod + def create(cls, node: Node, platform: str = sys.platform) -> TestId: + test_path: List[str] = [] + temp_node = node + while True: + name: str + if isinstance(temp_node, pytest.Function): + name = temp_node.originalname + elif isinstance(temp_node, pytest.Package): + # must check before pytest.Module since Package is a subclass + name = temp_node.name + elif isinstance(temp_node, pytest.Module): + name = temp_node.name[:-3] + else: + name = temp_node.name + test_path.insert(0, name) + if isinstance(temp_node.parent, pytest.Session) or temp_node.parent is None: + break + temp_node = temp_node.parent + + # TODO: can we avoid parsing the id's etc from the node name? + test_name, delimiter, rest = node.name.partition("[") + ids: Tuple[str, ...] + if delimiter == "": + ids = () + else: + ids = tuple(rest.rstrip("]").split("-")) + + return cls( + platform=platform, + test_path=tuple(test_path), + ids=ids, + ) + + @classmethod + def unmarshal(cls, marshalled: Dict[str, Any]) -> TestId: + return cls( + platform=marshalled["platform"], + test_path=tuple(marshalled["test_path"]), + ids=tuple(marshalled["ids"]), + ) + + def marshal(self) -> Dict[str, Any]: + return { + "platform": self.platform, + "test_path": self.test_path, + "ids": self.ids, + } + + +T = TypeVar("T") + + +@dataclasses.dataclass(frozen=True) +class DataTypeProtocol(Protocol): + tag: ClassVar[str] + + line: int + path: Path + label: str + duration: float + limit: float + + __match_args__: ClassVar[Tuple[str, ...]] = () + + @classmethod + def unmarshal(cls: Type[T], marshalled: Dict[str, Any]) -> T: + ... + + def marshal(self) -> Dict[str, Any]: + ... diff --git a/tests/util/time_out_assert.py b/tests/util/time_out_assert.py index 27515a767d74..4252d35756fc 100644 --- a/tests/util/time_out_assert.py +++ b/tests/util/time_out_assert.py @@ -1,39 +1,130 @@ from __future__ import annotations import asyncio +import dataclasses +import json import logging +import pathlib import time -from typing import Callable +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Tuple, cast, final +import chia +import tests from chia.protocols.protocol_message_types import ProtocolMessageTypes +from chia.util.misc import caller_file_and_line from chia.util.timing import adjusted_timeout +from tests import ether +from tests.util.misc import DataTypeProtocol log = logging.getLogger(__name__) -async def time_out_assert_custom_interval(timeout: float, interval, function, value=True, *args, **kwargs): +@final +@dataclasses.dataclass(frozen=True) +class TimeOutAssertData: + if TYPE_CHECKING: + _protocol_check: ClassVar[DataTypeProtocol] = cast("TimeOutAssertData", None) + + tag: ClassVar[str] = "time_out_assert" + + duration: float + path: pathlib.Path + line: int + limit: float + timed_out: bool + + label: str = "" + + __match_args__: ClassVar[Tuple[str, ...]] = () + + @classmethod + def unmarshal(cls, marshalled: Dict[str, Any]) -> TimeOutAssertData: + return cls( + duration=marshalled["duration"], + path=pathlib.Path(marshalled["path"]), + line=int(marshalled["line"]), + limit=marshalled["limit"], + timed_out=marshalled["timed_out"], + ) + + def marshal(self) -> Dict[str, Any]: + return { + "duration": self.duration, + "path": self.path.as_posix(), + "line": self.line, + "limit": self.limit, + "timed_out": self.timed_out, + } + + +async def time_out_assert_custom_interval( + timeout: float, interval, function, value=True, *args, stack_distance=0, **kwargs +): __tracebackhide__ = True + entry_file, entry_line = caller_file_and_line( + distance=stack_distance + 1, + relative_to=( + pathlib.Path(chia.__file__).parent.parent, + pathlib.Path(tests.__file__).parent.parent, + ), + ) + timeout = adjusted_timeout(timeout=timeout) - start = time.time() - while time.time() - start < timeout: - if asyncio.iscoroutinefunction(function): - f_res = await function(*args, **kwargs) - else: - f_res = function(*args, **kwargs) - if value == f_res: - return None - await asyncio.sleep(interval) - assert False, f"Timed assertion timed out after {timeout} seconds: expected {value!r}, got {f_res!r}" + start = time.monotonic() + duration = 0.0 + timed_out = False + try: + while True: + if asyncio.iscoroutinefunction(function): + f_res = await function(*args, **kwargs) + else: + f_res = function(*args, **kwargs) + + if value == f_res: + return None + + now = time.monotonic() + duration = now - start + + if duration > timeout: + timed_out = True + assert False, f"Timed assertion timed out after {timeout} seconds: expected {value!r}, got {f_res!r}" + + await asyncio.sleep(min(interval, timeout - duration)) + finally: + if ether.record_property is not None: + data = TimeOutAssertData( + duration=duration, + path=pathlib.Path(entry_file), + line=entry_line, + limit=timeout, + timed_out=timed_out, + ) + + ether.record_property( # pylint: disable=E1102 + data.tag, + json.dumps(data.marshal(), ensure_ascii=True, sort_keys=True), + ) async def time_out_assert(timeout: int, function, value=True, *args, **kwargs): __tracebackhide__ = True - await time_out_assert_custom_interval(timeout, 0.05, function, value, *args, **kwargs) + await time_out_assert_custom_interval( + timeout, + 0.05, + function, + value, + *args, + **kwargs, + stack_distance=1, + ) async def time_out_assert_not_none(timeout: float, function, *args, **kwargs): + # TODO: rework to leverage time_out_assert_custom_interval() such as by allowing + # value to be a callable __tracebackhide__ = True timeout = adjusted_timeout(timeout=timeout)