diff --git a/src/ragas/cost.py b/src/ragas/cost.py index 0703a1f51..144f66a12 100644 --- a/src/ragas/cost.py +++ b/src/ragas/cost.py @@ -100,6 +100,33 @@ def get_token_usage_for_anthropic( return TokenUsage(input_tokens=0, output_tokens=0) +def get_token_usage_for_bedrock( + llm_result: t.Union[LLMResult, ChatResult], +) -> TokenUsage: + token_usages = [] + for gs in llm_result.generations: + for g in gs: + if isinstance(g, ChatGeneration): + if g.message.response_metadata != {}: + token_usages.append( + TokenUsage( + input_tokens=get_from_dict( + g.message.response_metadata, + "usage.prompt_tokens", + 0, + ), + output_tokens=get_from_dict( + g.message.response_metadata, + "usage.completion_tokens", + 0, + ), + ) + ) + + return sum(token_usages, TokenUsage(input_tokens=0, output_tokens=0)) + return TokenUsage(input_tokens=0, output_tokens=0) + + class CostCallbackHandler(BaseCallbackHandler): def __init__(self, token_usage_parser: TokenUsageParser): self.token_usage_parser = token_usage_parser diff --git a/src/ragas/integrations/swarm.py b/src/ragas/integrations/swarm.py index 56110ace0..e7e574e5f 100644 --- a/src/ragas/integrations/swarm.py +++ b/src/ragas/integrations/swarm.py @@ -1,5 +1,6 @@ import json from typing import Any, Dict, List, Union + from ragas.messages import AIMessage, HumanMessage, ToolCall, ToolMessage diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index a30d4da5a..3955ee6fe 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -182,7 +182,7 @@ def is_finished(self, response: LLMResult) -> bool: elif resp_message.response_metadata.get("stop_reason") is not None: stop_reason = resp_message.response_metadata.get("stop_reason") is_finished_list.append( - stop_reason in ["end_turn", "STOP", "MAX_TOKENS"] + stop_reason in ["end_turn", "stop", "STOP", "MAX_TOKENS"] ) # default to True else: diff --git a/src/ragas/prompt/pydantic_prompt.py b/src/ragas/prompt/pydantic_prompt.py index fbf28010f..a7ee6cba1 100644 --- a/src/ragas/prompt/pydantic_prompt.py +++ b/src/ragas/prompt/pydantic_prompt.py @@ -1,11 +1,10 @@ from __future__ import annotations import copy +import hashlib import json import logging import os -import hashlib - import typing as t from langchain_core.exceptions import OutputParserException @@ -228,7 +227,7 @@ async def adapt( """ Adapt the prompt to a new language. """ - + strings = get_all_strings(self.examples) translated_strings = await translate_statements_prompt.generate( llm=llm, @@ -275,7 +274,7 @@ def __str__(self): ensure_ascii=False, )[1:-1] return f"{self.__class__.__name__}({json_str})" - + def __hash__(self): # convert examples to json string for hashing examples = [] @@ -284,23 +283,23 @@ def __hash__(self): examples.append( (input_model.model_dump_json(), output_model.model_dump_json()) ) - + # create a SHA-256 hash object hasher = hashlib.sha256() - + # update the hash object with the bytes of each attribute - hasher.update(self.name.encode('utf-8')) - hasher.update(self.input_model.__name__.encode('utf-8')) - hasher.update(self.output_model.__name__.encode('utf-8')) - hasher.update(self.instruction.encode('utf-8')) + hasher.update(self.name.encode("utf-8")) + hasher.update(self.input_model.__name__.encode("utf-8")) + hasher.update(self.output_model.__name__.encode("utf-8")) + hasher.update(self.instruction.encode("utf-8")) for example in examples: - hasher.update(example[0].encode('utf-8')) - hasher.update(example[1].encode('utf-8')) - hasher.update(self.language.encode('utf-8')) - + hasher.update(example[0].encode("utf-8")) + hasher.update(example[1].encode("utf-8")) + hasher.update(self.language.encode("utf-8")) + # return the integer value of the hash return int(hasher.hexdigest(), 16) - + def __eq__(self, other): if not isinstance(other, PydanticPrompt): return False diff --git a/tests/unit/test_cost.py b/tests/unit/test_cost.py index 772d4d71c..715f28f94 100644 --- a/tests/unit/test_cost.py +++ b/tests/unit/test_cost.py @@ -6,6 +6,7 @@ CostCallbackHandler, TokenUsage, get_token_usage_for_anthropic, + get_token_usage_for_bedrock, get_token_usage_for_openai, ) @@ -62,7 +63,7 @@ def test_token_usage_cost(): }, ) -athropic_llm_result = LLMResult( +anthropic_llm_result = LLMResult( generations=[ [ ChatGeneration( @@ -82,6 +83,52 @@ def test_token_usage_cost(): llm_output={}, ) +bedrock_llama_result = LLMResult( + generations=[ + [ + ChatGeneration( + text="Hello, world!", + message=AIMessage( + content="Hello, world!", + response_metadata={ + "usage": { + "prompt_tokens": 10, + "completion_tokens": 10, + "total_tokens": 20, + }, + "stop_reason": "stop", + "model_id": "us.meta.llama3-1-70b-instruct-v1:0", + }, + ), + ) + ] + ], + llm_output={}, +) + +bedrock_claude_result = LLMResult( + generations=[ + [ + ChatGeneration( + text="Hello, world!", + message=AIMessage( + content="Hello, world!", + response_metadata={ + "usage": { + "prompt_tokens": 10, + "completion_tokens": 10, + "total_tokens": 20, + }, + "stop_reason": "end_turn", + "model_id": "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + }, + ), + ) + ] + ], + llm_output={}, +) + def test_parse_llm_results(): # openai @@ -89,9 +136,17 @@ def test_parse_llm_results(): assert token_usage == TokenUsage(input_tokens=10, output_tokens=10) # anthropic - token_usage = get_token_usage_for_anthropic(athropic_llm_result) + token_usage = get_token_usage_for_anthropic(anthropic_llm_result) assert token_usage == TokenUsage(input_tokens=9, output_tokens=12) + # Bedrock LLaMa + token_usage = get_token_usage_for_bedrock(bedrock_llama_result) + assert token_usage == TokenUsage(input_tokens=10, output_tokens=10) + + # Bedrock Claude + token_usage = get_token_usage_for_bedrock(bedrock_claude_result) + assert token_usage == TokenUsage(input_tokens=10, output_tokens=10) + def test_cost_callback_handler(): cost_cb = CostCallbackHandler(token_usage_parser=get_token_usage_for_openai)