Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preliminary support for local LLM endpoints (tested on Ollama) #71

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dailalib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def create_plugin(*args, **kwargs):
gui_ctx_menu_actions["DAILA/LLM/Settings/update_api_key"] = ("Update API key...", litellm_api.ask_api_key)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_model"] = ("Change model...", litellm_api.ask_model)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_custom_url"] = ("Set Custom OpenAI Endpoint...", litellm_api.ask_custom_endpoint)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_custom_model"] = ("Set Custom OpenAI Model...", litellm_api.ask_custom_model)

#
# VarModel API (local variable renaming)
Expand Down
37 changes: 34 additions & 3 deletions dailalib/api/litellm/litellm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
fit_to_tokens: bool = False,
chat_use_ctx: bool = True,
chat_event_callbacks: Optional[dict] = None,
custom_endpoint: Optional[str] = None,
custom_model: Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -47,6 +49,8 @@ def __init__(
self.fit_to_tokens = fit_to_tokens
self.chat_use_ctx = chat_use_ctx
self.chat_event_callbacks = chat_event_callbacks or {"send": None, "receive": None}
self.custom_endpoint = custom_endpoint
self.custom_model = custom_model

# delay prompt import
from .prompts import PROMPTS
Expand Down Expand Up @@ -79,24 +83,29 @@ def query_model(
# delay import because litellm attempts to query the server on import to collect cost information.
from litellm import completion

if not self.api_key:
if not self.api_key and not self.custom_endpoint:
raise ValueError(f"Model API key is not set. Please set it before querying the model {self.model}")

prompt_model = model or self.model
prompt_model = (model or self.model) if not self.custom_endpoint else self.custom_model
response = completion(
model=prompt_model,
messages=[
{"role": "user", "content": prompt}
],
max_tokens=max_tokens,
timeout=60,
api_base=self.custom_endpoint if self.custom_endpoint else None, # Use custom endpoint if set
api_key=self.api_key if not self.custom_endpoint else "dummy" # In most of cases custom endpoint doesn't need the api_key
)
# get the answer
try:
answer = response.choices[0].message.content
except (KeyError, IndexError) as e:
answer = None

if self.custom_endpoint:
return answer, 0

# get the estimated cost
try:
prompt_tokens = response.usage.prompt_tokens
Expand Down Expand Up @@ -189,7 +198,7 @@ def api_key(self, value):
os.environ["ANTHROPIC_API_KEY"] = self._api_key
elif "gemini/gemini" in self.model:
os.environ["GEMINI_API_KEY"] = self._api_key
elif "perplexity" in self.model:
elif "perplexity" in self.model:
os.environ["PERPLEXITY_API_KEY"] = self._api_key

def ask_api_key(self, *args, **kwargs):
Expand All @@ -202,6 +211,28 @@ def ask_api_key(self, *args, **kwargs):
api_key = api_key_or_path
self.api_key = api_key

def ask_custom_endpoint(self, *args, **kwargs):
custom_endpoint = self._dec_interface.gui_ask_for_string("Enter your custom OpenAI endpoint:", title="DAILA")
if not custom_endpoint.strip():
self.custom_endpoint = None
self._dec_interface.info(f"Custom endpoint disabled, defaulting to online API")
return
if not (custom_endpoint.lower().startswith("http://") or custom_endpoint.lower().startswith("https://")):
self.custom_endpoint = None
self._dec_interface.error("Invalid endpoint format")
return
self.custom_endpoint = custom_endpoint.strip()
self._dec_interface.info(f"Custom endpoint set to {self.custom_endpoint}")

def ask_custom_model(self, *args, **kwargs):
custom_model = self._dec_interface.gui_ask_for_string("Enter your custom OpenAI model name:", title="DAILA")
if not custom_model.strip():
self.custom_model = None
self._dec_interface.info(f"Custom model selection cleared")
return
self.custom_model = "openai/" + custom_model.strip()
self._dec_interface.info(f"Custom model set to {self.custom_model}")

def _set_prompt_style(self, prompt_style):
self.prompt_style = prompt_style
global active_prompt_style
Expand Down
Loading