1
1
from __future__ import annotations
2
2
import os
3
- from typing import Iterator , Optional
3
+ from typing import Iterator , Optional , Any , Literal
4
4
import openai
5
5
from attr import define , field , Factory
6
6
from griptape .artifacts import TextArtifact
10
10
from typing import Tuple , Type
11
11
import dateparser
12
12
from datetime import datetime , timedelta
13
- import requests
14
13
15
14
16
15
@define
17
16
class OpenAiChatPromptDriver (BasePromptDriver ):
18
17
"""
19
18
Attributes:
20
- api_type: Can be changed to use OpenAI models on Azure.
21
- api_version: API version.
22
- api_base: API URL.
19
+ base_url: API URL.
23
20
api_key: API key to pass directly; by default uses `OPENAI_API_KEY_PATH` environment variable.
24
21
max_tokens: Optional maximum return tokens. If not specified, no value will be passed to the API. If set, the
25
22
value will be bounded to the maximum possible as determined by the tokenizer.
26
23
model: OpenAI model name. Uses `gpt-4` by default.
27
24
organization: OpenAI organization.
28
25
tokenizer: Custom `OpenAiTokenizer`.
29
26
user: OpenAI user.
27
+ response_format: Optional response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
28
+ seed: Optional seed.
30
29
_ratelimit_request_limit: The maximum number of requests allowed in the current rate limit window.
31
30
_ratelimit_requests_remaining: The number of requests remaining in the current rate limit window.
32
31
_ratelimit_requests_reset_at: The time at which the current rate limit window resets.
@@ -35,14 +34,23 @@ class OpenAiChatPromptDriver(BasePromptDriver):
35
34
_ratelimit_tokens_reset_at: The time at which the current rate limit window resets.
36
35
"""
37
36
38
- api_type : str = field (default = openai .api_type , kw_only = True )
39
- api_version : Optional [str ] = field (default = openai .api_version , kw_only = True )
40
- api_base : str = field (default = openai .api_base , kw_only = True )
37
+ base_url : str = field (default = None , kw_only = True )
41
38
api_key : Optional [str ] = field (
42
39
default = Factory (lambda : os .environ .get ("OPENAI_API_KEY" )), kw_only = True
43
40
)
44
41
organization : Optional [str ] = field (
45
- default = openai .organization , kw_only = True
42
+ default = os .environ .get ("OPENAI_ORG_ID" ), kw_only = True
43
+ )
44
+ seed : Optional [int ] = field (default = None , kw_only = True )
45
+ client : openai .OpenAI = field (
46
+ default = Factory (
47
+ lambda self : openai .OpenAI (
48
+ api_key = self .api_key ,
49
+ base_url = self .base_url ,
50
+ organization = self .organization ,
51
+ ),
52
+ takes_self = True ,
53
+ )
46
54
)
47
55
model : str = field (kw_only = True )
48
56
tokenizer : OpenAiTokenizer = field (
@@ -52,8 +60,11 @@ class OpenAiChatPromptDriver(BasePromptDriver):
52
60
kw_only = True ,
53
61
)
54
62
user : str = field (default = "" , kw_only = True )
63
+ response_format : Optional [Literal ["json_object" ]] = field (
64
+ default = None , kw_only = True
65
+ )
55
66
ignored_exception_types : Tuple [Type [Exception ], ...] = field (
56
- default = Factory (lambda : openai .InvalidRequestError ), kw_only = True
67
+ default = Factory (lambda : openai .BadRequestError ), kw_only = True
57
68
)
58
69
_ratelimit_request_limit : Optional [int ] = field (init = False , default = None )
59
70
_ratelimit_requests_remaining : Optional [int ] = field (
@@ -68,40 +79,36 @@ class OpenAiChatPromptDriver(BasePromptDriver):
68
79
init = False , default = None
69
80
)
70
81
71
- def __attrs_post_init__ (self ) -> None :
72
- # Define a hook to pull rate limit metadata from the OpenAI API response header.
73
- openai .requestssession = requests .Session ()
74
- openai .requestssession .hooks = {
75
- "response" : self ._extract_ratelimit_metadata
76
- }
77
-
78
82
def try_run (self , prompt_stack : PromptStack ) -> TextArtifact :
79
- result = openai .ChatCompletion .create (** self ._base_params (prompt_stack ))
83
+ result = self .client .chat .completions .with_raw_response .create (
84
+ ** self ._base_params (prompt_stack )
85
+ )
86
+
87
+ self ._extract_ratelimit_metadata (result )
80
88
89
+ result = result .parse ()
81
90
if len (result .choices ) == 1 :
82
- return TextArtifact (
83
- value = result .choices [0 ]["message" ]["content" ].strip ()
84
- )
91
+ return TextArtifact (value = result .choices [0 ].message .content .strip ())
85
92
else :
86
93
raise Exception (
87
94
"Completion with more than one choice is not supported yet."
88
95
)
89
96
90
97
def try_stream (self , prompt_stack : PromptStack ) -> Iterator [TextArtifact ]:
91
- result = openai . ChatCompletion .create (
98
+ result = self . client . chat . completions .create (
92
99
** self ._base_params (prompt_stack ), stream = True
93
100
)
94
101
95
102
for chunk in result :
96
103
if len (chunk .choices ) == 1 :
97
- delta = chunk .choices [0 ][ " delta" ]
104
+ delta = chunk .choices [0 ]. delta
98
105
else :
99
106
raise Exception (
100
107
"Completion with more than one choice is not supported yet."
101
108
)
102
109
103
- if " content" in delta :
104
- delta_content = delta [ " content" ]
110
+ if delta . content is not None :
111
+ delta_content = delta . content
105
112
106
113
yield TextArtifact (value = delta_content )
107
114
@@ -112,33 +119,37 @@ def token_count(self, prompt_stack: PromptStack) -> int:
112
119
113
120
def _prompt_stack_to_messages (
114
121
self , prompt_stack : PromptStack
115
- ) -> list [dict ]:
122
+ ) -> list [dict [ str , Any ] ]:
116
123
return [
117
124
{"role" : self .__to_openai_role (i ), "content" : i .content }
118
125
for i in prompt_stack .inputs
119
126
]
120
127
121
128
def _base_params (self , prompt_stack : PromptStack ) -> dict :
122
- messages = self ._prompt_stack_to_messages (prompt_stack )
123
-
124
129
params = {
125
130
"model" : self .model ,
126
131
"temperature" : self .temperature ,
127
132
"stop" : self .tokenizer .stop_sequences ,
128
133
"user" : self .user ,
129
- "api_key" : self .api_key ,
130
- "organization" : self .organization ,
131
- "api_version" : self .api_version ,
132
- "api_base" : self .api_base ,
133
- "api_type" : self .api_type ,
134
- "messages" : messages ,
134
+ "seed" : self .seed ,
135
135
}
136
136
137
+ if self .response_format == "json_object" :
138
+ params ["response_format" ] = {"type" : "json_object" }
139
+ # JSON mode still requires a system input instructing the LLM to output JSON.
140
+ prompt_stack .add_system_input (
141
+ "Provide your response as a valid JSON object."
142
+ )
143
+
144
+ messages = self ._prompt_stack_to_messages (prompt_stack )
145
+
137
146
# A max_tokens parameter is not required, but if it is specified by the caller, bound it to
138
147
# the maximum value as determined by the tokenizer and pass it to the API.
139
148
if self .max_tokens :
140
149
params ["max_tokens" ] = self .max_output_tokens (messages )
141
150
151
+ params ["messages" ] = messages
152
+
142
153
return params
143
154
144
155
def __to_openai_role (self , prompt_input : PromptStack .Input ) -> str :
@@ -149,7 +160,7 @@ def __to_openai_role(self, prompt_input: PromptStack.Input) -> str:
149
160
else :
150
161
return "user"
151
162
152
- def _extract_ratelimit_metadata (self , response , * args , ** kwargs ):
163
+ def _extract_ratelimit_metadata (self , response ):
153
164
# The OpenAI SDK's requestssession variable is global, so this hook will fire for all API requests.
154
165
# The following headers are not reliably returned in every API call, so we check for the presence of the
155
166
# headers before reading and parsing their values to prevent other SDK users from encountering KeyErrors.
0 commit comments