Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch transaction queue completion to a new ValuedEvent #17305

Merged
merged 2 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chia/full_node/full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ async def _handle_one_transaction(self, entry: TransactionQueueEntry) -> None:
peer = entry.peer
try:
inc_status, err = await self.add_transaction(entry.transaction, entry.spend_name, peer, entry.test)
entry.done.set_result((inc_status, err))
entry.done.set((inc_status, err))
except asyncio.CancelledError:
error_stack = traceback.format_exc()
self.log.debug(f"Cancelling _handle_one_transaction, closing: {error_stack}")
Expand Down
2 changes: 1 addition & 1 deletion chia/full_node/full_node_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ async def send_transaction(
await self.full_node.transaction_queue.put(queue_entry, peer_id=None, high_priority=True)
try:
with anyio.fail_after(delay=45):
status, error = await queue_entry.done
status, error = await queue_entry.done.wait()
except TimeoutError:
response = wallet_protocol.TransactionAck(spend_name, uint8(MempoolInclusionStatus.PENDING), None)
else:
Expand Down
6 changes: 3 additions & 3 deletions chia/types/transaction_queue_entry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass, field
from typing import Optional, Tuple

Expand All @@ -9,6 +8,7 @@
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.spend_bundle import SpendBundle
from chia.util.errors import Err
from chia.util.misc import ValuedEvent


@dataclass(frozen=True)
Expand All @@ -22,7 +22,7 @@ class TransactionQueueEntry:
spend_name: bytes32
peer: Optional[WSChiaConnection] = field(compare=False)
test: bool = field(compare=False)
done: asyncio.Future[Tuple[MempoolInclusionStatus, Optional[Err]]] = field(
default_factory=asyncio.Future,
done: ValuedEvent[Tuple[MempoolInclusionStatus, Optional[Err]]] = field(
default_factory=ValuedEvent,
compare=False,
)
25 changes: 25 additions & 0 deletions chia/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Any,
AsyncContextManager,
AsyncIterator,
ClassVar,
Collection,
ContextManager,
Dict,
Expand Down Expand Up @@ -374,3 +375,27 @@ async def split_async_manager(manager: AsyncContextManager[object], object: T) -
yield split
finally:
await split.exit(if_needed=True)


class ValuedEventSentinel:
pass


@dataclasses.dataclass
class ValuedEvent(Generic[T]):
_value_sentinel: ClassVar[ValuedEventSentinel] = ValuedEventSentinel()

_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event)
_value: Union[ValuedEventSentinel, T] = _value_sentinel

def set(self, value: T) -> None:
if not isinstance(self._value, ValuedEventSentinel):
raise Exception("Value already set")
self._value = value
self._event.set()

async def wait(self) -> T:
await self._event.wait()
if isinstance(self._value, ValuedEventSentinel):
raise Exception("Value not set despite event being set")
return self._value
107 changes: 106 additions & 1 deletion tests/util/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from __future__ import annotations

import contextlib
from typing import AsyncIterator, Iterator, List
from typing import AsyncIterator, Iterator, List, Optional, TypeVar

import anyio
import pytest

from chia.util.errors import InvalidPathError
from chia.util.misc import (
SplitAsyncManager,
SplitManager,
ValuedEvent,
format_bytes,
format_minutes,
split_async_manager,
split_manager,
to_batches,
validate_directory_writable,
)
from chia.util.timing import adjusted_timeout, backoff_times

T = TypeVar("T")


class TestMisc:
Expand Down Expand Up @@ -306,3 +311,103 @@ async def test_split_async_manager_raises_on_exit_without_entry() -> None:

with pytest.raises(Exception, match="^not yet entered$"):
await split.exit()


async def wait_for_valued_event_waiters(
event: ValuedEvent[T],
count: int,
timeout: float = 10,
) -> None:
with anyio.fail_after(delay=adjusted_timeout(timeout)):
for delay in backoff_times():
# ignoring the type since i'm hacking into the private attribute
# hopefully this is ok for testing and if it becomes invalid we
# will end up with an exception and can adjust then
if len(event._event._waiters) >= count: # type: ignore[attr-defined]
return
await anyio.sleep(delay)


@pytest.mark.anyio
async def test_valued_event_wait_already_set() -> None:
valued_event = ValuedEvent[int]()
value = 37
valued_event.set(value)

with anyio.fail_after(adjusted_timeout(10)):
result = await valued_event.wait()

assert result == value


@pytest.mark.anyio
async def test_valued_event_wait_not_yet_set() -> None:
valued_event = ValuedEvent[int]()
value = 37
result: Optional[int] = None

async def wait(valued_event: ValuedEvent[int]) -> None:
nonlocal result
result = await valued_event.wait()

with anyio.fail_after(adjusted_timeout(10)):
async with anyio.create_task_group() as task_group:
task_group.start_soon(wait, valued_event)
await wait_for_valued_event_waiters(event=valued_event, count=1)
valued_event.set(value)

assert result == value


@pytest.mark.anyio
async def test_valued_event_wait_blocks_when_not_set() -> None:
valued_event = ValuedEvent[int]()
with pytest.raises(TimeoutError):
# if we could just process until there are no pending events, that would be great
with anyio.fail_after(adjusted_timeout(1)):
await valued_event.wait()


@pytest.mark.anyio
async def test_valued_event_multiple_waits_all_get_values() -> None:
results: List[int] = []
valued_event = ValuedEvent[int]()
value = 37
task_count = 10

async def wait_and_append() -> None:
results.append(await valued_event.wait())

async with anyio.create_task_group() as task_group:
for i in range(task_count):
task_group.start_soon(wait_and_append, name=f"wait_and_append_{i}")

await wait_for_valued_event_waiters(event=valued_event, count=task_count)
valued_event.set(value)

assert results == [value] * task_count


@pytest.mark.anyio
async def test_valued_event_set_again_raises_and_does_not_change_value() -> None:
valued_event = ValuedEvent[int]()
value = 37
valued_event.set(value)

with pytest.raises(Exception, match="^Value already set$"):
valued_event.set(value + 1)

with anyio.fail_after(adjusted_timeout(10)):
result = await valued_event.wait()

assert result == value


@pytest.mark.anyio
async def test_valued_event_wait_raises_if_not_set() -> None:
valued_event = ValuedEvent[int]()
valued_event._event.set()

with pytest.raises(Exception, match="^Value not set despite event being set$"):
with anyio.fail_after(adjusted_timeout(10)):
await valued_event.wait()
Loading