Skip to content

Commit 95213e8

Browse files
committed
Remove init=False, improve comments
1 parent bac8d64 commit 95213e8

8 files changed

+80
-47
lines changed

griptape/drivers/embedding/azure_openai_embedding_driver.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import Optional
34
from attr import define, field, Factory
45
from griptape.drivers import OpenAiEmbeddingDriver
56
from griptape.tokenizers import OpenAiTokenizer
@@ -10,29 +11,36 @@
1011
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
1112
"""
1213
Attributes:
13-
model: OpenAI embedding model name.
14-
deployment_id: Azure OpenAI deployment ID.
15-
api_base: API URL.
16-
api_type: OpenAI API type. Defaults to 'azure'.
17-
api_version: API version. Defaults to '2023-05-15'.
18-
tokenizer: Optionally provide custom `OpenAiTokenizer`.
14+
azure_deployment: An Azure OpenAi deployment id.
15+
azure_endpoint: An Azure OpenAi endpoint.
16+
azure_ad_token: An optional Azure Active Directory token.
17+
azure_ad_token_provider: An optional Azure Active Directory token provider.
18+
api_version: An Azure OpenAi API version.
19+
tokenizer: An `OpenAiTokenizer`.
20+
client: An `openai.AzureOpenAI` client.
1921
"""
2022

21-
deployment_id: str = field(kw_only=True)
22-
api_base: str = field(kw_only=True)
23-
api_type: str = field(default="azure", kw_only=True)
23+
azure_deployment: str = field(kw_only=True)
24+
azure_endpoint: str = field(kw_only=True)
25+
azure_ad_token: Optional[str] = field(kw_only=True, default=None)
26+
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
2427
api_version: str = field(default="2023-05-15", kw_only=True)
2528
tokenizer: OpenAiTokenizer = field(
2629
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
2730
)
2831
client: openai.AzureOpenAI = field(
29-
init=False,
3032
default=Factory(
3133
lambda self: openai.AzureOpenAI(
32-
api_key=self.api_key, base_url=self.base_url, organization=self.organization
34+
organization=self.organization,
35+
api_key=self.api_key,
36+
api_version=self.api_version,
37+
azure_endpoint=self.azure_endpoint,
38+
azure_deployment=self.azure_deployment,
39+
azure_ad_token=self.azure_ad_token,
40+
azure_ad_token_provider=self.azure_ad_token_provider,
3341
),
3442
takes_self=True,
35-
),
43+
)
3644
)
3745

3846
def _params(self, chunk: str) -> dict:

