Skip to content

Commit b2c7ba5

Browse files
committed
Update openai sdk
1 parent 658a3d3 commit b2c7ba5

19 files changed

+1338
-1218
lines changed

griptape/drivers/embedding/azure_openai_embedding_driver.py

+12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from attr import define, field, Factory
44
from griptape.drivers import OpenAiEmbeddingDriver
55
from griptape.tokenizers import OpenAiTokenizer
6+
import openai
67

78

89
@define
@@ -28,6 +29,17 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
2829
),
2930
kw_only=True,
3031
)
32+
client: openai.AzureOpenAI = field(
33+
init=False,
34+
default=Factory(
35+
lambda self: openai.AzureOpenAI(
36+
api_key=self.api_key,
37+
base_url=self.base_url,
38+
organization=self.organization,
39+
),
40+
takes_self=True,
41+
),
42+
)
3143

3244
def _params(self, chunk: list[int] | str) -> dict:
3345
return super()._params(chunk) | {"deployment_id": self.deployment_id}
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22
import os
33
from typing import Optional
4-
import openai
54
from attr import define, field, Factory
65
from griptape.drivers import BaseEmbeddingDriver
76
from griptape.tokenizers import OpenAiTokenizer
7+
import openai
88

99

1010
@define
@@ -26,14 +26,21 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
2626

