1
1
from __future__ import annotations
2
- import os
3
2
from typing import Optional
4
- import openai
5
3
from attr import define , field , Factory
6
4
from griptape .drivers import BaseEmbeddingDriver
7
5
from griptape .tokenizers import OpenAiTokenizer
6
+ import openai
8
7
9
8
10
9
@define
@@ -13,9 +12,7 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
13
12
Attributes:
14
13
model: OpenAI embedding model name. Defaults to `text-embedding-ada-002`.
15
14
dimensions: Vector dimensions. Defaults to `1536`.
16
- api_type: OpenAI API type, for example 'open_ai' or 'azure'. Defaults to 'open_ai'.
17
- api_version: API version. Defaults to 'OPENAI_API_VERSION' environment variable.
18
- api_base: API URL. Defaults to OpenAI's v1 API URL.
15
+ base_url: API URL. Defaults to OpenAI's v1 API URL.
19
16
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
20
17
organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
21
18
tokenizer: Optionally provide custom `OpenAiTokenizer`.
@@ -26,14 +23,19 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
26
23
27
24
model : str = field (default = DEFAULT_MODEL , kw_only = True )
28
25
dimensions : int = field (default = DEFAULT_DIMENSIONS , kw_only = True )
29
- api_type : str = field (default = openai .api_type , kw_only = True )
30
- api_version : Optional [str ] = field (default = openai .api_version , kw_only = True )
31
- api_base : str = field (default = openai .api_base , kw_only = True )
32
- api_key : Optional [str ] = field (
33
- default = Factory (lambda : os .environ .get ("OPENAI_API_KEY" )), kw_only = True
34
- )
35
- organization : Optional [str ] = field (
36
- default = openai .organization , kw_only = True
26
+ base_url : str = field (default = None , kw_only = True )
27
+ api_key : Optional [str ] = field (default = None , kw_only = True )
28
+ organization : Optional [str ] = field (default = None , kw_only = True )
29
+ client : openai .OpenAI = field (
30
+ init = False ,
31
+ default = Factory (
32
+ lambda self : openai .OpenAI (
33
+ api_key = self .api_key ,
34
+ base_url = self .base_url ,
35
+ organization = self .organization ,
36
+ ),
37
+ takes_self = True ,
38
+ ),
37
39
)
38
40
tokenizer : OpenAiTokenizer = field (
39
41
default = Factory (
@@ -42,29 +44,16 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
42
44
kw_only = True ,
43
45
)
44
46
45
- def __attrs_post_init__ (self ) -> None :
46
- openai .api_type = self .api_type
47
- openai .api_version = self .api_version
48
- openai .api_base = self .api_base
49
- openai .api_key = self .api_key
50
- openai .organization = self .organization
51
-
52
47
def try_embed_chunk (self , chunk : str ) -> list [float ]:
53
48
# Address a performance issue in older ada models
54
49
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
55
50
if self .model .endswith ("001" ):
56
51
chunk = chunk .replace ("\n " , " " )
57
- return openai .Embedding .create (** self ._params (chunk ))["data" ][0 ][
58
- "embedding"
59
- ]
52
+ return (
53
+ self .client .embeddings .create (** self ._params (chunk ))
54
+ .data [0 ]
55
+ .embedding
56
+ )
60
57
61
58
def _params (self , chunk : str ) -> dict :
62
- return {
63
- "input" : chunk ,
64
- "model" : self .model ,
65
- "api_key" : self .api_key ,
66
- "organization" : self .organization ,
67
- "api_version" : self .api_version ,
68
- "api_base" : self .api_base ,
69
- "api_type" : self .api_type ,
70
- }
59
+ return {"input" : chunk , "model" : self .model }
0 commit comments