Skip to content

Commit 1790cfd

Browse files
Update openai sdk (#422)
* Update openai sdk * Fix integration tests * PR fixes * Remove util methods * Update type * Fix deps * Update poetry lock * Add back missing pytest-env * Remove init=False, improve comments * Resolve poetry.lock conflict --------- Co-authored-by: Vasily Vasinov <[email protected]>
1 parent b616525 commit 1790cfd

19 files changed

+400
-257
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,44 @@
11
from __future__ import annotations
22

3+
from typing import Optional
34
from attr import define, field, Factory
45
from griptape.drivers import OpenAiEmbeddingDriver
56
from griptape.tokenizers import OpenAiTokenizer
7+
import openai
68

79

810
@define
911
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
1012
"""
1113
Attributes:
12-
model: OpenAI embedding model name.
13-
deployment_id: Azure OpenAI deployment ID.
14-
api_base: API URL.
15-
api_type: OpenAI API type. Defaults to 'azure'.
16-
api_version: API version. Defaults to '2023-05-15'.
17-
tokenizer: Optionally provide custom `OpenAiTokenizer`.
14+
azure_deployment: An Azure OpenAi deployment id.
15+
azure_endpoint: An Azure OpenAi endpoint.
16+
azure_ad_token: An optional Azure Active Directory token.
17+
azure_ad_token_provider: An optional Azure Active Directory token provider.
18+
api_version: An Azure OpenAi API version.
19+
tokenizer: An `OpenAiTokenizer`.
20+
client: An `openai.AzureOpenAI` client.
1821
"""
1922

20-
model: str = field(kw_only=True)
21-
deployment_id: str = field(kw_only=True)
22-
api_base: str = field(kw_only=True)
23-
api_type: str = field(default="azure", kw_only=True)
23+
azure_deployment: str = field(kw_only=True)
24+
azure_endpoint: str = field(kw_only=True)
25+
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
26+
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
2427
api_version: str = field(default="2023-05-15", kw_only=True)
2528
tokenizer: OpenAiTokenizer = field(
2629
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
2730
)
28-
29-
def _params(self, chunk: list[int] | str) -> dict:
30-
return super()._params(chunk) | {"deployment_id": self.deployment_id}
31+
client: openai.AzureOpenAI = field(
32+
default=Factory(
33+
lambda self: openai.AzureOpenAI(
34+
organization=self.organization,
35+
api_key=self.api_key,
36+
api_version=self.api_version,
37+
azure_endpoint=self.azure_endpoint,
38+
azure_deployment=self.azure_deployment,
39+
azure_ad_token=self.azure_ad_token,
40+
azure_ad_token_provider=self.azure_ad_token_provider,
41+
),
42+
takes_self=True,
43+
)
44+
)
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
2-
import os
32
from typing import Optional
4-
import openai
53
from attr import define, field, Factory
64
from griptape.drivers import BaseEmbeddingDriver
75
from griptape.tokenizers import OpenAiTokenizer
6+
import openai
87

98

109
@define
@@ -13,49 +12,42 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
1312
Attributes:
1413
model: OpenAI embedding model name. Defaults to `text-embedding-ada-002`.
1514
dimensions: Vector dimensions. Defaults to `1536`.
16-
api_type: OpenAI API type, for example 'open_ai' or 'azure'. Defaults to 'open_ai'.
17-
api_version: API version. Defaults to 'OPENAI_API_VERSION' environment variable.
18-
api_base: API URL. Defaults to OpenAI's v1 API URL.
15+
base_url: API URL. Defaults to OpenAI's v1 API URL.
1916
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
2017
organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
2118
tokenizer: Optionally provide custom `OpenAiTokenizer`.
19+
client: Optionally provide custom `openai.OpenAI` client.
20+
azure_deployment: An Azure OpenAi deployment id.
21+
azure_endpoint: An Azure OpenAi endpoint.
22+
azure_ad_token: An optional Azure Active Directory token.
23+
azure_ad_token_provider: An optional Azure Active Directory token provider.
24+
api_version: An Azure OpenAi API version.
2225
"""
2326

2427
DEFAULT_MODEL = "text-embedding-ada-002"
2528
DEFAULT_DIMENSIONS = 1536
2629

2730
model: str = field(default=DEFAULT_MODEL, kw_only=True)
2831
dimensions: int = field(default=DEFAULT_DIMENSIONS, kw_only=True)
29-
api_type: str = field(default=openai.api_type, kw_only=True)
30-
api_version: Optional[str] = field(default=openai.api_version, kw_only=True)
31-
api_base: str = field(default=openai.api_base, kw_only=True)
32-
api_key: Optional[str] = field(default=Factory(lambda: os.environ.get("OPENAI_API_KEY")), kw_only=True)
33-
organization: Optional[str] = field(default=openai.organization, kw_only=True)
32+
base_url: str = field(default=None, kw_only=True)
33+
api_key: Optional[str] = field(default=None, kw_only=True)
34+
organization: Optional[str] = field(default=None, kw_only=True)
35+
client: openai.OpenAI = field(
36+
default=Factory(
37+
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
38+
takes_self=True,
39+
)
40+
)
3441
tokenizer: OpenAiTokenizer = field(
3542
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
3643
)
3744

38-
def __attrs_post_init__(self) -> None:
39-
openai.api_type = self.api_type
40-
openai.api_version = self.api_version
41-
openai.api_base = self.api_base
42-
openai.api_key = self.api_key
43-
openai.organization = self.organization
44-
4545
def try_embed_chunk(self, chunk: str) -> list[float]:
4646
# Address a performance issue in older ada models
4747
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
4848
if self.model.endswith("001"):
4949
chunk = chunk.replace("\n", " ")
50-
return openai.Embedding.create(**self._params(chunk))["data"][0]["embedding"]
50+
return self.client.embeddings.create(**self._params(chunk)).data[0].embedding
5151

5252
def _params(self, chunk: str) -> dict:
53-
return {
54-
"input": chunk,
55-
"model": self.model,
56-
"api_key": self.api_key,
57-
"organization": self.organization,
58-
"api_version": self.api_version,
59-
"api_base": self.api_base,
60-
"api_type": self.api_type,
61-
}
53+
return {"input": chunk, "model": self.model}
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,45 @@
11
from attr import define, field, Factory
2+
from typing import Optional
23
from griptape.utils import PromptStack
34
from griptape.drivers import OpenAiChatPromptDriver
4-
from griptape.tokenizers import OpenAiTokenizer
5+
import openai
56

67

78
@define
89
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
910
"""
1011
Attributes:
11-
api_base: API URL.
12-
deployment_id: Azure OpenAI deployment ID.
13-
model: OpenAI model name.
12+
azure_deployment: An Azure OpenAi deployment id.
13+
azure_endpoint: An Azure OpenAi endpoint.
14+
azure_ad_token: An optional Azure Active Directory token.
15+
azure_ad_token_provider: An optional Azure Active Directory token provider.
16+
api_version: An Azure OpenAi API version.
17+
client: An `openai.AzureOpenAI` client.
1418
"""
1519

16-
api_base: str = field(kw_only=True)
17-
model: str = field(kw_only=True)
18-
deployment_id: str = field(kw_only=True)
19-
api_type: str = field(default="azure", kw_only=True)
20+
azure_deployment: str = field(kw_only=True)
21+
azure_endpoint: str = field(kw_only=True)
22+
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
23+
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
2024
api_version: str = field(default="2023-05-15", kw_only=True)
21-
tokenizer: OpenAiTokenizer = field(
22-
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
25+
client: openai.AzureOpenAI = field(
26+
default=Factory(
27+
lambda self: openai.AzureOpenAI(
28+
organization=self.organization,
29+
api_key=self.api_key,
30+
api_version=self.api_version,
31+
azure_endpoint=self.azure_endpoint,
32+
azure_deployment=self.azure_deployment,
33+
azure_ad_token=self.azure_ad_token,
34+
azure_ad_token_provider=self.azure_ad_token_provider,
35+
),
36+
takes_self=True,
37+
)
2338
)
2439

2540
def _base_params(self, prompt_stack: PromptStack) -> dict:
26-
return super()._base_params(prompt_stack) | {"deployment_id": self.deployment_id}
41+
params = super()._base_params(prompt_stack)
42+
# TODO: Add `seed` parameter once Azure supports it.
43+
del params["seed"]
44+
45+
return params

griptape/drivers/prompt/azure_openai_completion_prompt_driver.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,35 @@
1+
from typing import Optional
12
from attr import define, field, Factory
2-
from griptape.utils import PromptStack
33
from griptape.drivers import OpenAiCompletionPromptDriver
4-
from griptape.tokenizers import OpenAiTokenizer
4+
import openai
55

66

77
@define
88
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
9-
api_base: str = field(kw_only=True)
10-
model: str = field(kw_only=True)
11-
deployment_id: str = field(kw_only=True)
12-
api_type: str = field(default="azure", kw_only=True)
9+
"""
10+
Attributes:
11+
azure_deployment: An Azure OpenAi deployment id.
12+
azure_endpoint: An Azure OpenAi endpoint.
13+
azure_ad_token: An optional Azure Active Directory token.
14+
azure_ad_token_provider: An optional Azure Active Directory token provider.
15+
api_version: An Azure OpenAi API version.
16+
client: An `openai.AzureOpenAI` client.
17+
"""
18+
19+
azure_deployment: str = field(kw_only=True)
20+
azure_endpoint: str = field(kw_only=True)
21+
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
22+
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
1323
api_version: str = field(default="2023-05-15", kw_only=True)
14-
tokenizer: OpenAiTokenizer = field(
15-
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
24+
client: openai.AzureOpenAI = field(
25+
default=Factory(
26+
lambda self: openai.AzureOpenAI(
27+
organization=self.organization,
28+
api_key=self.api_key,
29+
api_version=self.api_version,
30+
azure_endpoint=self.azure_endpoint,
31+
azure_deployment=self.azure_deployment,
32+
),
33+
takes_self=True,
34+
)
1635
)
17-
18-
def _base_params(self, prompt_stack: PromptStack) -> dict:
19-
return super()._base_params(prompt_stack) | {"deployment_id": self.deployment_id}

griptape/drivers/prompt/base_prompt_driver.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@
1414

1515
@define
1616
class BasePromptDriver(ExponentialBackoffMixin, ABC):
17+
"""Base class for Prompt Drivers.
18+
19+
Attributes:
20+
temperature: The temperature to use for the completion.
21+
max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
22+
structure: An optional `Structure` to publish events to.
23+
prompt_stack_to_string: A function that converts a `PromptStack` to a string.
24+
ignored_exception_types: A tuple of exception types to ignore.
25+
model: The model name.
26+
tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
27+
stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
28+
"""
29+
1730
temperature: float = field(default=0.1, kw_only=True)
1831
max_tokens: Optional[int] = field(default=None, kw_only=True)
1932
structure: Optional[Structure] = field(default=None, kw_only=True)
@@ -25,7 +38,7 @@ class BasePromptDriver(ExponentialBackoffMixin, ABC):
2538
tokenizer: BaseTokenizer
2639
stream: bool = field(default=False, kw_only=True)
2740

28-
def max_output_tokens(self, text: str) -> int:
41+
def max_output_tokens(self, text: str | list) -> int:
2942
tokens_left = self.tokenizer.count_tokens_left(text)
3043

3144
if self.max_tokens:

0 commit comments

Comments
 (0)