diff --git a/README.md b/README.md index 7f5943684b..ba93fa643f 100644 --- a/README.md +++ b/README.md @@ -334,6 +334,27 @@ k8sgpt analyze --explain --backend azureopenai + + +
+Cohere provider + +Prerequisites: a Cohere API key is needed, please visit the [Cohere dashboard](https://dashboard.cohere.ai/api-keys) to create one. + +To run k8sgpt, run `k8sgpt auth` with the `cohere` backend: + +``` +k8sgpt auth add --backend cohere --model command-nightly +``` + +Lastly, enter your Cohere API key, after the prompt. + +Now you are ready to analyze with the Cohere backend: + +``` +k8sgpt analyze --explain --backend cohere +``` +
diff --git a/go.mod b/go.mod index b7e8d114b1..9bac0000a6 100644 --- a/go.mod +++ b/go.mod @@ -27,10 +27,13 @@ require ( buf.build/gen/go/k8sgpt-ai/k8sgpt/grpc/go v1.3.0-20230620082254-6f80f9533908.1 buf.build/gen/go/k8sgpt-ai/k8sgpt/protocolbuffers/go v1.30.0-20230620082254-6f80f9533908.1 github.com/aws/aws-sdk-go v1.44.300 + github.com/cohere-ai/cohere-go v0.2.0 ) require ( github.com/anchore/go-struct-converter v0.0.0-20221118182256-c68fdcfa2092 // indirect + github.com/cohere-ai/tokenizer v1.1.1 // indirect + github.com/dlclark/regexp2 v1.4.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect ) diff --git a/go.sum b/go.sum index e650109854..f2a59d3c23 100644 --- a/go.sum +++ b/go.sum @@ -493,6 +493,10 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20220314180256-7f1daf1720fc/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cohere-ai/cohere-go v0.2.0 h1:Gljkn8LTtsAPy79ks1AVmZH9Av4kuQuXEgzEJ/1Ea34= +github.com/cohere-ai/cohere-go v0.2.0/go.mod h1:DFcCu5rwro4wAlluIXY9l17NLGiVBGb2bRio46RXBm8= +github.com/cohere-ai/tokenizer v1.1.1 h1:wCtmCj07O82TMrIiA/CORhIlEYsvMMM8ey+sUdEapHc= +github.com/cohere-ai/tokenizer v1.1.1/go.mod h1:9MNFPd9j1fuiEK3ua2HSCUxxcrfGMlSqpa93livg/C0= github.com/containerd/cgroups v1.1.0 h1:v8rEWFl6EoqHB+swVNjVoCJE8o3jX7e8nqBGPLaDFBM= github.com/containerd/containerd v1.7.0 h1:G/ZQr3gMZs6ZT0qPUZ15znx5QSdQdASW11nXTLTM2Pg= github.com/containerd/containerd v1.7.0/go.mod h1:QfR7Efgb/6X2BDpTPJRvPTYDE9rsF0FsXX9J8sIs/sc= @@ -518,6 +522,7 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2 h1:aBfCb7iqHmDEIp6fBvC/hQUddQfg+3qdYjwzaiP9Hnc= github.com/dlclark/regexp2 v1.4.0 h1:F1rxgk7p4uKjwIQxBs9oAXe5CqrXlCduYEJvrF4u93E= +github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= github.com/docker/cli v23.0.5+incompatible h1:ufWmAOuD3Vmr7JP2G5K3cyuNC4YZWiAsuDEvFVVDafE= github.com/docker/cli v23.0.5+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= diff --git a/pkg/ai/cohere.go b/pkg/ai/cohere.go new file mode 100644 index 0000000000..a09963c544 --- /dev/null +++ b/pkg/ai/cohere.go @@ -0,0 +1,116 @@ +/* +Copyright 2023 The K8sGPT Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ai + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "strings" + + "github.com/cohere-ai/cohere-go" + "github.com/fatih/color" + + "github.com/k8sgpt-ai/k8sgpt/pkg/cache" + "github.com/k8sgpt-ai/k8sgpt/pkg/util" +) + +type CohereClient struct { + client *cohere.Client + language string + model string +} + +func (c *CohereClient) Configure(config IAIConfig, language string) error { + token := config.GetPassword() + + client, err := cohere.CreateClient(token) + if err != nil { + return err + } + + baseURL := config.GetBaseURL() + if baseURL != "" { + client.BaseURL = baseURL + } + + if client == nil { + return errors.New("error creating Cohere client") + } + c.language = language + c.client = client + c.model = config.GetModel() + return nil +} + +func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl string) (string, error) { + // Create a completion request + if len(promptTmpl) == 0 { + promptTmpl = PromptMap["default"] + } + resp, err := c.client.Generate(cohere.GenerateOptions{ + Model: c.model, + Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt), + MaxTokens: cohere.Uint(2048), + Temperature: cohere.Float64(0.75), + K: cohere.Int(0), + StopSequences: []string{}, + ReturnLikelihoods: "NONE", + }) + if err != nil { + return "", err + } + return resp.Generations[0].Text, nil +} + +func (a *CohereClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { + inputKey := strings.Join(prompt, " ") + // Check for cached data + cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) + + if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { + response, err := cache.Load(cacheKey) + if err != nil { + return "", err + } + + if response != "" { + output, err := base64.StdEncoding.DecodeString(response) + if err != nil { + color.Red("error decoding cached data: %v", err) + return "", nil + } + return string(output), nil + } + } + + response, err := a.GetCompletion(ctx, inputKey, promptTmpl) + if err != nil { + return "", err + } + + err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) + + if err != nil { + color.Red("error storing value to cache: %v", err) + return "", nil + } + + return response, nil +} + +func (a *CohereClient) GetName() string { + return "cohere" +} diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 6b3ea1d776..2d09e41231 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -25,12 +25,14 @@ var ( &AzureAIClient{}, &LocalAIClient{}, &NoOpAIClient{}, + &CohereClient{}, } Backends = []string{ "openai", "localai", "azureopenai", "noopai", + "cohere", } )