Skip to content

Commit

Permalink
Feat: record the cost to the user (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
mahaloz authored Nov 13, 2024
1 parent 30b6cf6 commit f7210c7
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 17 deletions.
2 changes: 1 addition & 1 deletion dailalib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.10.4"
__version__ = "3.11.0"

import os
# stop LiteLLM from querying at all to the remote server
Expand Down
6 changes: 4 additions & 2 deletions dailalib/api/ai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def _requires_function(*args, ai_api: "AIAPI" = None, **kwargs):

return _requires_function

def on_query(self, query_name, model, prompt_style, function, decompilation):
def on_query(self, query_name, model, prompt_style, function, decompilation, **kwargs):
for func in self.query_callbacks:
t = threading.Thread(target=func, args=(query_name, model, prompt_style, function, decompilation))
t = threading.Thread(
target=func, args=(query_name, model, prompt_style, function, decompilation), kwargs=kwargs
)
t.start()
36 changes: 33 additions & 3 deletions dailalib/api/litellm/litellm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,31 @@ def query_model(
if not self.api_key:
raise ValueError(f"Model API key is not set. Please set it before querying the model {self.model}")

prompt_model = model or self.model
response = completion(
model=model or self.model,
model=prompt_model,
messages=[
{"role": "user", "content": prompt}
],
max_tokens=max_tokens,
timeout=60,
)

# get the answer
try:
answer = response.choices[0].message.content
except (KeyError, IndexError) as e:
answer = None

return answer
# get the estimated cost
try:
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
except (KeyError, IndexError) as e:
prompt_tokens, completion_tokens = None, None
cost = self.llm_cost(prompt_model, prompt_tokens, completion_tokens) \
if prompt_tokens is not None and completion_tokens is not None else None

return answer, cost

@staticmethod
def estimate_token_amount(content: str, model=DEFAULT_MODEL):
Expand Down Expand Up @@ -122,6 +132,26 @@ def fit_decompilation_to_token_max(decompilation: str, delta_step=10, model=DEFA

return LiteLLMAIAPI.fit_decompilation_to_token_max(decompilation, delta_step=delta_step, model=model)

@staticmethod
def llm_cost(model_name: str, prompt_tokens: int, completion_tokens: int) -> float | None:
# these are the $ per million tokens
COST = {
"gpt-4o": {"prompt_price": 2.5, "completion_price": 10},
"gpt-4o-mini": {"prompt_price": 0.150, "completion_price": 0.600},
"gpt-4-turbo": {"prompt_price": 10, "completion_price": 30},
"claude-3.5-sonnet-20240620": {"prompt_price": 3, "completion_price": 15},
"gemini/gemini-pro": {"prompt_price": 0.150, "completion_price": 0.600},
"vertex_ai_beta/gemini-pro": {"prompt_price": 0.150, "completion_price": 0.600},
}
if model_name not in COST:
return None

llm_price = COST[model_name]
prompt_price = (prompt_tokens / 1000000) * llm_price["prompt_price"]
completion_price = (completion_tokens / 1000000) * llm_price["completion_price"]

return round(prompt_price + completion_price, 5)

#
# LMM Settings
#
Expand Down
31 changes: 20 additions & 11 deletions dailalib/api/litellm/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from typing import Optional, Union, Dict, Callable
import textwrap
import time

from ...ai_api import AIAPI
from ..litellm_api import LiteLLMAIAPI
Expand Down Expand Up @@ -90,9 +91,17 @@ def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kw
self.last_rendered_template = query_text
ai_api.info(f"Prompting using: model={self.ai_api.model} and style={self.ai_api.prompt_style}")

ai_api.on_query(self.name, self.ai_api.model, self.ai_api.prompt_style, function, dec_text)
response += self.ai_api.query_model(query_text)
#ai_api.info(f"Response received from AI: {response}")
start_time = time.time()
_resp, cost = ai_api.query_model(query_text)
response += _resp
end_time = time.time()
total_time = end_time - start_time

# callback to handlers of post-query
ai_api.on_query(
self.name, self.ai_api.model, self.ai_api.prompt_style, function, dec_text, total_time=total_time, cost=cost
)

default_response = {} if self._json_response else ""
if not response:
ai_api.warning(f"Response received from AI was empty! AI failed to answer.")
Expand All @@ -119,14 +128,14 @@ def _query_model(ai_api=self.ai_api, function=function, dec_text=dec_text, **_kw
else:
response += self._posttext_response if self._pretext_response else ""

if isinstance(response, dict) or isinstance(response, str):
resp_len = len(response)
if resp_len:
ai_api.info(f"Response of len={resp_len} received from AI...")
else:
ai_api.warning(f"Response recieved from AI, but it was empty! AI failed to answer.")
else:
ai_api.info("Response received from AI!")
resp_len = len(str(response))
log_str = f"Response received from AI after {total_time:.2f}s."
if cost is not None:
log_str += f" Cost: {cost:.3f}."
log_str += f" Length: {resp_len}."
if not resp_len:
log_str += f" AI likely failed to answer coherently."
ai_api.info(log_str)

if ai_api.has_decompiler_gui and response:
ai_api.info("Updating the decompiler with the AI response...")
Expand Down

0 comments on commit f7210c7

Please sign in to comment.