From 66a10c5b3b85f514e80c21052256f957aa5869c3 Mon Sep 17 00:00:00 2001 From: Shreya Shankar <ss.shankar505@gmail.com> Date: Sun, 26 Jan 2025 08:12:12 -0800 Subject: [PATCH] fix: input context length logic (#293) --- docetl/operations/utils/llm.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/docetl/operations/utils/llm.py b/docetl/operations/utils/llm.py index 484d0981..e5bc7944 100644 --- a/docetl/operations/utils/llm.py +++ b/docetl/operations/utils/llm.py @@ -71,9 +71,18 @@ def truncate_messages( messages: List[Dict[str, str]], model: str, from_agent: bool = False ) -> List[Dict[str, str]]: """Truncate messages to fit within model's context length.""" - model_input_context_length = model_cost.get(model.split("/")[-1], {}).get( - "max_input_tokens", 8192 - ) + model_cost_info = model_cost.get(model, {}) + if not model_cost_info: + # Try stripping the first part before the / + split_model = model.split("/") + if len(split_model) > 1: + model_cost_info = model_cost.get("/".join(split_model[1:]), {}) + + if not model_cost_info: + model_cost_info = model_cost.get(model.split("/")[-1], {}) + + model_input_context_length = model_cost_info.get("max_input_tokens", 8192) + total_tokens = sum(count_tokens(json.dumps(msg), model) for msg in messages) if total_tokens <= model_input_context_length - 100: