Skip to content

Commit 17836be

Browse files
committed
Update openai sdk
1 parent a314190 commit 17836be

19 files changed

+1378
-1230
lines changed

griptape/drivers/embedding/azure_openai_embedding_driver.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from attr import define, field, Factory
44
from griptape.drivers import OpenAiEmbeddingDriver
55
from griptape.tokenizers import OpenAiTokenizer
6+
import openai
67

78

89
@define
@@ -17,7 +18,6 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
1718
tokenizer: Optionally provide custom `OpenAiTokenizer`.
1819
"""
1920

20-
model: str = field(kw_only=True)
2121
deployment_id: str = field(kw_only=True)
2222
api_base: str = field(kw_only=True)
2323
api_type: str = field(default="azure", kw_only=True)
@@ -28,6 +28,17 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
2828
),
2929
kw_only=True,
3030
)
31+
client: openai.AzureOpenAI = field(
32+
init=False,
33+
default=Factory(
34+
lambda self: openai.AzureOpenAI(
35+
api_key=self.api_key,
36+
base_url=self.base_url,
37+
organization=self.organization,
38+
),
39+
takes_self=True,
40+
),
41+
)
3142

32-
def _params(self, chunk: list[int] | str) -> dict:
43+
def _params(self, chunk: str) -> dict:
3344
return super()._params(chunk) | {"deployment_id": self.deployment_id}
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
2-
import os
32
from typing import Optional
4-
import openai
53
from attr import define, field, Factory
64
from griptape.drivers import BaseEmbeddingDriver
75
from griptape.tokenizers import OpenAiTokenizer
6+
import openai
87

98

109
@define
@@ -13,9 +12,7 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
1312
Attributes:
1413
model: OpenAI embedding model name. Defaults to `text-embedding-ada-002`.
1514
dimensions: Vector dimensions. Defaults to `1536`.
16-
api_type: OpenAI API type, for example 'open_ai' or 'azure'. Defaults to 'open_ai'.
17-
api_version: API version. Defaults to 'OPENAI_API_VERSION' environment variable.
18-
api_base: API URL. Defaults to OpenAI's v1 API URL.
15+
base_url: API URL. Defaults to OpenAI's v1 API URL.
1916
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
2017
organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
2118
tokenizer: Optionally provide custom `OpenAiTokenizer`.
@@ -26,14 +23,19 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
2623

2724
model: str = field(default=DEFAULT_MODEL, kw_only=True)
2825
dimensions: int = field(default=DEFAULT_DIMENSIONS, kw_only=True)
29-
api_type: str = field(default=openai.api_type, kw_only=True)
30-
api_version: Optional[str] = field(default=openai.api_version, kw_only=True)
31-
api_base: str = field(default=openai.api_base, kw_only=True)
32-
api_key: Optional[str] = field(
33-
default=Factory(lambda: os.environ.get("OPENAI_API_KEY")), kw_only=True
34-
)
35-
organization: Optional[str] = field(
36-
default=openai.organization, kw_only=True
26+
base_url: str = field(default=None, kw_only=True)
27+
api_key: Optional[str] = field(default=None, kw_only=True)
28+
organization: Optional[str] = field(default=None, kw_only=True)
29+
client: openai.OpenAI = field(
30+
init=False,
31+
default=Factory(
32+
lambda self: openai.OpenAI(
33+
api_key=self.api_key,
34+
base_url=self.base_url,
35+
organization=self.organization,
36+
),
37+
takes_self=True,
38+
),
3739
)
3840
tokenizer: OpenAiTokenizer = field(
3941
default=Factory(
@@ -42,29 +44,16 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
4244
kw_only=True,
4345
)
4446

45-
def __attrs_post_init__(self) -> None:
46-
openai.api_type = self.api_type
47-
openai.api_version = self.api_version
48-
openai.api_base = self.api_base
49-
openai.api_key = self.api_key
50-
openai.organization = self.organization
51-
5247
def try_embed_chunk(self, chunk: str) -> list[float]:
5348
# Address a performance issue in older ada models
5449
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
5550
if self.model.endswith("001"):
5651
chunk = chunk.replace("\n", " ")
57-
return openai.Embedding.create(**self._params(chunk))["data"][0][
58-
"embedding"
59-
]
52+
return (
53+
self.client.embeddings.create(**self._params(chunk))
54+
.data[0]
55+
.embedding
56+
)
6057

6158
def _params(self, chunk: str) -> dict:
62-
return {
63-
"input": chunk,
64-
"model": self.model,
65-
"api_key": self.api_key,
66-
"organization": self.organization,
67-
"api_version": self.api_version,
68-
"api_base": self.api_base,
69-
"api_type": self.api_type,
70-
}
59+
return {"input": chunk, "model": self.model}
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,41 @@
11
from attr import define, field, Factory
2-
from griptape.utils import PromptStack
2+
from typing import Optional
33
from griptape.drivers import OpenAiChatPromptDriver
4+
from griptape.utils.prompt_stack import PromptStack
45
from griptape.tokenizers import OpenAiTokenizer
6+
import openai
57

68

79
@define
810
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
911
"""
1012
Attributes:
11-
api_base: API URL.
12-
deployment_id: Azure OpenAI deployment ID.
13-
model: OpenAI model name.
13+
azure_deployment: Azure deployment id.
14+
azure_endpoint: Azure endpoint.
15+
azure_ad_token: Azure Active Directory token.
16+
azure_ad_token_provider: Azure Active Directory token provider.
17+
api_version: API version.
1418
"""
1519

16-
api_base: str = field(kw_only=True)
17-
model: str = field(kw_only=True)
18-
deployment_id: str = field(kw_only=True)
19-
api_type: str = field(default="azure", kw_only=True)
20+
azure_deployment: str = field(kw_only=True)
21+
azure_endpoint: str = field(kw_only=True)
22+
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
23+
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
2024
api_version: str = field(default="2023-05-15", kw_only=True)
25+
client: openai.AzureOpenAI = field(
26+
default=Factory(
27+
lambda self: openai.AzureOpenAI(
28+
organization=self.organization,
29+
api_key=self.api_key,
30+
api_version=self.api_version,
31+
azure_endpoint=self.azure_endpoint,
32+
azure_deployment=self.azure_deployment,
33+
azure_ad_token=self.azure_ad_token,
34+
azure_ad_token_provider=self.azure_ad_token_provider,
35+
),
36+
takes_self=True,
37+
)
38+
)
2139
tokenizer: OpenAiTokenizer = field(
2240
default=Factory(
2341
lambda self: OpenAiTokenizer(model=self.model), takes_self=True
@@ -26,6 +44,27 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
2644
)
2745

2846
def _base_params(self, prompt_stack: PromptStack) -> dict:
29-
return super()._base_params(prompt_stack) | {
30-
"deployment_id": self.deployment_id
47+
params = {
48+
"model": self.model,
49+
"temperature": self.temperature,
50+
"stop": self.tokenizer.stop_sequences,
51+
"user": self.user,
3152
}
53+
54+
if self.response_format == "json_object":
55+
params["response_format"] = {"type": "json_object"}
56+
# JSON mode still requires a system input instructing the LLM to output JSON.
57+
prompt_stack.add_system_input(
58+
"Provide your response as a valid JSON object."
59+
)
60+
61+
messages = self._prompt_stack_to_messages(prompt_stack)
62+
63+
# A max_tokens parameter is not required, but if it is specified by the caller, bound it to
64+
# the maximum value as determined by the tokenizer and pass it to the API.
65+
if self.max_tokens:
66+
params["max_tokens"] = self.max_output_tokens(messages)
67+
68+
params["messages"] = messages
69+
70+
return params

griptape/drivers/prompt/azure_openai_completion_prompt_driver.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
1+
from typing import Optional
12
from attr import define, field, Factory
2-
from griptape.utils import PromptStack
33
from griptape.drivers import OpenAiCompletionPromptDriver
4-
from griptape.tokenizers import OpenAiTokenizer
4+
import openai
55

66

77
@define
88
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
9-
api_base: str = field(kw_only=True)
10-
model: str = field(kw_only=True)
11-
deployment_id: str = field(kw_only=True)
12-
api_type: str = field(default="azure", kw_only=True)
9+
azure_deployment: str = field(kw_only=True)
10+
azure_endpoint: str = field(kw_only=True)
11+
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
12+
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
1313
api_version: str = field(default="2023-05-15", kw_only=True)
14-
tokenizer: OpenAiTokenizer = field(
14+
client: openai.AzureOpenAI = field(
1515
default=Factory(
16-
lambda self: OpenAiTokenizer(model=self.model), takes_self=True
17-
),
18-
kw_only=True,
16+
lambda self: openai.AzureOpenAI(
17+
organization=self.organization,
18+
api_key=self.api_key,
19+
api_version=self.api_version,
20+
azure_endpoint=self.azure_endpoint,
21+
azure_deployment=self.azure_deployment,
22+
),
23+
takes_self=True,
24+
)
1925
)
20-
21-
def _base_params(self, prompt_stack: PromptStack) -> dict:
22-
return super()._base_params(prompt_stack) | {
23-
"deployment_id": self.deployment_id
24-
}

griptape/drivers/prompt/base_prompt_driver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class BasePromptDriver(ExponentialBackoffMixin, ABC):
3232
tokenizer: BaseTokenizer
3333
stream: bool = field(default=False, kw_only=True)
3434

35-
def max_output_tokens(self, text: str) -> int:
35+
def max_output_tokens(self, text: str | list) -> int:
3636
tokens_left = self.tokenizer.count_tokens_left(text)
3737

3838
if self.max_tokens:

0 commit comments

Comments
 (0)