Skip to content

Commit 119a9b2

Browse files
authored
Merge branch 'main' into release-please--branches--main--components--pubsub
2 parents 81cf153 + ce82b22 commit 119a9b2

File tree

4 files changed

+309
-3
lines changed

4 files changed

+309
-3
lines changed

bigtable/type.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ func Equal(a, b Type) bool {
5656
return proto.Equal(a.proto(), b.proto())
5757
}
5858

59+
// TypeUnspecified represents the absence of a type.
60+
type TypeUnspecified struct{}
61+
62+
func (n TypeUnspecified) proto() *btapb.Type {
63+
return &btapb.Type{}
64+
}
65+
5966
type unknown[T interface{}] struct {
6067
wrapped *T
6168
}
@@ -240,7 +247,9 @@ func ProtoToType(pb *btapb.Type) Type {
240247
if pb == nil {
241248
return unknown[btapb.Type]{wrapped: nil}
242249
}
243-
250+
if pb.Kind == nil {
251+
return TypeUnspecified{}
252+
}
244253
switch t := pb.Kind.(type) {
245254
case *btapb.Type_Int64Type:
246255
return int64ProtoToType(t.Int64Type)

bigtable/type_test.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ func TestNilChecks(t *testing.T) {
184184
if val, ok := ProtoToType(nil).(unknown[btapb.Type]); !ok {
185185
t.Errorf("got: %T, wanted unknown[btapb.Type]", val)
186186
}
187-
if val, ok := ProtoToType(&btapb.Type{}).(unknown[btapb.Type]); !ok {
187+
if val, ok := ProtoToType(&btapb.Type{}).(TypeUnspecified); !ok {
188188
t.Errorf("got: %T, wanted unknown[btapb.Type]", val)
189189
}
190190

@@ -221,3 +221,13 @@ func TestNilChecks(t *testing.T) {
221221
t.Errorf("got: %T, wanted unknown[btapb.Type]", val)
222222
}
223223
}
224+
225+
func TestTypeUnspecified(t *testing.T) {
226+
pb := &btapb.Type{}
227+
tpe := ProtoToType(pb)
228+
assertType(t, tpe, &btapb.Type{})
229+
expect := TypeUnspecified{}
230+
if tpe != expect {
231+
t.Errorf("got: %v, wanted: %v", tpe, expect)
232+
}
233+
}
+287
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package tokenizer
15+
16+
import (
17+
"archive/zip"
18+
"bytes"
19+
"context"
20+
"fmt"
21+
"io"
22+
"log"
23+
"net/http"
24+
"os"
25+
"strings"
26+
"testing"
27+
28+
"cloud.google.com/go/vertexai/genai"
29+
"golang.org/x/text/encoding"
30+
"golang.org/x/text/encoding/charmap"
31+
"golang.org/x/text/encoding/japanese"
32+
"golang.org/x/text/encoding/simplifiedchinese"
33+
"golang.org/x/text/transform"
34+
)
35+
36+
// corporaInfo holds the name and content of a file in the zip archive
37+
type corporaInfo struct {
38+
Name string
39+
Content []byte
40+
}
41+
42+
// corporaGenerator is a helper function that downloads a zip archive from a given URL,
43+
// extracts the content of each file in the archive,
44+
// and returns a slice of corporaInfo objects containing the name and content of each file.
45+
func corporaGenerator(url string) ([]corporaInfo, error) {
46+
var corpora []corporaInfo
47+
48+
// Download the zip file
49+
resp, err := http.Get(url)
50+
if err != nil {
51+
return nil, fmt.Errorf("error downloading file: %v", err)
52+
}
53+
defer resp.Body.Close()
54+
55+
// Read the content of the response body
56+
body, err := io.ReadAll(resp.Body)
57+
if err != nil {
58+
return nil, fmt.Errorf("error reading response body: %v", err)
59+
}
60+
61+
// Create a zip reader from the downloaded content
62+
zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body)))
63+
if err != nil {
64+
return nil, fmt.Errorf("error creating zip reader: %v", err)
65+
}
66+
67+
// Iterate over each file in the zip archive
68+
for _, file := range zipReader.File {
69+
fileReader, err := file.Open()
70+
if err != nil {
71+
return nil, fmt.Errorf("error opening file: %v", err)
72+
}
73+
74+
// Check if the file is a text file
75+
if !file.FileInfo().IsDir() && file.FileInfo().Mode().IsRegular() {
76+
content, err := io.ReadAll(fileReader)
77+
fileReader.Close()
78+
if err != nil {
79+
return nil, fmt.Errorf("error reading file content: %v", err)
80+
}
81+
82+
corpora = append(corpora, corporaInfo{
83+
Name: file.Name[len("udhr/"):],
84+
Content: content,
85+
})
86+
}
87+
}
88+
89+
return corpora, nil
90+
}
91+
92+
// udhrCorpus represents the Universal Declaration of Human Rights (UDHR) corpus.
93+
// This corpus contains translations of the UDHR into many languages,
94+
// stored in a specific directory structure within a zip archive.
95+
//
96+
// The files in the corpus USUALLY follow a naming convention:
97+
//
98+
// <Language>_<Script>-<Encoding>
99+
//
100+
// For example:
101+
// - English_English-UTF8
102+
// - French_Français-Latin1
103+
// - Spanish_Español-UTF8
104+
//
105+
// The Language and Script parts are self-explanatory.
106+
// The Encoding part indicates the character encoding used in the file.
107+
//
108+
// This corpus is used to test the token counting functionality
109+
// against a diverse set of languages and encodings.
110+
type udhrCorpus struct {
111+
EncodingByFileSuffix map[string]encoding.Encoding
112+
EncodingByFilename map[string]encoding.Encoding
113+
114+
// Skip lists files that should be skipped during testing.
115+
// This is useful for excluding files that are known to cause issues
116+
// or are not relevant for the test.
117+
Skip map[string]bool
118+
}
119+
120+
// newUdhrCorpus initializes a new udhrCorpus with encoding patterns and skip set
121+
// func newUdhrCorpus() *udhrCorpus {
122+
func newUdhrCorpus() *udhrCorpus {
123+
124+
EncodingByFileSuffix := map[string]encoding.Encoding{
125+
"Latin1": charmap.ISO8859_1,
126+
"Hebrew": charmap.ISO8859_8,
127+
"Arabic": charmap.Windows1256,
128+
"UTF8": encoding.Nop,
129+
"Cyrillic": charmap.Windows1251,
130+
"SJIS": japanese.ShiftJIS,
131+
"GB2312": simplifiedchinese.HZGB2312,
132+
"Latin2": charmap.ISO8859_2,
133+
"Greek": charmap.ISO8859_7,
134+
"Turkish": charmap.ISO8859_9,
135+
"Baltic": charmap.ISO8859_4,
136+
"EUC": japanese.EUCJP,
137+
"VPS": charmap.Windows1258,
138+
"Agra": encoding.Nop,
139+
"T61": charmap.ISO8859_3,
140+
}
141+
142+
// For non-conventional filenames:
143+
EncodingByFilename := map[string]encoding.Encoding{
144+
"Czech_Cesky-UTF8": charmap.Windows1250,
145+
"Polish-Latin2": charmap.Windows1250,
146+
"Polish_Polski-Latin2": charmap.Windows1250,
147+
"Amahuaca": charmap.ISO8859_1,
148+
"Turkish_Turkce-Turkish": charmap.ISO8859_9,
149+
"Lithuanian_Lietuviskai-Baltic": charmap.ISO8859_4,
150+
"Abkhaz-Cyrillic+Abkh": charmap.Windows1251,
151+
"Azeri_Azerbaijani_Cyrillic-Az.Times.Cyr.Normal0117": charmap.Windows1251,
152+
"Azeri_Azerbaijani_Latin-Az.Times.Lat0117": charmap.ISO8859_2,
153+
}
154+
155+
// The skip list comes from the NLTK source code which says these are unsupported encodings,
156+
// or in general encodings Go doesn't support.
157+
// See NLTK source code reference: https://github.com/nltk/nltk/blob/f6567388b4399000b9aa2a6b0db713bff3fe332a/nltk/corpus/reader/udhr.py#L14
158+
Skip := map[string]bool{
159+
// The following files are not fully decodable because they
160+
// were truncated at wrong bytes:
161+
"Burmese_Myanmar-UTF8": true,
162+
"Japanese_Nihongo-JIS": true,
163+
"Chinese_Mandarin-HZ": true,
164+
"Chinese_Mandarin-UTF8": true,
165+
"Gujarati-UTF8": true,
166+
"Hungarian_Magyar-Unicode": true,
167+
"Lao-UTF8": true,
168+
"Magahi-UTF8": true,
169+
"Marathi-UTF8": true,
170+
"Tamil-UTF8": true,
171+
"Magahi-Agrarpc": true,
172+
"Magahi-Agra": true,
173+
// encoding not supported in Go.
174+
"Vietnamese-VIQR": true,
175+
"Vietnamese-TCVN": true,
176+
// The following files are encoded for specific fonts:
177+
"Burmese_Myanmar-WinResearcher": true,
178+
"Armenian-DallakHelv": true,
179+
"Tigrinya_Tigrigna-VG2Main": true,
180+
"Amharic-Afenegus6..60375": true,
181+
"Navaho_Dine-Navajo-Navaho-font": true,
182+
// The following files are unintended:
183+
"Czech-Latin2-err": true,
184+
"Russian_Russky-UTF8~": true,
185+
}
186+
187+
return &udhrCorpus{
188+
EncodingByFileSuffix: EncodingByFileSuffix,
189+
EncodingByFilename: EncodingByFilename,
190+
Skip: Skip,
191+
}
192+
}
193+
194+
// getEncoding returns the encoding for a given filename based on patterns
195+
func (ucr *udhrCorpus) getEncoding(filename string) (encoding.Encoding, bool) {
196+
if enc, exists := ucr.EncodingByFilename[filename]; exists {
197+
return enc, true
198+
}
199+
200+
parts := strings.Split(filename, "-")
201+
encodingKey := parts[len(parts)-1]
202+
if enc, exists := ucr.EncodingByFileSuffix[encodingKey]; exists {
203+
return enc, true
204+
}
205+
206+
return nil, false
207+
}
208+
209+
// shouldSkip checks if the file should be skipped
210+
func (ucr *udhrCorpus) shouldSkip(filename string) bool {
211+
return ucr.Skip[filename]
212+
}
213+
214+
// decodeBytes decodes the given byte slice using the specified encoding
215+
func decodeBytes(enc encoding.Encoding, content []byte) (string, error) {
216+
decodedBytes, _, err := transform.Bytes(enc.NewDecoder(), content)
217+
if err != nil {
218+
return "", fmt.Errorf("error decoding bytes: %v", err)
219+
}
220+
return string(decodedBytes), nil
221+
}
222+
223+
const defaultModel = "gemini-1.0-pro"
224+
const defaultLocation = "us-central1"
225+
226+
func TestCountTokensWithCorpora(t *testing.T) {
227+
projectID := os.Getenv("VERTEX_PROJECT_ID")
228+
if testing.Short() {
229+
t.Skip("skipping live test in -short mode")
230+
}
231+
232+
if projectID == "" {
233+
t.Skip("set a VERTEX_PROJECT_ID env var to run live tests")
234+
}
235+
ctx := context.Background()
236+
client, err := genai.NewClient(ctx, projectID, defaultLocation)
237+
if err != nil {
238+
t.Fatal(err)
239+
}
240+
defer client.Close()
241+
model := client.GenerativeModel(defaultModel)
242+
ucr := newUdhrCorpus()
243+
244+
corporaURL := "https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/udhr.zip"
245+
files, err := corporaGenerator(corporaURL)
246+
if err != nil {
247+
t.Fatalf("Failed to generate corpora: %v", err)
248+
}
249+
250+
// Iterate over files generated by the generator function
251+
for _, fileInfo := range files {
252+
if ucr.shouldSkip(fileInfo.Name) {
253+
fmt.Printf("Skipping file: %s\n", fileInfo.Name)
254+
continue
255+
}
256+
257+
enc, found := ucr.getEncoding(fileInfo.Name)
258+
if !found {
259+
fmt.Printf("No encoding found for file: %s\n", fileInfo.Name)
260+
continue
261+
}
262+
263+
decodedContent, err := decodeBytes(enc, fileInfo.Content)
264+
if err != nil {
265+
log.Fatalf("Failed to decode bytes: %v", err)
266+
}
267+
268+
tok, err := New(defaultModel)
269+
if err != nil {
270+
log.Fatal(err)
271+
}
272+
273+
localNtoks, err := tok.CountTokens(genai.Text(decodedContent))
274+
if err != nil {
275+
log.Fatal(err)
276+
}
277+
remoteNtoks, err := model.CountTokens(ctx, genai.Text(decodedContent))
278+
if err != nil {
279+
log.Fatal(fileInfo.Name, err)
280+
}
281+
if localNtoks.TotalTokens != remoteNtoks.TotalTokens {
282+
t.Errorf("expected %d(remote count-token results), but got %d(local count-token results)", remoteNtoks, localNtoks)
283+
}
284+
285+
}
286+
287+
}

vertexai/go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
cloud.google.com/go v0.115.1
77
cloud.google.com/go/aiplatform v1.68.0
88
github.com/google/go-cmp v0.6.0
9+
golang.org/x/text v0.17.0
910
google.golang.org/api v0.196.0
1011
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1
1112
google.golang.org/protobuf v1.34.2
@@ -35,7 +36,6 @@ require (
3536
golang.org/x/oauth2 v0.22.0 // indirect
3637
golang.org/x/sync v0.8.0 // indirect
3738
golang.org/x/sys v0.24.0 // indirect
38-
golang.org/x/text v0.17.0 // indirect
3939
golang.org/x/time v0.6.0 // indirect
4040
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
4141
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect

0 commit comments

Comments
 (0)