|
| 1 | +// Copyright 2024 Google LLC |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +package tokenizer |
| 15 | + |
| 16 | +import ( |
| 17 | + "archive/zip" |
| 18 | + "bytes" |
| 19 | + "context" |
| 20 | + "fmt" |
| 21 | + "io" |
| 22 | + "log" |
| 23 | + "net/http" |
| 24 | + "os" |
| 25 | + "strings" |
| 26 | + "testing" |
| 27 | + |
| 28 | + "cloud.google.com/go/vertexai/genai" |
| 29 | + "golang.org/x/text/encoding" |
| 30 | + "golang.org/x/text/encoding/charmap" |
| 31 | + "golang.org/x/text/encoding/japanese" |
| 32 | + "golang.org/x/text/encoding/simplifiedchinese" |
| 33 | + "golang.org/x/text/transform" |
| 34 | +) |
| 35 | + |
| 36 | +// corporaInfo holds the name and content of a file in the zip archive |
| 37 | +type corporaInfo struct { |
| 38 | + Name string |
| 39 | + Content []byte |
| 40 | +} |
| 41 | + |
| 42 | +// corporaGenerator is a helper function that downloads a zip archive from a given URL, |
| 43 | +// extracts the content of each file in the archive, |
| 44 | +// and returns a slice of corporaInfo objects containing the name and content of each file. |
| 45 | +func corporaGenerator(url string) ([]corporaInfo, error) { |
| 46 | + var corpora []corporaInfo |
| 47 | + |
| 48 | + // Download the zip file |
| 49 | + resp, err := http.Get(url) |
| 50 | + if err != nil { |
| 51 | + return nil, fmt.Errorf("error downloading file: %v", err) |
| 52 | + } |
| 53 | + defer resp.Body.Close() |
| 54 | + |
| 55 | + // Read the content of the response body |
| 56 | + body, err := io.ReadAll(resp.Body) |
| 57 | + if err != nil { |
| 58 | + return nil, fmt.Errorf("error reading response body: %v", err) |
| 59 | + } |
| 60 | + |
| 61 | + // Create a zip reader from the downloaded content |
| 62 | + zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) |
| 63 | + if err != nil { |
| 64 | + return nil, fmt.Errorf("error creating zip reader: %v", err) |
| 65 | + } |
| 66 | + |
| 67 | + // Iterate over each file in the zip archive |
| 68 | + for _, file := range zipReader.File { |
| 69 | + fileReader, err := file.Open() |
| 70 | + if err != nil { |
| 71 | + return nil, fmt.Errorf("error opening file: %v", err) |
| 72 | + } |
| 73 | + |
| 74 | + // Check if the file is a text file |
| 75 | + if !file.FileInfo().IsDir() && file.FileInfo().Mode().IsRegular() { |
| 76 | + content, err := io.ReadAll(fileReader) |
| 77 | + fileReader.Close() |
| 78 | + if err != nil { |
| 79 | + return nil, fmt.Errorf("error reading file content: %v", err) |
| 80 | + } |
| 81 | + |
| 82 | + corpora = append(corpora, corporaInfo{ |
| 83 | + Name: file.Name[len("udhr/"):], |
| 84 | + Content: content, |
| 85 | + }) |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + return corpora, nil |
| 90 | +} |
| 91 | + |
| 92 | +// udhrCorpus represents the Universal Declaration of Human Rights (UDHR) corpus. |
| 93 | +// This corpus contains translations of the UDHR into many languages, |
| 94 | +// stored in a specific directory structure within a zip archive. |
| 95 | +// |
| 96 | +// The files in the corpus USUALLY follow a naming convention: |
| 97 | +// |
| 98 | +// <Language>_<Script>-<Encoding> |
| 99 | +// |
| 100 | +// For example: |
| 101 | +// - English_English-UTF8 |
| 102 | +// - French_Français-Latin1 |
| 103 | +// - Spanish_Español-UTF8 |
| 104 | +// |
| 105 | +// The Language and Script parts are self-explanatory. |
| 106 | +// The Encoding part indicates the character encoding used in the file. |
| 107 | +// |
| 108 | +// This corpus is used to test the token counting functionality |
| 109 | +// against a diverse set of languages and encodings. |
| 110 | +type udhrCorpus struct { |
| 111 | + EncodingByFileSuffix map[string]encoding.Encoding |
| 112 | + EncodingByFilename map[string]encoding.Encoding |
| 113 | + |
| 114 | + // Skip lists files that should be skipped during testing. |
| 115 | + // This is useful for excluding files that are known to cause issues |
| 116 | + // or are not relevant for the test. |
| 117 | + Skip map[string]bool |
| 118 | +} |
| 119 | + |
| 120 | +// newUdhrCorpus initializes a new udhrCorpus with encoding patterns and skip set |
| 121 | +// func newUdhrCorpus() *udhrCorpus { |
| 122 | +func newUdhrCorpus() *udhrCorpus { |
| 123 | + |
| 124 | + EncodingByFileSuffix := map[string]encoding.Encoding{ |
| 125 | + "Latin1": charmap.ISO8859_1, |
| 126 | + "Hebrew": charmap.ISO8859_8, |
| 127 | + "Arabic": charmap.Windows1256, |
| 128 | + "UTF8": encoding.Nop, |
| 129 | + "Cyrillic": charmap.Windows1251, |
| 130 | + "SJIS": japanese.ShiftJIS, |
| 131 | + "GB2312": simplifiedchinese.HZGB2312, |
| 132 | + "Latin2": charmap.ISO8859_2, |
| 133 | + "Greek": charmap.ISO8859_7, |
| 134 | + "Turkish": charmap.ISO8859_9, |
| 135 | + "Baltic": charmap.ISO8859_4, |
| 136 | + "EUC": japanese.EUCJP, |
| 137 | + "VPS": charmap.Windows1258, |
| 138 | + "Agra": encoding.Nop, |
| 139 | + "T61": charmap.ISO8859_3, |
| 140 | + } |
| 141 | + |
| 142 | + // For non-conventional filenames: |
| 143 | + EncodingByFilename := map[string]encoding.Encoding{ |
| 144 | + "Czech_Cesky-UTF8": charmap.Windows1250, |
| 145 | + "Polish-Latin2": charmap.Windows1250, |
| 146 | + "Polish_Polski-Latin2": charmap.Windows1250, |
| 147 | + "Amahuaca": charmap.ISO8859_1, |
| 148 | + "Turkish_Turkce-Turkish": charmap.ISO8859_9, |
| 149 | + "Lithuanian_Lietuviskai-Baltic": charmap.ISO8859_4, |
| 150 | + "Abkhaz-Cyrillic+Abkh": charmap.Windows1251, |
| 151 | + "Azeri_Azerbaijani_Cyrillic-Az.Times.Cyr.Normal0117": charmap.Windows1251, |
| 152 | + "Azeri_Azerbaijani_Latin-Az.Times.Lat0117": charmap.ISO8859_2, |
| 153 | + } |
| 154 | + |
| 155 | + // The skip list comes from the NLTK source code which says these are unsupported encodings, |
| 156 | + // or in general encodings Go doesn't support. |
| 157 | + // See NLTK source code reference: https://github.com/nltk/nltk/blob/f6567388b4399000b9aa2a6b0db713bff3fe332a/nltk/corpus/reader/udhr.py#L14 |
| 158 | + Skip := map[string]bool{ |
| 159 | + // The following files are not fully decodable because they |
| 160 | + // were truncated at wrong bytes: |
| 161 | + "Burmese_Myanmar-UTF8": true, |
| 162 | + "Japanese_Nihongo-JIS": true, |
| 163 | + "Chinese_Mandarin-HZ": true, |
| 164 | + "Chinese_Mandarin-UTF8": true, |
| 165 | + "Gujarati-UTF8": true, |
| 166 | + "Hungarian_Magyar-Unicode": true, |
| 167 | + "Lao-UTF8": true, |
| 168 | + "Magahi-UTF8": true, |
| 169 | + "Marathi-UTF8": true, |
| 170 | + "Tamil-UTF8": true, |
| 171 | + "Magahi-Agrarpc": true, |
| 172 | + "Magahi-Agra": true, |
| 173 | + // encoding not supported in Go. |
| 174 | + "Vietnamese-VIQR": true, |
| 175 | + "Vietnamese-TCVN": true, |
| 176 | + // The following files are encoded for specific fonts: |
| 177 | + "Burmese_Myanmar-WinResearcher": true, |
| 178 | + "Armenian-DallakHelv": true, |
| 179 | + "Tigrinya_Tigrigna-VG2Main": true, |
| 180 | + "Amharic-Afenegus6..60375": true, |
| 181 | + "Navaho_Dine-Navajo-Navaho-font": true, |
| 182 | + // The following files are unintended: |
| 183 | + "Czech-Latin2-err": true, |
| 184 | + "Russian_Russky-UTF8~": true, |
| 185 | + } |
| 186 | + |
| 187 | + return &udhrCorpus{ |
| 188 | + EncodingByFileSuffix: EncodingByFileSuffix, |
| 189 | + EncodingByFilename: EncodingByFilename, |
| 190 | + Skip: Skip, |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +// getEncoding returns the encoding for a given filename based on patterns |
| 195 | +func (ucr *udhrCorpus) getEncoding(filename string) (encoding.Encoding, bool) { |
| 196 | + if enc, exists := ucr.EncodingByFilename[filename]; exists { |
| 197 | + return enc, true |
| 198 | + } |
| 199 | + |
| 200 | + parts := strings.Split(filename, "-") |
| 201 | + encodingKey := parts[len(parts)-1] |
| 202 | + if enc, exists := ucr.EncodingByFileSuffix[encodingKey]; exists { |
| 203 | + return enc, true |
| 204 | + } |
| 205 | + |
| 206 | + return nil, false |
| 207 | +} |
| 208 | + |
| 209 | +// shouldSkip checks if the file should be skipped |
| 210 | +func (ucr *udhrCorpus) shouldSkip(filename string) bool { |
| 211 | + return ucr.Skip[filename] |
| 212 | +} |
| 213 | + |
| 214 | +// decodeBytes decodes the given byte slice using the specified encoding |
| 215 | +func decodeBytes(enc encoding.Encoding, content []byte) (string, error) { |
| 216 | + decodedBytes, _, err := transform.Bytes(enc.NewDecoder(), content) |
| 217 | + if err != nil { |
| 218 | + return "", fmt.Errorf("error decoding bytes: %v", err) |
| 219 | + } |
| 220 | + return string(decodedBytes), nil |
| 221 | +} |
| 222 | + |
| 223 | +const defaultModel = "gemini-1.0-pro" |
| 224 | +const defaultLocation = "us-central1" |
| 225 | + |
| 226 | +func TestCountTokensWithCorpora(t *testing.T) { |
| 227 | + projectID := os.Getenv("VERTEX_PROJECT_ID") |
| 228 | + if testing.Short() { |
| 229 | + t.Skip("skipping live test in -short mode") |
| 230 | + } |
| 231 | + |
| 232 | + if projectID == "" { |
| 233 | + t.Skip("set a VERTEX_PROJECT_ID env var to run live tests") |
| 234 | + } |
| 235 | + ctx := context.Background() |
| 236 | + client, err := genai.NewClient(ctx, projectID, defaultLocation) |
| 237 | + if err != nil { |
| 238 | + t.Fatal(err) |
| 239 | + } |
| 240 | + defer client.Close() |
| 241 | + model := client.GenerativeModel(defaultModel) |
| 242 | + ucr := newUdhrCorpus() |
| 243 | + |
| 244 | + corporaURL := "https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/udhr.zip" |
| 245 | + files, err := corporaGenerator(corporaURL) |
| 246 | + if err != nil { |
| 247 | + t.Fatalf("Failed to generate corpora: %v", err) |
| 248 | + } |
| 249 | + |
| 250 | + // Iterate over files generated by the generator function |
| 251 | + for _, fileInfo := range files { |
| 252 | + if ucr.shouldSkip(fileInfo.Name) { |
| 253 | + fmt.Printf("Skipping file: %s\n", fileInfo.Name) |
| 254 | + continue |
| 255 | + } |
| 256 | + |
| 257 | + enc, found := ucr.getEncoding(fileInfo.Name) |
| 258 | + if !found { |
| 259 | + fmt.Printf("No encoding found for file: %s\n", fileInfo.Name) |
| 260 | + continue |
| 261 | + } |
| 262 | + |
| 263 | + decodedContent, err := decodeBytes(enc, fileInfo.Content) |
| 264 | + if err != nil { |
| 265 | + log.Fatalf("Failed to decode bytes: %v", err) |
| 266 | + } |
| 267 | + |
| 268 | + tok, err := New(defaultModel) |
| 269 | + if err != nil { |
| 270 | + log.Fatal(err) |
| 271 | + } |
| 272 | + |
| 273 | + localNtoks, err := tok.CountTokens(genai.Text(decodedContent)) |
| 274 | + if err != nil { |
| 275 | + log.Fatal(err) |
| 276 | + } |
| 277 | + remoteNtoks, err := model.CountTokens(ctx, genai.Text(decodedContent)) |
| 278 | + if err != nil { |
| 279 | + log.Fatal(fileInfo.Name, err) |
| 280 | + } |
| 281 | + if localNtoks.TotalTokens != remoteNtoks.TotalTokens { |
| 282 | + t.Errorf("expected %d(remote count-token results), but got %d(local count-token results)", remoteNtoks, localNtoks) |
| 283 | + } |
| 284 | + |
| 285 | + } |
| 286 | + |
| 287 | +} |
0 commit comments