-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit test for OpenAI compatible EmbeddingFunc
- Loading branch information
1 parent
ac9e437
commit c9f3d82
Showing
1 changed file
with
86 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package chromem_test | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"slices" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/philippgille/chromem-go" | ||
) | ||
|
||
type openAIResponse struct { | ||
Data []struct { | ||
Embedding []float32 `json:"embedding"` | ||
} `json:"data"` | ||
} | ||
|
||
func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { | ||
apiKey := "secret" | ||
model := "model-small" | ||
baseURLSuffix := "/v1" | ||
document := "hello world" | ||
|
||
wantBody, err := json.Marshal(map[string]string{ | ||
"input": document, | ||
"model": model, | ||
}) | ||
if err != nil { | ||
t.Error("unexpected error:", err) | ||
} | ||
wantRes := []float32{-0.1, 0.1, 0.2} | ||
|
||
// Mock server | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
// Check URL | ||
if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") { | ||
t.Error("expected URL", baseURLSuffix+"/embedding", "got", r.URL.Path) | ||
} | ||
// Check method | ||
if r.Method != "POST" { | ||
t.Error("expected method POST, got", r.Method) | ||
} | ||
// Check headers | ||
if r.Header.Get("Authorization") != "Bearer "+apiKey { | ||
t.Error("expected Authorization header", "Bearer "+apiKey, "got", r.Header.Get("Authorization")) | ||
} | ||
if r.Header.Get("Content-Type") != "application/json" { | ||
t.Error("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type")) | ||
} | ||
// Check body | ||
body, err := io.ReadAll(r.Body) | ||
if err != nil { | ||
t.Error("unexpected error:", err) | ||
} | ||
if !bytes.Equal(body, wantBody) { | ||
t.Error("expected body", wantBody, "got", body) | ||
} | ||
|
||
// Write response | ||
resp := openAIResponse{ | ||
Data: []struct { | ||
Embedding []float32 `json:"embedding"` | ||
}{ | ||
{Embedding: wantRes}, | ||
}, | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
_ = json.NewEncoder(w).Encode(resp) | ||
})) | ||
defer ts.Close() | ||
baseURL := ts.URL + baseURLSuffix | ||
|
||
f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model) | ||
res, err := f(context.Background(), document) | ||
if err != nil { | ||
t.Error("expected nil, got", err) | ||
} | ||
if slices.Compare[[]float32](wantRes, res) != 0 { | ||
t.Error("expected res", wantRes, "got", res) | ||
} | ||
} |