Skip to content

Commit

Permalink
Add token parser for Bedrock & fix anthropic typo (#1851)
Browse files Browse the repository at this point in the history
I am currently using Ragas with Bedrock and had to create new token
usage parsers. This PR responds to
[Issue-1151](#1151)

Thanks for this package it's nice 😄
  • Loading branch information
michaelromagne authored Jan 18, 2025
1 parent b3c768b commit b8c6be2
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 18 deletions.
27 changes: 27 additions & 0 deletions src/ragas/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,33 @@ def get_token_usage_for_anthropic(
return TokenUsage(input_tokens=0, output_tokens=0)


def get_token_usage_for_bedrock(
llm_result: t.Union[LLMResult, ChatResult],
) -> TokenUsage:
token_usages = []
for gs in llm_result.generations:
for g in gs:
if isinstance(g, ChatGeneration):
if g.message.response_metadata != {}:
token_usages.append(
TokenUsage(
input_tokens=get_from_dict(
g.message.response_metadata,
"usage.prompt_tokens",
0,
),
output_tokens=get_from_dict(
g.message.response_metadata,
"usage.completion_tokens",
0,
),
)
)

return sum(token_usages, TokenUsage(input_tokens=0, output_tokens=0))
return TokenUsage(input_tokens=0, output_tokens=0)


class CostCallbackHandler(BaseCallbackHandler):
def __init__(self, token_usage_parser: TokenUsageParser):
self.token_usage_parser = token_usage_parser
Expand Down
1 change: 1 addition & 0 deletions src/ragas/integrations/swarm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Any, Dict, List, Union

from ragas.messages import AIMessage, HumanMessage, ToolCall, ToolMessage


Expand Down
2 changes: 1 addition & 1 deletion src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def is_finished(self, response: LLMResult) -> bool:
elif resp_message.response_metadata.get("stop_reason") is not None:
stop_reason = resp_message.response_metadata.get("stop_reason")
is_finished_list.append(
stop_reason in ["end_turn", "STOP", "MAX_TOKENS"]
stop_reason in ["end_turn", "stop", "STOP", "MAX_TOKENS"]
)
# default to True
else:
Expand Down
29 changes: 14 additions & 15 deletions src/ragas/prompt/pydantic_prompt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import copy
import hashlib
import json
import logging
import os
import hashlib

import typing as t

from langchain_core.exceptions import OutputParserException
Expand Down Expand Up @@ -228,7 +227,7 @@ async def adapt(
"""
Adapt the prompt to a new language.
"""

strings = get_all_strings(self.examples)
translated_strings = await translate_statements_prompt.generate(
llm=llm,
Expand Down Expand Up @@ -275,7 +274,7 @@ def __str__(self):
ensure_ascii=False,
)[1:-1]
return f"{self.__class__.__name__}({json_str})"

def __hash__(self):
# convert examples to json string for hashing
examples = []
Expand All @@ -284,23 +283,23 @@ def __hash__(self):
examples.append(
(input_model.model_dump_json(), output_model.model_dump_json())
)

# create a SHA-256 hash object
hasher = hashlib.sha256()

# update the hash object with the bytes of each attribute
hasher.update(self.name.encode('utf-8'))
hasher.update(self.input_model.__name__.encode('utf-8'))
hasher.update(self.output_model.__name__.encode('utf-8'))
hasher.update(self.instruction.encode('utf-8'))
hasher.update(self.name.encode("utf-8"))
hasher.update(self.input_model.__name__.encode("utf-8"))
hasher.update(self.output_model.__name__.encode("utf-8"))
hasher.update(self.instruction.encode("utf-8"))
for example in examples:
hasher.update(example[0].encode('utf-8'))
hasher.update(example[1].encode('utf-8'))
hasher.update(self.language.encode('utf-8'))
hasher.update(example[0].encode("utf-8"))
hasher.update(example[1].encode("utf-8"))
hasher.update(self.language.encode("utf-8"))

# return the integer value of the hash
return int(hasher.hexdigest(), 16)

def __eq__(self, other):
if not isinstance(other, PydanticPrompt):
return False
Expand Down
59 changes: 57 additions & 2 deletions tests/unit/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CostCallbackHandler,
TokenUsage,
get_token_usage_for_anthropic,
get_token_usage_for_bedrock,
get_token_usage_for_openai,
)

Expand Down Expand Up @@ -62,7 +63,7 @@ def test_token_usage_cost():
},
)

athropic_llm_result = LLMResult(
anthropic_llm_result = LLMResult(
generations=[
[
ChatGeneration(
Expand All @@ -82,16 +83,70 @@ def test_token_usage_cost():
llm_output={},
)

bedrock_llama_result = LLMResult(
generations=[
[
ChatGeneration(
text="Hello, world!",
message=AIMessage(
content="Hello, world!",
response_metadata={
"usage": {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
},
"stop_reason": "stop",
"model_id": "us.meta.llama3-1-70b-instruct-v1:0",
},
),
)
]
],
llm_output={},
)

bedrock_claude_result = LLMResult(
generations=[
[
ChatGeneration(
text="Hello, world!",
message=AIMessage(
content="Hello, world!",
response_metadata={
"usage": {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
},
"stop_reason": "end_turn",
"model_id": "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
},
),
)
]
],
llm_output={},
)


def test_parse_llm_results():
# openai
token_usage = get_token_usage_for_openai(openai_llm_result)
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)

# anthropic
token_usage = get_token_usage_for_anthropic(athropic_llm_result)
token_usage = get_token_usage_for_anthropic(anthropic_llm_result)
assert token_usage == TokenUsage(input_tokens=9, output_tokens=12)

# Bedrock LLaMa
token_usage = get_token_usage_for_bedrock(bedrock_llama_result)
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)

# Bedrock Claude
token_usage = get_token_usage_for_bedrock(bedrock_claude_result)
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)


def test_cost_callback_handler():
cost_cb = CostCallbackHandler(token_usage_parser=get_token_usage_for_openai)
Expand Down

0 comments on commit b8c6be2

Please sign in to comment.