Skip to content

Commit

Permalink
Have EventListener flush batch (#1246)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Oct 14, 2024
1 parent 32ff8ba commit a9943fb
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 45 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `griptape.configs.logging.JsonFormatter` for formatting logs as JSON.
- Request/response debug logging to all Prompt Drivers.
- `BaseEventListener.flush_events()` to flush events from an Event Listener.
- Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing.

### Changed

- **BREAKING**: `BaseEventListener.publish_event` `flush` argument. Use `BaseEventListener.flush_events()` instead.
- `_DefaultsConfig.logging_config` and `Defaults.drivers_config` are now lazily instantiated.
- `BaseTask.add_parent`/`BaseTask.add_child` now only add the parent/child task to the structure if it is not already present.
- `BaseEventListener.flush_events()` to flush events from an Event Listener.
- `BaseEventListener` no longer requires a thread lock for batching events.

### Fixed

- Structures not flushing events when not listening for `FinishStructureRunEvent`.

## \[0.33.0\] - 2024-10-09

Expand Down
48 changes: 29 additions & 19 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import logging
import threading
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin

if TYPE_CHECKING:
Expand All @@ -16,38 +16,48 @@


@define
class BaseEventListenerDriver(FuturesExecutorMixin, ABC):
class BaseEventListenerDriver(FuturesExecutorMixin, ExponentialBackoffMixin, ABC):
batched: bool = field(default=True, kw_only=True)
batch_size: int = field(default=10, kw_only=True)
thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

_batch: list[dict] = field(default=Factory(list), kw_only=True)

@property
def batch(self) -> list[dict]:
return self._batch

def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None:
self.futures_executor.submit(self._safe_try_publish_event, event, flush=flush)
def publish_event(self, event: BaseEvent | dict) -> None:
event_payload = event if isinstance(event, dict) else event.to_dict()

if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch)
self._batch = []
else:
self.futures_executor.submit(self._safe_publish_event_payload, event_payload)

def flush_events(self) -> None:
if self.batch:
self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch)
self._batch = []

@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None: ...

@abstractmethod
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ...

def _safe_try_publish_event(self, event: BaseEvent | dict, *, flush: bool) -> None:
try:
event_payload = event if isinstance(event, dict) else event.to_dict()

if self.batched:
with self.thread_lock:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size or flush:
self.try_publish_event_payload_batch(self.batch)
self._batch = []
return
else:
def _safe_publish_event_payload(self, event_payload: dict) -> None:
for attempt in self.retrying():
with attempt:
self.try_publish_event_payload(event_payload)
except Exception as e:
logger.error(e)
else:
logger.error("event listener driver failed after all retry attempts")

def _safe_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
for attempt in self.retrying():
with attempt:
self.try_publish_event_payload_batch(event_payload_batch)
else:
logger.error("event listener driver failed after all retry attempts")
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def validate_run_id(self, _: Attribute, structure_run_id: str) -> None:
"structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).",
)

def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None:
def publish_event(self, event: BaseEvent | dict) -> None:
from griptape.observability.observability import Observability

event_payload = event.to_dict() if isinstance(event, BaseEvent) else event
Expand All @@ -51,7 +51,7 @@ def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None
if span_id is not None:
event_payload["span_id"] = span_id

super().publish_event(event_payload, flush=flush)
super().publish_event(event_payload)

def try_publish_event_payload(self, event_payload: dict) -> None:
self._post_event(self._get_event_request(event_payload))
Expand Down
7 changes: 5 additions & 2 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
event_payload = self.handler(event)
if self.driver is not None:
if event_payload is not None and isinstance(event_payload, dict):
self.driver.publish_event(event_payload, flush=flush)
self.driver.publish_event(event_payload)
else:
self.driver.publish_event(event, flush=flush)
self.driver.publish_event(event)

if self.driver is not None and flush:
self.driver.flush_events()
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,52 @@


