-
Notifications
You must be signed in to change notification settings - Fork 307
/
Copy pathquestions.go
113 lines (94 loc) · 3.55 KB
/
questions.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package postapi
import (
"encoding/json"
"log"
"net/http"
"github.com/pashpashpash/vault/form"
openai "github.com/sashabaranov/go-openai"
)
type Context struct {
Text string `json:"text"`
Title string `json:"title"`
}
type Answer struct {
Answer string `json:"answer"`
Context []Context `json:"context"`
Tokens int `json:"tokens"`
}
// Handle Requests For Question
func (ctx *HandlerContext) QuestionHandler(w http.ResponseWriter, r *http.Request) {
form := new(form.QuestionForm)
if errs := FormParseVerify(form, "QuestionForm", w, r); errs != nil {
return
}
log.Println("[QuestionHandler] Question:", form.Question)
log.Println("[QuestionHandler] Model:", form.Model)
log.Println("[QuestionHandler] UUID:", form.UUID)
log.Println("[QuestionHandler] ApiKey:", form.ApiKey)
clientToUse := ctx.openAIClient
if form.ApiKey != "" {
log.Println("[QuestionHandler] Using provided custom API key:", form.ApiKey)
clientToUse = openai.NewClient(form.ApiKey)
}
// step 1: Feed question to openai embeddings api to get an embedding back
questionEmbedding, err := getEmbedding(clientToUse, form.Question, openai.AdaEmbeddingV2)
if err != nil {
log.Println("[QuestionHandler ERR] OpenAI get embedding request error\n", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
log.Println("[QuestionHandler] Question Embedding Length:", len(questionEmbedding))
// step 2: Query vector db using questionEmbedding to get context matches
matches, err := ctx.vectorDB.Retrieve(questionEmbedding, 4, form.UUID)
if err != nil {
log.Println("[QuestionHandler ERR] Vector DB query error\n", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
log.Println("[QuestionHandler] Got matches from vector DB:", matches)
// Extract context text and titles from the matches
contexts := make([]Context, len(matches))
for i, match := range matches {
contexts[i].Text = match.Metadata["text"]
contexts[i].Title = match.Metadata["title"]
}
log.Println("[QuestionHandler] Retrieved context from vector DB:\n", contexts)
// step 3: Structure the prompt with a context section + question, using top x results from vector DB as the context
contextTexts := make([]string, len(contexts))
for i, context := range contexts {
contextTexts[i] = context.Text
}
prompt, err := buildPrompt(contextTexts, form.Question)
if prompt == "" {
prompt = form.Question
}
if err != nil {
log.Println("[QuestionHandler ERR] Error building prompt\n", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
model := openai.GPT3Dot5Turbo
if form.Model == "GPT Davinci" {
model = openai.GPT3TextDavinci003
}
log.Printf("[QuestionHandler] Sending OpenAI api request...\nPrompt:%s\n", prompt)
openAIResponse, tokens, err := callOpenAI(clientToUse, prompt, model,
"You are a helpful assistant answering questions based on the context provided.",
512)
if err != nil {
log.Println("[QuestionHandler ERR] OpenAI answer questions request error\n", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
log.Println("[QuestionHandler] OpenAI response:\n", openAIResponse)
response := OpenAIResponse{openAIResponse, tokens}
answer := Answer{response.Response, contexts, response.Tokens}
jsonResponse, err := json.Marshal(answer)
if err != nil {
log.Println("[QuestionHandler ERR] OpenAI response marshalling error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(jsonResponse)
}