1
1
from __future__ import annotations
2
- import os
3
2
from typing import Optional
4
- import openai
5
3
from attr import define , field , Factory
6
4
from griptape .drivers import BaseEmbeddingDriver
7
5
from griptape .tokenizers import OpenAiTokenizer
6
+ import openai
8
7
9
8
10
9
@define
@@ -13,49 +12,42 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
13
12
Attributes:
14
13
model: OpenAI embedding model name. Defaults to `text-embedding-ada-002`.
15
14
dimensions: Vector dimensions. Defaults to `1536`.
16
- api_type: OpenAI API type, for example 'open_ai' or 'azure'. Defaults to 'open_ai'.
17
- api_version: API version. Defaults to 'OPENAI_API_VERSION' environment variable.
18
- api_base: API URL. Defaults to OpenAI's v1 API URL.
15
+ base_url: API URL. Defaults to OpenAI's v1 API URL.
19
16
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
20
17
organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
21
18
tokenizer: Optionally provide custom `OpenAiTokenizer`.
19
+ client: Optionally provide custom `openai.OpenAI` client.
20
+ azure_deployment: An Azure OpenAi deployment id.
21
+ azure_endpoint: An Azure OpenAi endpoint.
22
+ azure_ad_token: An optional Azure Active Directory token.
23
+ azure_ad_token_provider: An optional Azure Active Directory token provider.
24
+ api_version: An Azure OpenAi API version.
22
25
"""
23
26
24
27
DEFAULT_MODEL = "text-embedding-ada-002"
25
28
DEFAULT_DIMENSIONS = 1536
26
29
27
30
model : str = field (default = DEFAULT_MODEL , kw_only = True )
28
31
dimensions : int = field (default = DEFAULT_DIMENSIONS , kw_only = True )
29
- api_type : str = field (default = openai .api_type , kw_only = True )
30
- api_version : Optional [str ] = field (default = openai .api_version , kw_only = True )
31
- api_base : str = field (default = openai .api_base , kw_only = True )
32
- api_key : Optional [str ] = field (default = Factory (lambda : os .environ .get ("OPENAI_API_KEY" )), kw_only = True )
33
- organization : Optional [str ] = field (default = openai .organization , kw_only = True )
32
+ base_url : str = field (default = None , kw_only = True )
33
+ api_key : Optional [str ] = field (default = None , kw_only = True )
34
+ organization : Optional [str ] = field (default = None , kw_only = True )
35
+ client : openai .OpenAI = field (
36
+ default = Factory (
37
+ lambda self : openai .OpenAI (api_key = self .api_key , base_url = self .base_url , organization = self .organization ),
38
+ takes_self = True ,
39
+ )
40
+ )
34
41
tokenizer : OpenAiTokenizer = field (
35
42
default = Factory (lambda self : OpenAiTokenizer (model = self .model ), takes_self = True ), kw_only = True
36
43
)
37
44
38
- def __attrs_post_init__ (self ) -> None :
39
- openai .api_type = self .api_type
40
- openai .api_version = self .api_version
41
- openai .api_base = self .api_base
42
- openai .api_key = self .api_key
43
- openai .organization = self .organization
44
-
45
45
def try_embed_chunk (self , chunk : str ) -> list [float ]:
46
46
# Address a performance issue in older ada models
47
47
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
48
48
if self .model .endswith ("001" ):
49
49
chunk = chunk .replace ("\n " , " " )
50
- return openai . Embedding . create (** self ._params (chunk ))[ " data" ] [0 ][ " embedding" ]
50
+ return self . client . embeddings . create (** self ._params (chunk )). data [0 ]. embedding
51
51
52
52
def _params (self , chunk : str ) -> dict :
53
- return {
54
- "input" : chunk ,
55
- "model" : self .model ,
56
- "api_key" : self .api_key ,
57
- "organization" : self .organization ,
58
- "api_version" : self .api_version ,
59
- "api_base" : self .api_base ,
60
- "api_type" : self .api_type ,
61
- }
53
+ return {"input" : chunk , "model" : self .model }
0 commit comments