class TestBaseEventListenerDriver:
def test_publish_event(self):
def test_publish_event_no_batched(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(futures_executor_fn=lambda: executor)
driver = MockEventListenerDriver(batched=False, futures_executor=executor)
mock_event_payload = MockEvent().to_dict()

driver.publish_event(MockEvent().to_dict())
driver.publish_event(mock_event_payload)

executor.submit.assert_called_once()
executor.submit.assert_called_once_with(driver._safe_publish_event_payload, mock_event_payload)

def test__safe_try_publish_event(self):
driver = MockEventListenerDriver(batched=False)
def test_publish_event_yes_batched(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=True, futures_executor=executor)
mock_event_payload = MockEvent().to_dict()

for _ in range(4):
driver._safe_try_publish_event(MockEvent().to_dict(), flush=False)
assert len(driver.batch) == 0
# Publish 9 events to fill the batch
mock_event_payloads = [mock_event_payload for _ in range(0, 9)]
for mock_event_payload in mock_event_payloads:
driver.publish_event(mock_event_payload)

def test__safe_try_publish_event_batch(self):
driver = MockEventListenerDriver(batched=True)
assert len(driver._batch) == 9
executor.submit.assert_not_called()

for _ in range(0, 3):
driver._safe_try_publish_event(MockEvent().to_dict(), flush=False)
assert len(driver.batch) == 3
# Publish the 10th event to trigger the batch publish
driver.publish_event(mock_event_payload)

def test__safe_try_publish_event_batch_flush(self):
driver = MockEventListenerDriver(batched=True)
assert len(driver._batch) == 0
executor.submit.assert_called_once_with(
driver._safe_publish_event_payload_batch, [*mock_event_payloads, mock_event_payload]
)

def test_flush_events(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(batched=True, futures_executor=executor)
driver.try_publish_event_payload_batch = MagicMock(side_effect=driver.try_publish_event_payload)

driver.flush_events()
driver.try_publish_event_payload_batch.assert_not_called()
assert driver.batch == []
mock_event_payloads = [MockEvent().to_dict() for _ in range(0, 3)]
for mock_event_payload in mock_event_payloads:
driver.publish_event(mock_event_payload)
assert len(driver.batch) == 3

for _ in range(0, 3):
driver._safe_try_publish_event(MockEvent().to_dict(), flush=True)
driver.flush_events()
executor.submit.assert_called_once_with(driver._safe_publish_event_payload_batch, mock_event_payloads)
assert len(driver.batch) == 0
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
from unittest.mock import MagicMock, Mock

import pytest
Expand Down Expand Up @@ -45,8 +46,10 @@ def test_init(self, driver):

def test_publish_event_without_span_id(self, mock_post, driver):
event = MockEvent()
driver.publish_event(event, flush=True)
driver.publish_event(event)
driver.flush_events()

time.sleep(1) # Happens asynchronously, so need to wait for it to finish
mock_post.assert_called_with(
url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events",
json=[driver._get_event_request(event.to_dict())],
Expand All @@ -59,8 +62,10 @@ def test_publish_event_with_span_id(self, mock_post, driver):
observability_driver.get_span_id.return_value = "test"

with Observability(observability_driver=observability_driver):
driver.publish_event(event, flush=True)
driver.publish_event(event)
driver.flush_events()

time.sleep(1) # Happens asynchronously, so need to wait for it to finish
mock_post.assert_called_with(
url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events",
json=[driver._get_event_request({**event.to_dict(), "span_id": "test"})],
Expand All @@ -71,6 +76,7 @@ def test_try_publish_event_payload(self, mock_post, driver):
event = MockEvent()
driver.try_publish_event_payload(event.to_dict())

time.sleep(1) # Happens asynchronously, so need to wait for it to finish
mock_post.assert_called_once_with(
url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events",
json=driver._get_event_request(event.to_dict()),
Expand All @@ -82,6 +88,7 @@ def try_publish_event_payload_batch(self, mock_post, driver):
event = MockEvent()
driver.try_publish_event_payload(event.to_dict())

time.sleep(1) # Happens asynchronously, so need to wait for it to finish
mock_post.assert_called_with(
url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events",
json=driver._get_event_request(event.to_dict()),
Expand Down
28 changes: 26 additions & 2 deletions tests/unit/events/test_event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from griptape.structures import Pipeline
from griptape.tasks import ActionsSubtask, ToolkitTask
from tests.mocks.mock_event import MockEvent
from tests.mocks.mock_event_listener_driver import MockEventListenerDriver
from tests.mocks.mock_prompt_driver import MockPromptDriver
from tests.mocks.mock_tool.tool import MockTool

Expand Down Expand Up @@ -121,7 +122,7 @@ def event_handler(_: BaseEvent) -> None:
event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent])
event_listener.publish_event(mock_event)

mock_event_listener_driver.publish_event.assert_called_once_with(mock_event, flush=False)
mock_event_listener_driver.publish_event.assert_called_once_with(mock_event)

def test_publish_transformed_event(self):
mock_event_listener_driver = Mock()
Expand All @@ -134,7 +135,7 @@ def event_handler(event: BaseEvent):
event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent])
event_listener.publish_event(mock_event)

mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False)
mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()})

def test_context_manager(self):
e1 = EventListener()
Expand All @@ -153,3 +154,26 @@ def test_context_manager_multiple(self):
assert EventBus.event_listeners == [e1, e2, e3]

assert EventBus.event_listeners == [e1]

def test_publish_event_yes_flush(self):
mock_event_listener_driver = MockEventListenerDriver()
mock_event_listener_driver.flush_events = Mock(side_effect=mock_event_listener_driver.flush_events)

event_listener = EventListener(driver=mock_event_listener_driver, event_types=[MockEvent])
event_listener.publish_event(MockEvent(), flush=True)

mock_event_listener_driver.flush_events.assert_called_once()
assert mock_event_listener_driver.batch == []

def test_publish_event_no_flush(self):
mock_event_listener_driver = MockEventListenerDriver()
mock_event_listener_driver.flush_events = Mock(side_effect=mock_event_listener_driver.flush_events)

event_listener = EventListener(driver=mock_event_listener_driver, event_types=[MockEvent])
mock_event = MockEvent()
event_listener.publish_event(mock_event, flush=False)

mock_event_listener_driver.flush_events.assert_not_called()
assert mock_event_listener_driver.batch == [
mock_event.to_dict(),
]

0 comments on commit a9943fb

Please sign in to comment.