Skip to content

Commit

Permalink
Ensure much stronger type safety for consumers
Browse files Browse the repository at this point in the history
  • Loading branch information
cburgdorf committed Sep 12, 2018
1 parent 26ca6f5 commit 9a8086a
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 21 deletions.
1 change: 1 addition & 0 deletions lahja/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .misc import ( # noqa: F401
BaseEvent,
BaseRequestResponseEvent,
BroadcastConfig,
Subscription,
)
32 changes: 23 additions & 9 deletions lahja/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Type,
TypeVar,
cast,
)
import uuid
Expand All @@ -21,6 +22,7 @@
from .misc import (
TRANSPARENT_EVENT,
BaseEvent,
BaseRequestResponseEvent,
BroadcastConfig,
Subscription,
)
Expand Down Expand Up @@ -114,7 +116,9 @@ def broadcast(self, item: BaseEvent, config: Optional[BroadcastConfig] = None) -
item._origin = self.name
self._sending_queue.put_nowait((item, config))

async def request(self, item: BaseEvent) -> BaseEvent:
TResponse = TypeVar('TResponse', bound=BaseEvent)

async def request(self, item: BaseRequestResponseEvent[TResponse]) -> TResponse:
"""
Broadcast an instance of :class:`~lahja.misc.BaseEvent` on the event bus and immediately
wait on an expected answer of type :class:`~lahja.misc.BaseEvent`.
Expand All @@ -129,11 +133,15 @@ async def request(self, item: BaseEvent) -> BaseEvent:

result = await future

return cast(BaseEvent, result)
# We ignore the warning (not error) of returning `Any`. Since `TResponse` is
# nothing we can cast to, I guess we can't do any better.
return result # type: ignore

TSubscribeEvent = TypeVar('TSubscribeEvent', bound=BaseEvent)

def subscribe(self,
event_type: Type[BaseEvent],
handler: Callable[[BaseEvent], None]) -> Subscription:
event_type: Type[TSubscribeEvent],
handler: Callable[[TSubscribeEvent], None]) -> Subscription:
"""
Subscribe to receive updates for any event that matches the specified event type.
A handler is passed as a second argument an :class:`~lahja.misc.Subscription` is returned
Expand All @@ -142,13 +150,17 @@ def subscribe(self,
if event_type not in self._handler:
self._handler[event_type] = []

self._handler[event_type].append(handler)
casted_handler = cast(Callable[[BaseEvent], Any], handler)

self._handler[event_type].append(casted_handler)

return Subscription(lambda: self._handler[event_type].remove(handler))
return Subscription(lambda: self._handler[event_type].remove(casted_handler))

TStreamEvent = TypeVar('TStreamEvent', bound=BaseEvent)

async def stream(self,
event_type: Type[BaseEvent],
max: Optional[int] = None) -> AsyncIterable[BaseEvent]:
event_type: Type[TStreamEvent],
max: Optional[int] = None) -> AsyncIterable[TStreamEvent]:
"""
Stream all events that match the specified event type. This returns an
``AsyncIterable[BaseEvent]`` which can be consumed through an ``async for`` loop.
Expand All @@ -175,7 +187,9 @@ async def stream(self,
if i is not None and i >= cast(int, max):
break

async def wait_for(self, event_type: Type[BaseEvent]) -> BaseEvent: # type: ignore
TWaitForEvent = TypeVar('TWaitForEvent', bound=BaseEvent)

async def wait_for(self, event_type: Type[TWaitForEvent]) -> TWaitForEvent: # type: ignore
"""
Wait for a single instance of an event that matches the specified event type.
"""
Expand Down
9 changes: 9 additions & 0 deletions lahja/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import ( # noqa: F401
Any,
Callable,
Generic,
Optional,
TypeVar,
)


Expand Down Expand Up @@ -40,6 +42,13 @@ def broadcast_config(self) -> BroadcastConfig:
)


TResponse = TypeVar('TResponse', bound=BaseEvent)


class BaseRequestResponseEvent(BaseEvent, Generic[TResponse]):
pass


class TransparentEvent(BaseEvent):
"""
This event is used to create artificial activity so that code that
Expand Down
46 changes: 35 additions & 11 deletions tests/core/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,70 @@
import asyncio
from typing import (
Any,
)

import pytest

from lahja import (
BaseEvent,
BaseRequestResponseEvent,
EventBus,
)


class DummyRequest(BaseEvent):
pass
property_of_dummy_request = None


class DummyResponse(BaseEvent):
pass
property_of_dummy_response = None

def __init__(self, something: Any) -> None:
pass


class DummyRequestPair(BaseRequestResponseEvent[DummyResponse]):
property_of_dummy_request_pair = None


@pytest.mark.asyncio
async def test_request():
async def test_request() -> None:
bus = EventBus()
endpoint = bus.create_endpoint('test')
bus.start()
endpoint.connect()

endpoint.subscribe(
DummyRequest,
lambda ev: endpoint.broadcast(DummyResponse(), ev.broadcast_config())
DummyRequestPair,
lambda ev: endpoint.broadcast(
# Accessing `ev.property_of_dummy_request_pair` here allows us to validate
# mypy has the type information we think it has. We run mypy on the tests.
DummyResponse(ev.property_of_dummy_request_pair), ev.broadcast_config()
)
)

response = await endpoint.request(DummyRequest())
response = await endpoint.request(DummyRequestPair())
# Accessing `ev.property_of_dummy_response` here allows us to validate
# mypy has the type information we think it has. We run mypy on the tests.
print(response.property_of_dummy_response)
assert isinstance(response, DummyResponse)
endpoint.stop()
bus.stop()


@pytest.mark.asyncio
async def test_stream_with_max():
async def test_stream_with_max() -> None:
bus = EventBus()
endpoint = bus.create_endpoint('test')
bus.start()
endpoint.connect()
stream_counter = 0

async def stream_response():
async for _ in endpoint.stream(DummyRequest, max=2): # noqa: F841
async def stream_response() -> None:
async for event in endpoint.stream(DummyRequest, max=2):
# Accessing `ev.property_of_dummy_request` here allows us to validate
# mypy has the type information we think it has. We run mypy on the tests.
print(event.property_of_dummy_request)
nonlocal stream_counter
stream_counter += 1

Expand All @@ -60,15 +81,18 @@ async def stream_response():


@pytest.mark.asyncio
async def test_wait_for():
async def test_wait_for() -> None:
bus = EventBus()
endpoint = bus.create_endpoint('test')
bus.start()
endpoint.connect()
received = None

async def stream_response():
async def stream_response() -> None:
request = await endpoint.wait_for(DummyRequest)
# Accessing `ev.property_of_dummy_request` here allows us to validate
# mypy has the type information we think it has. We run mypy on the tests.
print(request.property_of_dummy_request)
nonlocal received
received = request

Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_import.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@


def test_import():
def test_import() -> None:
import lahja # noqa: F401
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ basepython=python
extras=lint
commands=
mypy lahja --ignore-missing-imports --strict
mypy {toxinidir}/tests/core --follow-imports=silent --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs --disallow-any-generics
flake8 {toxinidir}/lahja {toxinidir}/tests
isort --recursive --check-only --diff {toxinidir}/lahja {toxinidir}/tests

0 comments on commit 9a8086a

Please sign in to comment.