Skip to content

Commit

Permalink
Add support for Gemini through VertexAPI (#40)
Browse files Browse the repository at this point in the history
* Add support for Gemini through VertexAPI

* Bump
  • Loading branch information
mahaloz authored Jun 29, 2024
1 parent 5cf7666 commit 73cbe8d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 141 deletions.
2 changes: 1 addition & 1 deletion dailalib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.0.0"
__version__ = "3.1.0"

from .api import AIAPI, LiteLLMAIAPI
from libbs.api import DecompilerInterface
Expand Down
23 changes: 16 additions & 7 deletions dailalib/api/litellm/litellm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ class LiteLLMAIAPI(AIAPI):
DEFAULT_MODEL = "gpt-4o"
MODEL_TO_TOKENS = {
"gpt-4-turbo": 128_000,
"gpt-4": 8000,
"gpt-4o": 8000,
"gpt-3.5-turbo": 4096,
"gpt-4": 8_000,
"gpt-4o": 8_000,
"gpt-3.5-turbo": 4_096,
"claude-2": 200_000,
"vertex_ai_beta/gemini-pro": 12_288,
}

# replacement strings for API calls
def __init__(self, api_key: Optional[str] = None, model: str = DEFAULT_MODEL, prompts: Optional[list] = None, **kwargs):
super().__init__(**kwargs)
self._api_key = None
self._openai_client: OpenAI = None
# default to openai api key if not provided
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.model = model
Expand Down Expand Up @@ -108,6 +108,8 @@ def api_key(self):
return os.getenv("OPENAI_API_KEY", None)
elif "claude" in self.model:
return os.getenv("ANTHROPIC_API_KEY", None)
elif "vertex" in self.model:
return self._api_key
else:
return None

Expand All @@ -121,7 +123,14 @@ def api_key(self, value):
os.environ["ANTHROPIC_API_KEY"] = self._api_key

def ask_api_key(self, *args, **kwargs):
self.api_key = self._dec_interface.gui_ask_for_string("Enter you AI API Key:", title="DAILA")
api_key_or_path = self._dec_interface.gui_ask_for_string("Enter you AI API Key or Creds Path:", title="DAILA")
if "/" in api_key_or_path or "\\" in api_key_or_path:
# treat as path
with open(api_key_or_path, "r") as f:
api_key = f.read().strip()
else:
api_key = api_key_or_path
self.api_key = api_key

def ask_prompt_style(self):
if self._dec_interface is not None:
Expand All @@ -133,7 +142,7 @@ def ask_prompt_style(self):

p_style = self._dec_interface.gui_ask_for_choice(
"What prompting style would you like to use?",
ALL_STYLES,
style_choices,
title="DAILA"
)
self.prompt_style = p_style
Expand All @@ -147,7 +156,7 @@ def ask_model(self):

model = self._dec_interface.gui_ask_for_choice(
"What LLM model would you like to use?",
list(LiteLLMAIAPI.MODEL_TO_TOKENS.keys()),
model_choices,
title="DAILA"
)
self.model = model
Expand Down
133 changes: 0 additions & 133 deletions dailalib/api/litellm/prompts.py

This file was deleted.

0 comments on commit 73cbe8d

Please sign in to comment.