Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update openai sdk #422

Merged
merged 13 commits into from
Nov 10, 2023
40 changes: 27 additions & 13 deletions griptape/drivers/embedding/azure_openai_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,44 @@
from __future__ import annotations

from typing import Optional
from attr import define, field, Factory
from griptape.drivers import OpenAiEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
import openai


@define
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
"""
Attributes:
model: OpenAI embedding model name.
deployment_id: Azure OpenAI deployment ID.
api_base: API URL.
api_type: OpenAI API type. Defaults to 'azure'.
api_version: API version. Defaults to '2023-05-15'.
tokenizer: Optionally provide custom `OpenAiTokenizer`.
azure_deployment: An Azure OpenAi deployment id.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
tokenizer: An `OpenAiTokenizer`.
client: An `openai.AzureOpenAI` client.
"""

model: str = field(kw_only=True)
deployment_id: str = field(kw_only=True)
api_base: str = field(kw_only=True)
api_type: str = field(default="azure", kw_only=True)
azure_deployment: str = field(kw_only=True)
azure_endpoint: str = field(kw_only=True)
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
api_version: str = field(default="2023-05-15", kw_only=True)
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
)

def _params(self, chunk: list[int] | str) -> dict:
return super()._params(chunk) | {"deployment_id": self.deployment_id}
client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
)
)
46 changes: 19 additions & 27 deletions griptape/drivers/embedding/openai_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations
import os
from typing import Optional
import openai
from attr import define, field, Factory
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
import openai


@define
Expand All @@ -13,49 +12,42 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
Attributes:
model: OpenAI embedding model name. Defaults to `text-embedding-ada-002`.
dimensions: Vector dimensions. Defaults to `1536`.
api_type: OpenAI API type, for example 'open_ai' or 'azure'. Defaults to 'open_ai'.
api_version: API version. Defaults to 'OPENAI_API_VERSION' environment variable.
api_base: API URL. Defaults to OpenAI's v1 API URL.
base_url: API URL. Defaults to OpenAI's v1 API URL.
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
tokenizer: Optionally provide custom `OpenAiTokenizer`.
client: Optionally provide custom `openai.OpenAI` client.
azure_deployment: An Azure OpenAi deployment id.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
"""

DEFAULT_MODEL = "text-embedding-ada-002"
DEFAULT_DIMENSIONS = 1536

model: str = field(default=DEFAULT_MODEL, kw_only=True)
dimensions: int = field(default=DEFAULT_DIMENSIONS, kw_only=True)
api_type: str = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True)
api_base: str = field(default=openai.api_base, kw_only=True)
api_key: Optional[str] = field(default=Factory(lambda: os.environ.get("OPENAI_API_KEY")), kw_only=True)
organization: Optional[str] = field(default=openai.organization, kw_only=True)
base_url: str = field(default=None, kw_only=True)
api_key: Optional[str] = field(default=None, kw_only=True)
organization: Optional[str] = field(default=None, kw_only=True)
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
)
)
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
)

def __attrs_post_init__(self) -> None:
openai.api_type = self.api_type
openai.api_version = self.api_version
openai.api_base = self.api_base
openai.api_key = self.api_key
openai.organization = self.organization

def try_embed_chunk(self, chunk: str) -> list[float]:
# Address a performance issue in older ada models
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
if self.model.endswith("001"):
chunk = chunk.replace("\n", " ")
return openai.Embedding.create(**self._params(chunk))["data"][0]["embedding"]
return self.client.embeddings.create(**self._params(chunk)).data[0].embedding

def _params(self, chunk: str) -> dict:
return {
"input": chunk,
"model": self.model,
"api_key": self.api_key,
"organization": self.organization,
"api_version": self.api_version,
"api_base": self.api_base,
"api_type": self.api_type,
}
return {"input": chunk, "model": self.model}
41 changes: 30 additions & 11 deletions griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,45 @@
from attr import define, field, Factory
from typing import Optional
from griptape.utils import PromptStack
from griptape.drivers import OpenAiChatPromptDriver
from griptape.tokenizers import OpenAiTokenizer
import openai


@define
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
"""
Attributes:
api_base: API URL.
deployment_id: Azure OpenAI deployment ID.
model: OpenAI model name.
azure_deployment: An Azure OpenAi deployment id.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
client: An `openai.AzureOpenAI` client.
"""

api_base: str = field(kw_only=True)
model: str = field(kw_only=True)
deployment_id: str = field(kw_only=True)
api_type: str = field(default="azure", kw_only=True)
azure_deployment: str = field(kw_only=True)
azure_endpoint: str = field(kw_only=True)
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
api_version: str = field(default="2023-05-15", kw_only=True)
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
)
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
return super()._base_params(prompt_stack) | {"deployment_id": self.deployment_id}
params = super()._base_params(prompt_stack)
# TODO: Add `seed` parameter once Azure supports it.
del params["seed"]

return params
38 changes: 27 additions & 11 deletions griptape/drivers/prompt/azure_openai_completion_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
from typing import Optional
from attr import define, field, Factory
from griptape.utils import PromptStack
from griptape.drivers import OpenAiCompletionPromptDriver
from griptape.tokenizers import OpenAiTokenizer
import openai


@define
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
api_base: str = field(kw_only=True)
model: str = field(kw_only=True)
deployment_id: str = field(kw_only=True)
api_type: str = field(default="azure", kw_only=True)
"""
Attributes:
azure_deployment: An Azure OpenAi deployment id.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True)
azure_endpoint: str = field(kw_only=True)
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
api_version: str = field(default="2023-05-15", kw_only=True)
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
),
takes_self=True,
)
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
return super()._base_params(prompt_stack) | {"deployment_id": self.deployment_id}
15 changes: 14 additions & 1 deletion griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@

@define
class BasePromptDriver(ExponentialBackoffMixin, ABC):
"""Base class for Prompt Drivers.

Attributes:
temperature: The temperature to use for the completion.
max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
structure: An optional `Structure` to publish events to.
prompt_stack_to_string: A function that converts a `PromptStack` to a string.
ignored_exception_types: A tuple of exception types to ignore.
model: The model name.
tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
"""

temperature: float = field(default=0.1, kw_only=True)
max_tokens: Optional[int] = field(default=None, kw_only=True)
structure: Optional[Structure] = field(default=None, kw_only=True)
Expand All @@ -25,7 +38,7 @@ class BasePromptDriver(ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True)

def max_output_tokens(self, text: str) -> int:
def max_output_tokens(self, text: str | list) -> int:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, just fixing pyright errors.

tokens_left = self.tokenizer.count_tokens_left(text)

if self.max_tokens:
Expand Down
Loading