forked from securego/gosec
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathai.go
142 lines (115 loc) · 4.02 KB
/
ai.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package autofix
import (
"context"
"errors"
"fmt"
"time"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
"github.com/securego/gosec/v2/issue"
)
const (
GeminiModel = "gemini-1.5-flash"
AIPrompt = `Provide a brief explanation and a solution to fix this security issue
in Go programming language: %q.
Answer in markdown format and keep the response limited to 200 words.`
GeminiProvider = "gemini"
timeout = 30 * time.Second
)
// GenAIClient defines the interface for the GenAI client.
type GenAIClient interface {
// Close clean up and close the client.
Close() error
// GenerativeModel build the generative mode.
GenerativeModel(name string) GenAIGenerativeModel
}
// GenAIGenerativeModel defines the interface for the Generative Model.
type GenAIGenerativeModel interface {
// GenerateContent generates an response for given prompt.
GenerateContent(ctx context.Context, prompt string) (string, error)
}
// genAIClientWrapper wraps the genai.Client to implement GenAIClient.
type genAIClientWrapper struct {
client *genai.Client
}
// Close closes the gen AI client.
func (w *genAIClientWrapper) Close() error {
return w.client.Close()
}
// GenerativeModel builds the generative Model.
func (w *genAIClientWrapper) GenerativeModel(name string) GenAIGenerativeModel {
return &genAIGenerativeModelWrapper{model: w.client.GenerativeModel(name)}
}
// genAIGenerativeModelWrapper wraps the genai.GenerativeModel to implement GenAIGenerativeModel
type genAIGenerativeModelWrapper struct {
// model is the underlying generative model
model *genai.GenerativeModel
}
// GenerateContent generates a response for the given prompt using gemini API.
func (w *genAIGenerativeModelWrapper) GenerateContent(ctx context.Context, prompt string) (string, error) {
resp, err := w.model.GenerateContent(ctx, genai.Text(prompt))
if err != nil {
return "", fmt.Errorf("generating autofix: %w", err)
}
if len(resp.Candidates) == 0 {
return "", errors.New("no autofix returned by gemini")
}
if len(resp.Candidates[0].Content.Parts) == 0 {
return "", errors.New("nothing found in the first autofix returned by gemini")
}
// Return the first candidate
return fmt.Sprintf("%+v", resp.Candidates[0].Content.Parts[0]), nil
}
// NewGenAIClient creates a new gemini API client.
func NewGenAIClient(ctx context.Context, aiApiKey, endpoint string) (GenAIClient, error) {
clientOptions := []option.ClientOption{option.WithAPIKey(aiApiKey)}
if endpoint != "" {
clientOptions = append(clientOptions, option.WithEndpoint(endpoint))
}
client, err := genai.NewClient(ctx, clientOptions...)
if err != nil {
return nil, fmt.Errorf("calling gemini API: %w", err)
}
return &genAIClientWrapper{client: client}, nil
}
func generateSolutionByGemini(client GenAIClient, issues []*issue.Issue) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
model := client.GenerativeModel(GeminiModel)
cachedAutofix := make(map[string]string)
for _, issue := range issues {
if val, ok := cachedAutofix[issue.What]; ok {
issue.Autofix = val
continue
}
prompt := fmt.Sprintf(AIPrompt, issue.What)
resp, err := model.GenerateContent(ctx, prompt)
if err != nil {
return fmt.Errorf("generating autofix with gemini: %w", err)
}
if resp == "" {
return errors.New("no autofix returned by gemini")
}
issue.Autofix = resp
cachedAutofix[issue.What] = issue.Autofix
}
return nil
}
// GenerateSolution generates a solution for the given issues using the specified AI provider
func GenerateSolution(aiApiProvider, aiApiKey, endpoint string, issues []*issue.Issue) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var client GenAIClient
switch aiApiProvider {
case GeminiProvider:
var err error
client, err = NewGenAIClient(ctx, aiApiKey, endpoint)
if err != nil {
return fmt.Errorf("generating autofix: %w", err)
}
default:
return errors.New("ai provider not supported")
}
defer client.Close()
return generateSolutionByGemini(client, issues)
}