diff --git a/CHANGELOG.md b/CHANGELOG.md index cd06878b2..a4cc6e672 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,11 +11,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `griptape.configs.logging.JsonFormatter` for formatting logs as JSON. - Request/response debug logging to all Prompt Drivers. +- `BaseEventListener.flush_events()` to flush events from an Event Listener. +- Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing. ### Changed +- **BREAKING**: `BaseEventListener.publish_event` `flush` argument. Use `BaseEventListener.flush_events()` instead. - `_DefaultsConfig.logging_config` and `Defaults.drivers_config` are now lazily instantiated. - `BaseTask.add_parent`/`BaseTask.add_child` now only add the parent/child task to the structure if it is not already present. +- `BaseEventListener.flush_events()` to flush events from an Event Listener. +- `BaseEventListener` no longer requires a thread lock for batching events. + +### Fixed + +- Structures not flushing events when not listening for `FinishStructureRunEvent`. ## \[0.33.0\] - 2024-10-09 diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index f9cb55dc9..b14a4fa40 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -1,12 +1,12 @@ from __future__ import annotations import logging -import threading from abc import ABC, abstractmethod from typing import TYPE_CHECKING from attrs import Factory, define, field +from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin if TYPE_CHECKING: @@ -16,10 +16,9 @@ @define -class BaseEventListenerDriver(FuturesExecutorMixin, ABC): +class BaseEventListenerDriver(FuturesExecutorMixin, ExponentialBackoffMixin, ABC): batched: bool = field(default=True, kw_only=True) batch_size: int = field(default=10, kw_only=True) - thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock())) _batch: list[dict] = field(default=Factory(list), kw_only=True) @@ -27,8 +26,21 @@ class BaseEventListenerDriver(FuturesExecutorMixin, ABC): def batch(self) -> list[dict]: return self._batch - def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None: - self.futures_executor.submit(self._safe_try_publish_event, event, flush=flush) + def publish_event(self, event: BaseEvent | dict) -> None: + event_payload = event if isinstance(event, dict) else event.to_dict() + + if self.batched: + self._batch.append(event_payload) + if len(self.batch) >= self.batch_size: + self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch) + self._batch = [] + else: + self.futures_executor.submit(self._safe_publish_event_payload, event_payload) + + def flush_events(self) -> None: + if self.batch: + self.futures_executor.submit(self._safe_publish_event_payload_batch, self.batch) + self._batch = [] @abstractmethod def try_publish_event_payload(self, event_payload: dict) -> None: ... @@ -36,18 +48,16 @@ def try_publish_event_payload(self, event_payload: dict) -> None: ... @abstractmethod def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ... - def _safe_try_publish_event(self, event: BaseEvent | dict, *, flush: bool) -> None: - try: - event_payload = event if isinstance(event, dict) else event.to_dict() - - if self.batched: - with self.thread_lock: - self._batch.append(event_payload) - if len(self.batch) >= self.batch_size or flush: - self.try_publish_event_payload_batch(self.batch) - self._batch = [] - return - else: + def _safe_publish_event_payload(self, event_payload: dict) -> None: + for attempt in self.retrying(): + with attempt: self.try_publish_event_payload(event_payload) - except Exception as e: - logger.error(e) + else: + logger.error("event listener driver failed after all retry attempts") + + def _safe_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + for attempt in self.retrying(): + with attempt: + self.try_publish_event_payload_batch(event_payload_batch) + else: + logger.error("event listener driver failed after all retry attempts") diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py index 3e06eaa88..f48d469fa 100644 --- a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -42,7 +42,7 @@ def validate_run_id(self, _: Attribute, structure_run_id: str) -> None: "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).", ) - def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None: + def publish_event(self, event: BaseEvent | dict) -> None: from griptape.observability.observability import Observability event_payload = event.to_dict() if isinstance(event, BaseEvent) else event @@ -51,7 +51,7 @@ def publish_event(self, event: BaseEvent | dict, *, flush: bool = False) -> None if span_id is not None: event_payload["span_id"] = span_id - super().publish_event(event_payload, flush=flush) + super().publish_event(event_payload) def try_publish_event_payload(self, event_payload: dict) -> None: self._post_event(self._get_event_request(event_payload)) diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index 1fad4a1de..e785cd782 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -39,6 +39,9 @@ def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: event_payload = self.handler(event) if self.driver is not None: if event_payload is not None and isinstance(event_payload, dict): - self.driver.publish_event(event_payload, flush=flush) + self.driver.publish_event(event_payload) else: - self.driver.publish_event(event, flush=flush) + self.driver.publish_event(event) + + if self.driver is not None and flush: + self.driver.flush_events() diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py index 114778f72..928706c5f 100644 --- a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -5,32 +5,52 @@ class TestBaseEventListenerDriver: - def test_publish_event(self): + def test_publish_event_no_batched(self): executor = MagicMock() executor.__enter__.return_value = executor - driver = MockEventListenerDriver(futures_executor_fn=lambda: executor) + driver = MockEventListenerDriver(batched=False, futures_executor=executor) + mock_event_payload = MockEvent().to_dict() - driver.publish_event(MockEvent().to_dict()) + driver.publish_event(mock_event_payload) - executor.submit.assert_called_once() + executor.submit.assert_called_once_with(driver._safe_publish_event_payload, mock_event_payload) - def test__safe_try_publish_event(self): - driver = MockEventListenerDriver(batched=False) + def test_publish_event_yes_batched(self): + executor = MagicMock() + executor.__enter__.return_value = executor + driver = MockEventListenerDriver(batched=True, futures_executor=executor) + mock_event_payload = MockEvent().to_dict() - for _ in range(4): - driver._safe_try_publish_event(MockEvent().to_dict(), flush=False) - assert len(driver.batch) == 0 + # Publish 9 events to fill the batch + mock_event_payloads = [mock_event_payload for _ in range(0, 9)] + for mock_event_payload in mock_event_payloads: + driver.publish_event(mock_event_payload) - def test__safe_try_publish_event_batch(self): - driver = MockEventListenerDriver(batched=True) + assert len(driver._batch) == 9 + executor.submit.assert_not_called() - for _ in range(0, 3): - driver._safe_try_publish_event(MockEvent().to_dict(), flush=False) - assert len(driver.batch) == 3 + # Publish the 10th event to trigger the batch publish + driver.publish_event(mock_event_payload) - def test__safe_try_publish_event_batch_flush(self): - driver = MockEventListenerDriver(batched=True) + assert len(driver._batch) == 0 + executor.submit.assert_called_once_with( + driver._safe_publish_event_payload_batch, [*mock_event_payloads, mock_event_payload] + ) + + def test_flush_events(self): + executor = MagicMock() + executor.__enter__.return_value = executor + driver = MockEventListenerDriver(batched=True, futures_executor=executor) + driver.try_publish_event_payload_batch = MagicMock(side_effect=driver.try_publish_event_payload) + + driver.flush_events() + driver.try_publish_event_payload_batch.assert_not_called() + assert driver.batch == [] + mock_event_payloads = [MockEvent().to_dict() for _ in range(0, 3)] + for mock_event_payload in mock_event_payloads: + driver.publish_event(mock_event_payload) + assert len(driver.batch) == 3 - for _ in range(0, 3): - driver._safe_try_publish_event(MockEvent().to_dict(), flush=True) + driver.flush_events() + executor.submit.assert_called_once_with(driver._safe_publish_event_payload_batch, mock_event_payloads) assert len(driver.batch) == 0 diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py index 441589774..472f249cf 100644 --- a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -1,4 +1,5 @@ import os +import time from unittest.mock import MagicMock, Mock import pytest @@ -45,8 +46,10 @@ def test_init(self, driver): def test_publish_event_without_span_id(self, mock_post, driver): event = MockEvent() - driver.publish_event(event, flush=True) + driver.publish_event(event) + driver.flush_events() + time.sleep(1) # Happens asynchronously, so need to wait for it to finish mock_post.assert_called_with( url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", json=[driver._get_event_request(event.to_dict())], @@ -59,8 +62,10 @@ def test_publish_event_with_span_id(self, mock_post, driver): observability_driver.get_span_id.return_value = "test" with Observability(observability_driver=observability_driver): - driver.publish_event(event, flush=True) + driver.publish_event(event) + driver.flush_events() + time.sleep(1) # Happens asynchronously, so need to wait for it to finish mock_post.assert_called_with( url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", json=[driver._get_event_request({**event.to_dict(), "span_id": "test"})], @@ -71,6 +76,7 @@ def test_try_publish_event_payload(self, mock_post, driver): event = MockEvent() driver.try_publish_event_payload(event.to_dict()) + time.sleep(1) # Happens asynchronously, so need to wait for it to finish mock_post.assert_called_once_with( url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", json=driver._get_event_request(event.to_dict()), @@ -82,6 +88,7 @@ def try_publish_event_payload_batch(self, mock_post, driver): event = MockEvent() driver.try_publish_event_payload(event.to_dict()) + time.sleep(1) # Happens asynchronously, so need to wait for it to finish mock_post.assert_called_with( url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", json=driver._get_event_request(event.to_dict()), diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index f35bc5416..00d6c4cba 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -19,6 +19,7 @@ from griptape.structures import Pipeline from griptape.tasks import ActionsSubtask, ToolkitTask from tests.mocks.mock_event import MockEvent +from tests.mocks.mock_event_listener_driver import MockEventListenerDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -121,7 +122,7 @@ def event_handler(_: BaseEvent) -> None: event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent]) event_listener.publish_event(mock_event) - mock_event_listener_driver.publish_event.assert_called_once_with(mock_event, flush=False) + mock_event_listener_driver.publish_event.assert_called_once_with(mock_event) def test_publish_transformed_event(self): mock_event_listener_driver = Mock() @@ -134,7 +135,7 @@ def event_handler(event: BaseEvent): event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent]) event_listener.publish_event(mock_event) - mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False) + mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}) def test_context_manager(self): e1 = EventListener() @@ -153,3 +154,26 @@ def test_context_manager_multiple(self): assert EventBus.event_listeners == [e1, e2, e3] assert EventBus.event_listeners == [e1] + + def test_publish_event_yes_flush(self): + mock_event_listener_driver = MockEventListenerDriver() + mock_event_listener_driver.flush_events = Mock(side_effect=mock_event_listener_driver.flush_events) + + event_listener = EventListener(driver=mock_event_listener_driver, event_types=[MockEvent]) + event_listener.publish_event(MockEvent(), flush=True) + + mock_event_listener_driver.flush_events.assert_called_once() + assert mock_event_listener_driver.batch == [] + + def test_publish_event_no_flush(self): + mock_event_listener_driver = MockEventListenerDriver() + mock_event_listener_driver.flush_events = Mock(side_effect=mock_event_listener_driver.flush_events) + + event_listener = EventListener(driver=mock_event_listener_driver, event_types=[MockEvent]) + mock_event = MockEvent() + event_listener.publish_event(mock_event, flush=False) + + mock_event_listener_driver.flush_events.assert_not_called() + assert mock_event_listener_driver.batch == [ + mock_event.to_dict(), + ]