griptape/drivers/embedding/openai_embedding_driver.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
1616
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
1717
organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
1818
tokenizer: Optionally provide custom `OpenAiTokenizer`.
19+
client: Optionally provide custom `openai.OpenAI` client.
20+
azure_deployment: An Azure OpenAi deployment id.
21+
azure_endpoint: An Azure OpenAi endpoint.
22+
azure_ad_token: An optional Azure Active Directory token.
23+
azure_ad_token_provider: An optional Azure Active Directory token provider.
24+
api_version: An Azure OpenAi API version.
1925
"""
2026

2127
DEFAULT_MODEL = "text-embedding-ada-002"
@@ -27,11 +33,10 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
2733
api_key: Optional[str] = field(default=None, kw_only=True)
2834
organization: Optional[str] = field(default=None, kw_only=True)
2935
client: openai.OpenAI = field(
30-
init=False,
3136
default=Factory(
3237
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
3338
takes_self=True,
34-
),
39+
)
3540
)
3641
tokenizer: OpenAiTokenizer = field(
3742
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True

griptape/drivers/prompt/azure_openai_chat_prompt_driver.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@
22
from typing import Optional
33
from griptape.utils import PromptStack
44
from griptape.drivers import OpenAiChatPromptDriver
5-
from griptape.tokenizers import OpenAiTokenizer
65
import openai
76

87

98
@define
109
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
1110
"""
1211
Attributes:
13-
azure_deployment: Azure deployment id.
14-
azure_endpoint: Azure endpoint.
15-
azure_ad_token: Azure Active Directory token.
16-
azure_ad_token_provider: Azure Active Directory token provider.
17-
api_version: API version.
12+
azure_deployment: An Azure OpenAi deployment id.
13+
azure_endpoint: An Azure OpenAi endpoint.
14+
azure_ad_token: An optional Azure Active Directory token.
15+
azure_ad_token_provider: An optional Azure Active Directory token provider.
16+
api_version: An Azure OpenAi API version.
17+
client: An `openai.AzureOpenAI` client.
1818
"""
1919

2020
azure_deployment: str = field(kw_only=True)
@@ -36,9 +36,6 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
3636
takes_self=True,
3737
)
3838
)
39-
tokenizer: OpenAiTokenizer = field(
40-
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
41-
)
4239

4340
def _base_params(self, prompt_stack: PromptStack) -> dict:
4441
params = super()._base_params(prompt_stack)

griptape/drivers/prompt/azure_openai_completion_prompt_driver.py

+10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66

77
@define
88
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
9+
"""
10+
Attributes:
11+
azure_deployment: An Azure OpenAi deployment id.
12+
azure_endpoint: An Azure OpenAi endpoint.
13+
azure_ad_token: An optional Azure Active Directory token.
14+
azure_ad_token_provider: An optional Azure Active Directory token provider.
15+
api_version: An Azure OpenAi API version.
16+
client: An `openai.AzureOpenAI` client.
17+
"""
18+
919
azure_deployment: str = field(kw_only=True)
1020
azure_endpoint: str = field(kw_only=True)
1121
azure_ad_token: Optional[str] = field(kw_only=True, default=None)

griptape/drivers/prompt/base_prompt_driver.py

+13
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@
1414

1515
@define
1616
class BasePromptDriver(ExponentialBackoffMixin, ABC):
17+
"""Base class for Prompt Drivers.
18+
19+
Attributes:
20+
temperature: The temperature to use for the completion.
21+
max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
22+
structure: An optional `Structure` to publish events to.
23+
prompt_stack_to_string: A function that converts a `PromptStack` to a string.
24+
ignored_exception_types: A tuple of exception types to ignore.
25+
model: The model name.
26+
tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
27+
stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
28+
"""
29+
1730
temperature: float = field(default=0.1, kw_only=True)
1831
max_tokens: Optional[int] = field(default=None, kw_only=True)
1932
structure: Optional[Structure] = field(default=None, kw_only=True)

griptape/drivers/prompt/openai_chat_prompt_driver.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,16 @@
1515
class OpenAiChatPromptDriver(BasePromptDriver):
1616
"""
1717
Attributes:
18-
base_url: API URL.
19-
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
20-
max_tokens: Optional maximum return tokens. If not specified, no value will be passed to the API. If set, the
21-
value will be bounded to the maximum possible as determined by the tokenizer.
22-
model: OpenAI model name.
23-
organization: OpenAI organization. Defaults to `OPENAI_ORG_ID` environment variable.
24-
client: OpenAI client. Defaults to `openai.OpenAI`.
25-
tokenizer: Custom `OpenAiTokenizer`.
26-
user: OpenAI user.
27-
response_format: Optional response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
28-
seed: Optional seed.
18+
base_url: An optional OpenAi API URL.
19+
api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used.
20+
organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used.
21+
client: An `openai.OpenAI` client.
22+
model: An OpenAI model name.
23+
tokenizer: An `OpenAiTokenizer`.
24+
user: An optional user id. Can be used to track requests by user.
25+
response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
26+
seed: An optional OpenAi Chat Completion seed.
27+
ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
2928
_ratelimit_request_limit: The maximum number of requests allowed in the current rate limit window.
3029
_ratelimit_requests_remaining: The number of requests remaining in the current rate limit window.
3130
_ratelimit_requests_reset_at: The time at which the current rate limit window resets.
@@ -34,10 +33,9 @@ class OpenAiChatPromptDriver(BasePromptDriver):
3433
_ratelimit_tokens_reset_at: The time at which the current rate limit window resets.
3534
"""
3635

37-
base_url: str = field(default=None, kw_only=True)
36+
base_url: Optional[str] = field(default=None, kw_only=True)
3837
api_key: Optional[str] = field(default=None, kw_only=True)
3938
organization: Optional[str] = field(default=None, kw_only=True)
40-
seed: Optional[int] = field(default=None, kw_only=True)
4139
client: openai.OpenAI = field(
4240
default=Factory(
4341
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
@@ -48,8 +46,9 @@ class OpenAiChatPromptDriver(BasePromptDriver):
4846
tokenizer: OpenAiTokenizer = field(
4947
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
5048
)
51-
user: str = field(default="", kw_only=True)
49+
user: Optional[str] = field(default=None, kw_only=True)
5250
response_format: Optional[Literal["json_object"]] = field(default=None, kw_only=True)
51+
seed: Optional[int] = field(default=None, kw_only=True)
5352
ignored_exception_types: Tuple[Type[Exception], ...] = field(
5453
default=Factory(
5554
lambda: (

griptape/drivers/prompt/openai_completion_prompt_driver.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
class OpenAiCompletionPromptDriver(BasePromptDriver):
1313
"""
1414
Attributes:
15-
base_url: API URL.
16-
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
17-
max_tokens: Optional maximum return tokens. If not specified, the value will be automatically generated based by the tokenizer.
18-
model: OpenAI model name. Uses `gpt-4` by default.
19-
organization: OpenAI organization. Defaults to `OPENAI_ORG_ID` environment variable.
20-
client: OpenAI client. Defaults to `openai.OpenAI`.
21-
tokenizer: Custom `OpenAiTokenizer`.
22-
user: OpenAI user.
15+
base_url: An optional OpenAi API URL.
16+
api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used.
17+
organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used.
18+
client: An `openai.OpenAI` client.
19+
model: An OpenAI model name.
20+
tokenizer: An `OpenAiTokenizer`.
21+
user: An optional user id. Can be used to track requests by user.
22+
ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
2323
"""
2424

25-
base_url: str = field(default=None, kw_only=True)
25+
base_url: Optional[str] = field(default=None, kw_only=True)
2626
api_key: Optional[str] = field(default=None, kw_only=True)
2727
organization: Optional[str] = field(default=None, kw_only=True)
2828
client: openai.OpenAI = field(
@@ -35,7 +35,7 @@ class OpenAiCompletionPromptDriver(BasePromptDriver):
3535
tokenizer: OpenAiTokenizer = field(
3636
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
3737
)
38-
user: str = field(default="", kw_only=True)
38+
user: Optional[str] = field(default=None, kw_only=True)
3939
ignored_exception_types: Tuple[Type[Exception], ...] = field(
4040
default=Factory(
4141
lambda: (

griptape/tokenizers/openai_tokenizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class OpenAiTokenizer(BaseTokenizer):
1515
DEFAULT_MAX_TOKENS = 2049
1616
TOKEN_OFFSET = 8
1717

18+
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
1819
MODEL_PREFIXES_TO_MAX_TOKENS = {
1920
"gpt-4-1106": 128000,
2021
"gpt-4-32k": 32768,

0 commit comments

Comments
 (0)