diff --git a/example/chat.go b/example/chat.go index 538257d..3e61b30 100644 --- a/example/chat.go +++ b/example/chat.go @@ -2,7 +2,9 @@ package main import ( "encoding/json" + "fmt" "log" + "strings" "github.com/Simplou/openai" ) @@ -73,7 +75,7 @@ func chatByEmbedding(largeText string, query string) { Model: "text-embedding-ada-002", Input: chunks, }) - if err != nil{ + if err != nil { return &openai.EmbeddingResponse[[]float64]{}, nil } return emb, nil @@ -103,8 +105,48 @@ func chatByEmbedding(largeText string, query string) { relevantChunks = append(relevantChunks, chunks[i]) } summary, err := openai.ChunksSummary(client, httpClient, relevantChunks, query) - if err != nil{ + if err != nil { log.Println(err) } log.Println(summary) } + +func chatModerator(customerMessage string) { + moderation, err := openai.Moderator(client, httpClient, &openai.ModerationRequest[string]{ + Input: customerMessage, + }) + if err != nil { + log.Println(err) + } + categories := make([]string, 0) + for _, v := range moderation.Results { + if v.Flagged { + for category, value := range v.Categories { + if value { + categories = append(categories, category) + } + } + } + } + if len(categories) == 0 { + res, err := openai.ChatCompletion[openai.DefaultMessages]( + client, + httpClient, + &openai.CompletionRequest[openai.DefaultMessages]{ + Model: "gpt-3.5-turbo", + Messages: openai.DefaultMessages{ + {Role: "user", Content: customerMessage}, + }, + }, + ) + if err != nil { + log.Println(err) + } + log.Println(res.Choices[0].Message.Content) + } else { + s := strings.Join(categories, ", ") + moderatorMessage := fmt.Sprintf("Your statement contains several disrespectful things: (%s)", s) + log.Println(moderatorMessage) + } + +} diff --git a/example/main.go b/example/main.go index 620b467..c848f02 100644 --- a/example/main.go +++ b/example/main.go @@ -24,6 +24,10 @@ func main() { `, "oi, o que é o actor paradigm?", ) + badSentence := "I want to kill them." + query := "Hi, could you please explain what the actor paradigm is?" + chatModerator(badSentence) + chatModerator(query) //functionCall() //tts() //whisper() diff --git a/moderation.go b/moderation.go new file mode 100644 index 0000000..4569bf4 --- /dev/null +++ b/moderation.go @@ -0,0 +1,46 @@ +package openai + +import ( + "encoding/json" + + "github.com/Simplou/goxios" +) + +type ModerationRequest[Input string | []string] struct { + Input string `json:"input"` + Model string `json:"model,omitempty"` +} + +type ModerationResponse struct { + Id string `json:"id"` + Model string `json:"model"` + Results []struct { + Flagged bool `json:"flagged"` + Categories goxios.GenericJSON[bool] `json:"categories"` + } `json:"results"` + CategoryScores goxios.GenericJSON[float64] `json:"category_scores"` +} + +func Moderator[Input string | []string](api OpenAIClient, httpClient HTTPClient, body *ModerationRequest[Input]) (*ModerationResponse, *OpenAIErr) { + api.AddHeader(contentTypeJSON) + b, err := json.Marshal(body) + if err != nil { + return nil, errCannotMarshalJSON(err) + } + options := goxios.RequestOpts{ + Body: ioReader(b), + Headers: Headers(), + } + res, err := httpClient.Post(api.BaseURL()+"/moderations", &options) + if err != nil { + return nil, errCannotSendRequest(err) + } + response := new(ModerationResponse) + if err := goxios.DecodeJSON(res.Body, response); err != nil { + return nil, errCannotDecodeJSON(err) + } + if err := res.Body.Close(); err != nil { + return nil, errCloseBody(err) + } + return response, nil +}