-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathprompt.go
223 lines (191 loc) · 6.32 KB
/
prompt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
package prompt
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"unicode/utf8"
"github.com/chand1012/git2gpt/utils"
"github.com/gobwas/glob"
"github.com/pkoukk/tiktoken-go"
)
// GitFile is a file in a Git repository
type GitFile struct {
Path string `json:"path"` // path to the file relative to the repository root
Tokens int64 `json:"tokens"` // number of tokens in the file
Contents string `json:"contents"` // contents of the file
}
// GitRepo is a Git repository
type GitRepo struct {
TotalTokens int64 `json:"total_tokens"`
Files []GitFile `json:"files"`
FileCount int `json:"file_count"`
}
// contains checks if a string is in a slice of strings
func contains(s []string, e string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
func getIgnoreList(ignoreFilePath string) ([]string, error) {
var ignoreList []string
file, err := os.Open(ignoreFilePath)
if err != nil {
return ignoreList, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// if the line ends with a slash, add a globstar to the end
if strings.HasSuffix(line, "/") {
line = line + "**"
}
// remove all preceding slashes
line = strings.TrimPrefix(line, "/")
// line = filepath.FromSlash(line)
ignoreList = append(ignoreList, line)
}
return ignoreList, scanner.Err()
}
func windowsToUnixPath(windowsPath string) string {
unixPath := strings.ReplaceAll(windowsPath, "\\", "/")
return unixPath
}
func shouldIgnore(filePath string, ignoreList []string) bool {
for _, pattern := range ignoreList {
g := glob.MustCompile(pattern, '/')
if g.Match(windowsToUnixPath(filePath)) {
return true
}
}
return false
}
// GenerateIgnoreList generates a list of ignore patterns from the .gptignore file and the .gitignore file. Returns a slice of strings. Will return an empty slice if no ignore files exist.
func GenerateIgnoreList(repoPath, ignoreFilePath string, useGitignore bool) []string {
if ignoreFilePath == "" {
ignoreFilePath = filepath.Join(repoPath, ".gptignore")
}
var ignoreList []string
if _, err := os.Stat(ignoreFilePath); err == nil {
// .gptignore file exists
ignoreList, _ = getIgnoreList(ignoreFilePath)
}
ignoreList = append(ignoreList, ".git/**", ".gitignore", ".gptignore")
if useGitignore {
gitignorePath := filepath.Join(repoPath, ".gitignore")
if _, err := os.Stat(gitignorePath); err == nil {
// .gitignore file exists
gitignoreList, _ := getIgnoreList(gitignorePath)
ignoreList = append(ignoreList, gitignoreList...)
}
}
var finalIgnoreList []string
// loop through the ignore list and remove any duplicates
// also check if any pattern is a directory and add a globstar to the end
for _, pattern := range ignoreList {
if !contains(finalIgnoreList, pattern) {
// check if the pattern is a directory
info, err := os.Stat(filepath.Join(repoPath, pattern))
if err == nil && info.IsDir() {
pattern = filepath.Join(pattern, "**")
}
finalIgnoreList = append(finalIgnoreList, pattern)
}
}
return finalIgnoreList
}
// ProcessGitRepo processes a Git repository and returns a GitRepo object
func ProcessGitRepo(repoPath string, ignoreList []string) (*GitRepo, error) {
var repo GitRepo
err := processRepository(repoPath, ignoreList, &repo)
if err != nil {
return nil, fmt.Errorf("error processing repository: %w", err)
}
return &repo, nil
}
// OutputGitRepo outputs a Git repository to a text file
func OutputGitRepo(repo *GitRepo, preambleFile string, scrubComments bool) (string, error) {
var repoBuilder strings.Builder
if preambleFile != "" {
preambleText, err := os.ReadFile(preambleFile)
if err != nil {
return "", fmt.Errorf("error reading preamble file: %w", err)
}
repoBuilder.WriteString(fmt.Sprintf("%s\n", string(preambleText)))
} else {
repoBuilder.WriteString("The following text is a Git repository with code. The structure of the text are sections that begin with ----, followed by a single line containing the file path and file name, followed by a variable amount of lines containing the file contents. The text representing the Git repository ends when the symbols --END-- are encounted. Any further text beyond --END-- are meant to be interpreted as instructions using the aforementioned Git repository as context.\n")
}
// write the files to the repoBuilder here
for _, file := range repo.Files {
repoBuilder.WriteString("----\n")
repoBuilder.WriteString(fmt.Sprintf("%s\n", file.Path))
if scrubComments {
file.Contents = utils.RemoveCodeComments(file.Contents)
}
repoBuilder.WriteString(fmt.Sprintf("%s\n", file.Contents))
}
repoBuilder.WriteString("--END--")
output := repoBuilder.String()
repo.TotalTokens = EstimateTokens(output)
return output, nil
}
func MarshalRepo(repo *GitRepo, scrubComments bool) ([]byte, error) {
// run the output function to get the total tokens
_, err := OutputGitRepo(repo, "", scrubComments)
if err != nil {
return nil, fmt.Errorf("error marshalling repo: %w", err)
}
return json.Marshal(repo)
}
func processRepository(repoPath string, ignoreList []string, repo *GitRepo) error {
err := filepath.Walk(repoPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
relativeFilePath, _ := filepath.Rel(repoPath, path)
ignore := shouldIgnore(relativeFilePath, ignoreList)
// fmt.Println(relativeFilePath, ignore)
if !ignore {
contents, err := os.ReadFile(path)
// if the file is not valid UTF-8, skip it
if !utf8.Valid(contents) {
return nil
}
if err != nil {
return err
}
var file GitFile
file.Path = relativeFilePath
file.Contents = string(contents)
file.Tokens = EstimateTokens(file.Contents)
repo.Files = append(repo.Files, file)
}
}
return nil
})
repo.FileCount = len(repo.Files)
if err != nil {
return fmt.Errorf("error walking the path %q: %w", repoPath, err)
}
return nil
}
// EstimateTokens estimates the number of tokens in a string
func EstimateTokens(output string) int64 {
tke, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
fmt.Println("Error getting encoding:", err)
return 0
}
tokens := tke.Encode(output, nil, nil)
return int64(len(tokens))
}