2727
model: str = field(default=DEFAULT_MODEL, kw_only=True)
2828
dimensions: int = field(default=DEFAULT_DIMENSIONS, kw_only=True)
29-
api_type: str = field(default=openai.api_type, kw_only=True)
30-
api_version: Optional[str] = field(default=openai.api_version, kw_only=True)
31-
api_base: str = field(default=openai.api_base, kw_only=True)
29+
base_url: str = field(default=None, kw_only=True)
3230
api_key: Optional[str] = field(
3331
default=Factory(lambda: os.environ.get("OPENAI_API_KEY")), kw_only=True
3432
)
35-
organization: Optional[str] = field(
36-
default=openai.organization, kw_only=True
33+
organization: Optional[str] = field(default=None, kw_only=True)
34+
client: openai.OpenAI = field(
35+
init=False,
36+
default=Factory(
37+
lambda self: openai.OpenAI(
38+
api_key=self.api_key,
39+
base_url=self.base_url,
40+
organization=self.organization,
41+
),
42+
takes_self=True,
43+
),
3744
)
3845
tokenizer: OpenAiTokenizer = field(
3946
default=Factory(
@@ -42,29 +49,16 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
4249
kw_only=True,
4350
)
4451

45-
def __attrs_post_init__(self) -> None:
46-
openai.api_type = self.api_type
47-
openai.api_version = self.api_version
48-
openai.api_base = self.api_base
49-
openai.api_key = self.api_key
50-
openai.organization = self.organization
51-
5252
def try_embed_chunk(self, chunk: str) -> list[float]:
5353
# Address a performance issue in older ada models
5454
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
5555
if self.model.endswith("001"):
5656
chunk = chunk.replace("\n", " ")
57-
return openai.Embedding.create(**self._params(chunk))["data"][0][
58-
"embedding"
59-
]
57+
return (
58+
self.client.embeddings.create(**self._params(chunk))
59+
.data[0]
60+
.embedding
61+
)
6062

6163
def _params(self, chunk: str) -> dict:
62-
return {
63-
"input": chunk,
64-
"model": self.model,
65-
"api_key": self.api_key,
66-
"organization": self.organization,
67-
"api_version": self.api_version,
68-
"api_base": self.api_base,
69-
"api_type": self.api_type,
70-
}
64+
return {"input": chunk, "model": self.model}

griptape/drivers/prompt/azure_openai_chat_prompt_driver.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,44 @@
1+
import os
12
from attr import define, field, Factory
2-
from griptape.utils import PromptStack
3+
from typing import Optional
34
from griptape.drivers import OpenAiChatPromptDriver
45
from griptape.tokenizers import OpenAiTokenizer
6+
import openai
57

68

79
@define
810
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
911
"""
1012
Attributes:
11-
api_base: API URL.
13+
api_version: API version.
14+
azure_deployment: Azure deployment id.
15+
azure_endpoint: Azure endpoint.
1216
deployment_id: Azure OpenAI deployment ID.
1317
model: OpenAI model name.
1418
"""
1519

16-
api_base: str = field(kw_only=True)
17-
model: str = field(kw_only=True)
18-
deployment_id: str = field(kw_only=True)
19-
api_type: str = field(default="azure", kw_only=True)
20+
azure_deployment: str = field(kw_only=True)
21+
azure_endpoint: str = field(kw_only=True)
2022
api_version: str = field(default="2023-05-15", kw_only=True)
23+
api_key: Optional[str] = field(
24+
default=Factory(lambda: os.environ.get("AZURE_OPENAI_API_KEY")),
25+
kw_only=True,
26+
)
27+
client: openai.AzureOpenAI = field(
28+
default=Factory(
29+
lambda self: openai.AzureOpenAI(
30+
organization=self.organization,
31+
api_key=self.api_key,
32+
api_version=self.api_version,
33+
azure_endpoint=self.azure_endpoint,
34+
azure_deployment=self.azure_deployment,
35+
),
36+
takes_self=True,
37+
)
38+
)
2139
tokenizer: OpenAiTokenizer = field(
2240
default=Factory(
2341
lambda self: OpenAiTokenizer(model=self.model), takes_self=True
2442
),
2543
kw_only=True,
2644
)
27-
28-
def _base_params(self, prompt_stack: PromptStack) -> dict:
29-
return super()._base_params(prompt_stack) | {
30-
"deployment_id": self.deployment_id
31-
}
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1+
import os
12
from attr import define, field, Factory
2-
from griptape.utils import PromptStack
3+
from typing import Optional
34
from griptape.drivers import OpenAiCompletionPromptDriver
4-
from griptape.tokenizers import OpenAiTokenizer
5+
import openai
56

67

78
@define
89
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
9-
api_base: str = field(kw_only=True)
10-
model: str = field(kw_only=True)
11-
deployment_id: str = field(kw_only=True)
12-
api_type: str = field(default="azure", kw_only=True)
10+
azure_deployment: str = field(kw_only=True)
11+
azure_endpoint: str = field(kw_only=True)
1312
api_version: str = field(default="2023-05-15", kw_only=True)
14-
tokenizer: OpenAiTokenizer = field(
15-
default=Factory(
16-
lambda self: OpenAiTokenizer(model=self.model), takes_self=True
17-
),
13+
api_key: Optional[str] = field(
14+
default=Factory(lambda: os.environ.get("AZURE_OPENAI_API_KEY")),
1815
kw_only=True,
1916
)
20-
21-
def _base_params(self, prompt_stack: PromptStack) -> dict:
22-
return super()._base_params(prompt_stack) | {
23-
"deployment_id": self.deployment_id
24-
}
17+
client: openai.AzureOpenAI = field(
18+
default=Factory(
19+
lambda self: openai.AzureOpenAI(
20+
organization=self.organization,
21+
api_key=self.api_key,
22+
api_version=self.api_version,
23+
azure_endpoint=self.azure_endpoint,
24+
azure_deployment=self.azure_deployment,
25+
),
26+
takes_self=True,
27+
)
28+
)

griptape/drivers/prompt/base_prompt_driver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class BasePromptDriver(ExponentialBackoffMixin, ABC):
3232
tokenizer: BaseTokenizer
3333
stream: bool = field(default=False, kw_only=True)
3434

35-
def max_output_tokens(self, text: str) -> int:
35+
def max_output_tokens(self, text: str | list) -> int:
3636
tokens_left = self.tokenizer.count_tokens_left(text)
3737

3838
if self.max_tokens:

griptape/drivers/prompt/openai_chat_prompt_driver.py

+46-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
import os
3-
from typing import Iterator, Optional
3+
from typing import Iterator, Optional, Any, Literal
44
import openai
55
from attr import define, field, Factory
66
from griptape.artifacts import TextArtifact
@@ -10,23 +10,22 @@
1010
from typing import Tuple, Type
1111
import dateparser
1212
from datetime import datetime, timedelta
13-
import requests
1413

1514

1615
@define
1716
class OpenAiChatPromptDriver(BasePromptDriver):
1817
"""
1918
Attributes:
20-
api_type: Can be changed to use OpenAI models on Azure.
21-
api_version: API version.
22-
api_base: API URL.
19+
base_url: API URL.
2320
api_key: API key to pass directly; by default uses `OPENAI_API_KEY_PATH` environment variable.
2421
max_tokens: Optional maximum return tokens. If not specified, no value will be passed to the API. If set, the
2522
value will be bounded to the maximum possible as determined by the tokenizer.
2623
model: OpenAI model name. Uses `gpt-4` by default.
2724
organization: OpenAI organization.
2825
tokenizer: Custom `OpenAiTokenizer`.
2926
user: OpenAI user.
27+
response_format: Optional response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
28+
seed: Optional seed.
3029
_ratelimit_request_limit: The maximum number of requests allowed in the current rate limit window.
3130
_ratelimit_requests_remaining: The number of requests remaining in the current rate limit window.
3231
_ratelimit_requests_reset_at: The time at which the current rate limit window resets.
@@ -35,14 +34,23 @@ class OpenAiChatPromptDriver(BasePromptDriver):
3534
_ratelimit_tokens_reset_at: The time at which the current rate limit window resets.
3635
"""
3736

38-
api_type: str = field(default=openai.api_type, kw_only=True)
39-
api_version: Optional[str] = field(default=openai.api_version, kw_only=True)
40-
api_base: str = field(default=openai.api_base, kw_only=True)
37+
base_url: str = field(default=None, kw_only=True)
4138
api_key: Optional[str] = field(
4239
default=Factory(lambda: os.environ.get("OPENAI_API_KEY")), kw_only=True
4340
)
4441
organization: Optional[str] = field(
45-
default=openai.organization, kw_only=True
42+
default=os.environ.get("OPENAI_ORG_ID"), kw_only=True
43+
)
44+
seed: Optional[int] = field(default=None, kw_only=True)
45+
client: openai.OpenAI = field(
46+
default=Factory(
47+
lambda self: openai.OpenAI(
48+
api_key=self.api_key,
49+
base_url=self.base_url,
50+
organization=self.organization,
51+
),
52+
takes_self=True,
53+
)
4654
)
4755
model: str = field(kw_only=True)
4856
tokenizer: OpenAiTokenizer = field(
@@ -52,8 +60,11 @@ class OpenAiChatPromptDriver(BasePromptDriver):
5260
kw_only=True,
5361
)
5462
user: str = field(default="", kw_only=True)
63+
response_format: Optional[Literal["json_object"]] = field(
64+
default=None, kw_only=True
65+
)
5566
ignored_exception_types: Tuple[Type[Exception], ...] = field(
56-
default=Factory(lambda: openai.InvalidRequestError), kw_only=True
67+
default=Factory(lambda: openai.BadRequestError), kw_only=True
5768
)
5869
_ratelimit_request_limit: Optional[int] = field(init=False, default=None)
5970
_ratelimit_requests_remaining: Optional[int] = field(
@@ -68,40 +79,36 @@ class OpenAiChatPromptDriver(BasePromptDriver):
6879
init=False, default=None
6980
)
7081

71-
def __attrs_post_init__(self) -> None:
72-
# Define a hook to pull rate limit metadata from the OpenAI API response header.
73-
openai.requestssession = requests.Session()
74-
openai.requestssession.hooks = {
75-
"response": self._extract_ratelimit_metadata
76-
}
77-
7882
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
79-
result = openai.ChatCompletion.create(**self._base_params(prompt_stack))
83+
result = self.client.chat.completions.with_raw_response.create(
84+
**self._base_params(prompt_stack)
85+
)
86+
87+
self._extract_ratelimit_metadata(result)
8088

89+
result = result.parse()
8190
if len(result.choices) == 1:
82-
return TextArtifact(
83-
value=result.choices[0]["message"]["content"].strip()
84-
)
91+
return TextArtifact(value=result.choices[0].message.content.strip())
8592
else:
8693
raise Exception(
8794
"Completion with more than one choice is not supported yet."
8895
)
8996

9097
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
91-
result = openai.ChatCompletion.create(
98+
result = self.client.chat.completions.create(
9299
**self._base_params(prompt_stack), stream=True
93100
)
94101

95102
for chunk in result:
96103
if len(chunk.choices) == 1:
97-
delta = chunk.choices[0]["delta"]
104+
delta = chunk.choices[0].delta
98105
else:
99106
raise Exception(
100107
"Completion with more than one choice is not supported yet."
101108
)
102109

103-
if "content" in delta:
104-
delta_content = delta["content"]
110+
if delta.content is not None:
111+
delta_content = delta.content
105112

106113
yield TextArtifact(value=delta_content)
107114

@@ -112,33 +119,37 @@ def token_count(self, prompt_stack: PromptStack) -> int:
112119

113120
def _prompt_stack_to_messages(
114121
self, prompt_stack: PromptStack
115-
) -> list[dict]:
122+
) -> list[dict[str, Any]]:
116123
return [
117124
{"role": self.__to_openai_role(i), "content": i.content}
118125
for i in prompt_stack.inputs
119126
]
120127

121128
def _base_params(self, prompt_stack: PromptStack) -> dict:
122-
messages = self._prompt_stack_to_messages(prompt_stack)
123-
124129
params = {
125130
"model": self.model,
126131
"temperature": self.temperature,
127132
"stop": self.tokenizer.stop_sequences,
128133
"user": self.user,
129-
"api_key": self.api_key,
130-
"organization": self.organization,
131-
"api_version": self.api_version,
132-
"api_base": self.api_base,
133-
"api_type": self.api_type,
134-
"messages": messages,
134+
"seed": self.seed,
135135
}
136136

137+
if self.response_format == "json_object":
138+
params["response_format"] = {"type": "json_object"}
139+
# JSON mode still requires a system input instructing the LLM to output JSON.
140+
prompt_stack.add_system_input(
141+
"Provide your response as valid JSON."
142+
)
143+
144+
messages = self._prompt_stack_to_messages(prompt_stack)
145+
137146
# A max_tokens parameter is not required, but if it is specified by the caller, bound it to
138147
# the maximum value as determined by the tokenizer and pass it to the API.
139148
if self.max_tokens:
140149
params["max_tokens"] = self.max_output_tokens(messages)
141150

151+
params["messages"] = messages
152+
142153
return params
143154

144155
def __to_openai_role(self, prompt_input: PromptStack.Input) -> str:
@@ -149,7 +160,7 @@ def __to_openai_role(self, prompt_input: PromptStack.Input) -> str:
149160
else:
150161
return "user"
151162

152-
def _extract_ratelimit_metadata(self, response, *args, **kwargs):
163+
def _extract_ratelimit_metadata(self, response):
153164
# The OpenAI SDK's requestssession variable is global, so this hook will fire for all API requests.
154165
# The following headers are not reliably returned in every API call, so we check for the presence of the
155166
# headers before reading and parsing their values to prevent other SDK users from encountering KeyErrors.

0 commit comments

Comments
 (0)