diff --git a/app/src/chat_engine.py b/app/src/chat_engine.py index 2dfbda4e..41403545 100644 --- a/app/src/chat_engine.py +++ b/app/src/chat_engine.py @@ -86,7 +86,9 @@ def __init__(self) -> None: super().__init__() @abstractmethod - def on_message(self, question: str, chat_history: Optional[ChatHistory]) -> OnMessageResult: + def on_message( + self, question: str, chat_history: Optional[ChatHistory] = None + ) -> OnMessageResult: pass @@ -131,7 +133,9 @@ class BaseEngine(ChatEngineInterface): formatting_config = FormattingConfig() - def on_message(self, question: str, chat_history: Optional[ChatHistory]) -> OnMessageResult: + def on_message( + self, question: str, chat_history: Optional[ChatHistory] = None + ) -> OnMessageResult: attributes = analyze_message(self.llm, self.system_prompt_1, question, MessageAttributes) if attributes.needs_context: @@ -299,7 +303,7 @@ class ImagineLaEngine(BaseEngine): - LGBTQ resources: https://dpss.lacounty.gov/en/rights/rights/sogie.html If the user's question is related to any of the following policy updates listed below, set canned_response to empty string and \ -set alert_message to one or more of the following text based on the user's question: +set alert_message to one or more of the following text based on the user's question: - Benefits application website: "YourBenefitsNow(YBN) no longer exists. Instead people use [benefitscal.com](https://benefitscal.com/) to apply for and manage \ CalWorks, CalFresh, General Relief and Medi-Cal applications and documents. People can also apply for Medi-Cal and health insurance at coveredca.com." - Medicaid for immigrants: "Since January 1, 2024, a new law in California will allow adults ages 26 through 49 to qualify for full-scope Medi-Cal, \ @@ -331,7 +335,9 @@ class ImagineLaEngine(BaseEngine): {PROMPT}""" - def on_message(self, question: str, chat_history: Optional[ChatHistory]) -> OnMessageResult: + def on_message( + self, question: str, chat_history: Optional[ChatHistory] = None + ) -> OnMessageResult: attributes = analyze_message( self.llm, self.system_prompt_1, question, response_format=ImagineLA_MessageAttributes ) diff --git a/app/tests/src/test_chat_engine.py b/app/tests/src/test_chat_engine.py index 3f91c35d..10938a33 100644 --- a/app/tests/src/test_chat_engine.py +++ b/app/tests/src/test_chat_engine.py @@ -1,5 +1,5 @@ from src import chat_engine -from src.chat_engine import CaEddWebEngine, ImagineLaEngine +from src.chat_engine import CaEddWebEngine, ImagineLA_MessageAttributes, ImagineLaEngine def test_available_engines(): @@ -33,3 +33,58 @@ def test_create_engine_Imagine_LA(): "WIC", "Covered California", ] + + +def test_on_message_Imagine_LA_canned_response(monkeypatch): + def mock_analyze_message(llm: str, system_prompt: str, message: str, response_format): + return ImagineLA_MessageAttributes( + needs_context=True, + translated_message="", + canned_response="This is a canned response", + alert_message="", + ) + + monkeypatch.setattr(chat_engine, "analyze_message", mock_analyze_message) + + engine = chat_engine.create_engine("imagine-la") + result = engine.on_message("What is AI?") + assert result.response == "This is a canned response" + assert result.attributes.alert_message == "" + + +def test_on_message_Imagine_LA_alert_message(monkeypatch): + def mock_analyze_message(llm: str, system_prompt: str, message: str, response_format): + return ImagineLA_MessageAttributes( + needs_context=True, + translated_message="", + canned_response="", + alert_message="Some alert message", + ) + + monkeypatch.setattr(chat_engine, "analyze_message", mock_analyze_message) + + def mock_generate( + llm: str, + system_prompt: str, + query: str, + context_text = None, + chat_history = None, + ) -> str: + return "This is a generated response" + + monkeypatch.setattr(chat_engine, "generate", mock_generate) + + def mock_retrieve_with_scores( + query: str, + retrieval_k: int, + retrieval_k_min_score: float, + **filters, + ): + return [] + + monkeypatch.setattr(chat_engine, "retrieve_with_scores", mock_retrieve_with_scores) + + engine = chat_engine.create_engine("imagine-la") + result = engine.on_message("What is AI?") + assert result.attributes.alert_message.startswith("**Policy update**: ") + assert result.attributes.alert_message.endswith("\n\nThe rest of this answer may be outdated.")