Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yoomlam committed Jan 30, 2025
1 parent 47ec416 commit d9ae6f4
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 5 deletions.
14 changes: 10 additions & 4 deletions app/src/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def __init__(self) -> None:
super().__init__()

@abstractmethod
def on_message(self, question: str, chat_history: Optional[ChatHistory]) -> OnMessageResult:
def on_message(
self, question: str, chat_history: Optional[ChatHistory] = None
) -> OnMessageResult:
pass


Expand Down Expand Up @@ -131,7 +133,9 @@ class BaseEngine(ChatEngineInterface):

formatting_config = FormattingConfig()

def on_message(self, question: str, chat_history: Optional[ChatHistory]) -> OnMessageResult:
def on_message(
self, question: str, chat_history: Optional[ChatHistory] = None
) -> OnMessageResult:
attributes = analyze_message(self.llm, self.system_prompt_1, question, MessageAttributes)

if attributes.needs_context:
Expand Down Expand Up @@ -299,7 +303,7 @@ class ImagineLaEngine(BaseEngine):
- LGBTQ resources: https://dpss.lacounty.gov/en/rights/rights/sogie.html
If the user's question is related to any of the following policy updates listed below, set canned_response to empty string and \
set alert_message to one or more of the following text based on the user's question:
set alert_message to one or more of the following text based on the user's question:
- Benefits application website: "YourBenefitsNow(YBN) no longer exists. Instead people use [benefitscal.com](https://benefitscal.com/) to apply for and manage \
CalWorks, CalFresh, General Relief and Medi-Cal applications and documents. People can also apply for Medi-Cal and health insurance at coveredca.com."
- Medicaid for immigrants: "Since January 1, 2024, a new law in California will allow adults ages 26 through 49 to qualify for full-scope Medi-Cal, \
Expand Down Expand Up @@ -331,7 +335,9 @@ class ImagineLaEngine(BaseEngine):
{PROMPT}"""

def on_message(self, question: str, chat_history: Optional[ChatHistory]) -> OnMessageResult:
def on_message(
self, question: str, chat_history: Optional[ChatHistory] = None
) -> OnMessageResult:
attributes = analyze_message(
self.llm, self.system_prompt_1, question, response_format=ImagineLA_MessageAttributes
)
Expand Down
57 changes: 56 additions & 1 deletion app/tests/src/test_chat_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src import chat_engine
from src.chat_engine import CaEddWebEngine, ImagineLaEngine
from src.chat_engine import CaEddWebEngine, ImagineLA_MessageAttributes, ImagineLaEngine


def test_available_engines():
Expand Down Expand Up @@ -33,3 +33,58 @@ def test_create_engine_Imagine_LA():
"WIC",
"Covered California",
]


def test_on_message_Imagine_LA_canned_response(monkeypatch):
def mock_analyze_message(llm: str, system_prompt: str, message: str, response_format):
return ImagineLA_MessageAttributes(
needs_context=True,
translated_message="",
canned_response="This is a canned response",
alert_message="",
)

monkeypatch.setattr(chat_engine, "analyze_message", mock_analyze_message)

engine = chat_engine.create_engine("imagine-la")
result = engine.on_message("What is AI?")
assert result.response == "This is a canned response"
assert result.attributes.alert_message == ""


def test_on_message_Imagine_LA_alert_message(monkeypatch):
def mock_analyze_message(llm: str, system_prompt: str, message: str, response_format):
return ImagineLA_MessageAttributes(
needs_context=True,
translated_message="",
canned_response="",
alert_message="Some alert message",
)

monkeypatch.setattr(chat_engine, "analyze_message", mock_analyze_message)

def mock_generate(
llm: str,
system_prompt: str,
query: str,
context_text = None,
chat_history = None,
) -> str:
return "This is a generated response"

monkeypatch.setattr(chat_engine, "generate", mock_generate)

def mock_retrieve_with_scores(
query: str,
retrieval_k: int,
retrieval_k_min_score: float,
**filters,
):
return []

monkeypatch.setattr(chat_engine, "retrieve_with_scores", mock_retrieve_with_scores)

engine = chat_engine.create_engine("imagine-la")
result = engine.on_message("What is AI?")
assert result.attributes.alert_message.startswith("**Policy update**: ")
assert result.attributes.alert_message.endswith("\n\nThe rest of this answer may be outdated.")

0 comments on commit d9ae6f4

Please sign in to comment.