From 477cce321fe3fe4c8c40196e098666f3f27ce5b4 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Wed, 29 Jan 2025 19:41:09 -0500 Subject: [PATCH] Fix llms (#2003) * iwp * add in api_base --------- Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> --- src/crewai/llm.py | 5 ++++- src/crewai/utilities/llm_utils.py | 17 +++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 98b0bc8553..ef8746fd59 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -133,6 +133,7 @@ def __init__( logprobs: Optional[int] = None, top_logprobs: Optional[int] = None, base_url: Optional[str] = None, + api_base: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, callbacks: List[Any] = [], @@ -152,6 +153,7 @@ def __init__( self.logprobs = logprobs self.top_logprobs = top_logprobs self.base_url = base_url + self.api_base = api_base self.api_version = api_version self.api_key = api_key self.callbacks = callbacks @@ -232,7 +234,8 @@ def call( "seed": self.seed, "logprobs": self.logprobs, "top_logprobs": self.top_logprobs, - "api_base": self.base_url, + "api_base": self.api_base, + "base_url": self.base_url, "api_version": self.api_version, "api_key": self.api_key, "stream": False, diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 13230edf6c..c774a71fbe 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -53,6 +53,7 @@ def create_llm( timeout: Optional[float] = getattr(llm_value, "timeout", None) api_key: Optional[str] = getattr(llm_value, "api_key", None) base_url: Optional[str] = getattr(llm_value, "base_url", None) + api_base: Optional[str] = getattr(llm_value, "api_base", None) created_llm = LLM( model=model, @@ -62,6 +63,7 @@ def create_llm( timeout=timeout, api_key=api_key, base_url=base_url, + api_base=api_base, ) return created_llm except Exception as e: @@ -101,8 +103,18 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: callbacks: List[Any] = [] # Optional base URL from env - api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL") - if api_base: + base_url = ( + os.environ.get("BASE_URL") + or os.environ.get("OPENAI_API_BASE") + or os.environ.get("OPENAI_BASE_URL") + ) + + api_base = os.environ.get("API_BASE") or os.environ.get("AZURE_API_BASE") + + # Synchronize base_url and api_base if one is populated and the other is not + if base_url and not api_base: + api_base = base_url + elif api_base and not base_url: base_url = api_base # Initialize llm_params dictionary @@ -115,6 +127,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: "timeout": timeout, "api_key": api_key, "base_url": base_url, + "api_base": api_base, "api_version": api_version, "presence_penalty": presence_penalty, "frequency_penalty": frequency_penalty,