Skip to content

Commit

Permalink
Add unit test for OpenAI compatible EmbeddingFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
philippgille committed Feb 18, 2024
1 parent ac9e437 commit c9f3d82
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions embed_openai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package chromem_test

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"testing"

"github.com/philippgille/chromem-go"
)

type openAIResponse struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}

func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
apiKey := "secret"
model := "model-small"
baseURLSuffix := "/v1"
document := "hello world"

wantBody, err := json.Marshal(map[string]string{
"input": document,
"model": model,
})
if err != nil {
t.Error("unexpected error:", err)
}
wantRes := []float32{-0.1, 0.1, 0.2}

// Mock server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check URL
if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") {
t.Error("expected URL", baseURLSuffix+"/embedding", "got", r.URL.Path)
}
// Check method
if r.Method != "POST" {
t.Error("expected method POST, got", r.Method)
}
// Check headers
if r.Header.Get("Authorization") != "Bearer "+apiKey {
t.Error("expected Authorization header", "Bearer "+apiKey, "got", r.Header.Get("Authorization"))
}
if r.Header.Get("Content-Type") != "application/json" {
t.Error("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type"))
}
// Check body
body, err := io.ReadAll(r.Body)
if err != nil {
t.Error("unexpected error:", err)
}
if !bytes.Equal(body, wantBody) {
t.Error("expected body", wantBody, "got", body)
}

// Write response
resp := openAIResponse{
Data: []struct {
Embedding []float32 `json:"embedding"`
}{
{Embedding: wantRes},
},
}
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}))
defer ts.Close()
baseURL := ts.URL + baseURLSuffix

f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model)
res, err := f(context.Background(), document)
if err != nil {
t.Error("expected nil, got", err)
}
if slices.Compare[[]float32](wantRes, res) != 0 {
t.Error("expected res", wantRes, "got", res)
}
}

0 comments on commit c9f3d82

Please sign in to comment.