From 66a10c5b3b85f514e80c21052256f957aa5869c3 Mon Sep 17 00:00:00 2001
From: Shreya Shankar <ss.shankar505@gmail.com>
Date: Sun, 26 Jan 2025 08:12:12 -0800
Subject: [PATCH] fix: input context length logic (#293)

---
 docetl/operations/utils/llm.py | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/docetl/operations/utils/llm.py b/docetl/operations/utils/llm.py
index 484d0981..e5bc7944 100644
--- a/docetl/operations/utils/llm.py
+++ b/docetl/operations/utils/llm.py
@@ -71,9 +71,18 @@ def truncate_messages(
     messages: List[Dict[str, str]], model: str, from_agent: bool = False
 ) -> List[Dict[str, str]]:
     """Truncate messages to fit within model's context length."""
-    model_input_context_length = model_cost.get(model.split("/")[-1], {}).get(
-        "max_input_tokens", 8192
-    )
+    model_cost_info = model_cost.get(model, {})
+    if not model_cost_info:
+        # Try stripping the first part before the /
+        split_model = model.split("/")
+        if len(split_model) > 1:
+            model_cost_info = model_cost.get("/".join(split_model[1:]), {})
+
+    if not model_cost_info:
+        model_cost_info = model_cost.get(model.split("/")[-1], {})
+
+    model_input_context_length = model_cost_info.get("max_input_tokens", 8192)
+
     total_tokens = sum(count_tokens(json.dumps(msg), model) for msg in messages)
 
     if total_tokens <= model_input_context_length - 100: