Skip to content

Commit

Permalink
googleai: add initial Vertex (GCP) implementation of Model (tmc#540)
Browse files Browse the repository at this point in the history
* googleai: add initial Vertex (GCP) implementation of Model

re tmc#410
  • Loading branch information
eliben authored Jan 22, 2024
1 parent c3c9d57 commit 029ff8e
Show file tree
Hide file tree
Showing 8 changed files with 541 additions and 45 deletions.
36 changes: 21 additions & 15 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@ require (
)

require (
cloud.google.com/go v0.110.8 // indirect
cloud.google.com/go v0.111.0 // indirect
cloud.google.com/go/ai v0.3.0 // indirect
cloud.google.com/go/compute v1.23.1 // indirect
cloud.google.com/go/compute v1.23.3 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v1.1.3 // indirect
cloud.google.com/go/longrunning v0.5.2 // indirect
cloud.google.com/go/iam v1.1.5 // indirect
cloud.google.com/go/longrunning v0.5.4 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.2.0 // indirect
github.com/PuerkitoBio/purell v1.1.1 // indirect
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
github.com/alecthomas/colour v0.1.0 // indirect
github.com/alecthomas/repr v0.0.0-20210801044451-80ca428c5142 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/antchfx/htmlquery v1.3.0 // indirect
github.com/antchfx/xmlquery v1.3.17 // indirect
Expand Down Expand Up @@ -62,6 +64,7 @@ require (
github.com/kr/pretty v0.3.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2 // indirect
github.com/mitchellh/copystructure v1.0.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
Expand All @@ -74,6 +77,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
github.com/sergi/go-diff v1.2.0 // indirect
github.com/shopspring/decimal v1.2.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/cast v1.3.1 // indirect
Expand All @@ -89,22 +93,25 @@ require (
go.mongodb.org/mongo-driver v1.11.3 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/oauth2 v0.13.0 // indirect
golang.org/x/sync v0.4.0 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/oauth2 v0.15.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20231016165738-49dd2c1f3d0b // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20231120223509-83a465c0220f // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20231211222908-989df2bf70f3 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231211222908-989df2bf70f3 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

require (
cloud.google.com/go/aiplatform v1.51.1
cloud.google.com/go/aiplatform v1.58.0
cloud.google.com/go/vertexai v0.6.0
github.com/Masterminds/sprig/v3 v3.2.3
github.com/PuerkitoBio/goquery v1.8.1
github.com/alecthomas/assert v1.0.0
github.com/amikos-tech/chroma-go v0.0.0-20231228181736-e8f5e927093e
github.com/cohere-ai/tokenizer v1.1.2
github.com/go-openapi/strfmt v0.21.3
Expand All @@ -113,7 +120,6 @@ require (
github.com/google/generative-ai-go v0.5.0
github.com/google/go-cmp v0.6.0
github.com/jackc/pgx/v5 v5.4.1
github.com/joho/godotenv v1.5.1
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80
github.com/mattn/go-sqlite3 v1.14.17
github.com/metaphorsystems/metaphor-go v0.0.0-20230816231421-43794c04824e
Expand All @@ -129,7 +135,7 @@ require (
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
google.golang.org/api v0.149.0
google.golang.org/grpc v1.59.0
google.golang.org/api v0.152.0
google.golang.org/grpc v1.60.0
google.golang.org/protobuf v1.31.0
)
69 changes: 41 additions & 28 deletions go.sum

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion llms/googleai/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
// downloadImageData downloads the content from the given URL and returns the
// image type and data. The image type is the second part of the response's
// MIME (e.g. "png" from "image/png").
func downloadImageData(url string) (string, []byte, error) {
func DownloadImageData(url string) (string, []byte, error) {
resp, err := http.Get(url) //nolint
if err != nil {
return "", nil, fmt.Errorf("failed to fetch image from url: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion llms/googleai/googleai_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func convertParts(parts []llms.ContentPart) ([]genai.Part, error) {
case llms.BinaryContent:
out = genai.Blob{MIMEType: p.MIMEType, Data: p.Data}
case llms.ImageURLContent:
typ, data, err := downloadImageData(p.URL)
typ, data, err := DownloadImageData(p.URL)
if err != nil {
return nil, err
}
Expand Down
42 changes: 42 additions & 0 deletions llms/googleai/vertex/new.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package vertex

import (
"context"
"log"

"cloud.google.com/go/vertexai/genai"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
)

// Vertex is a type that represents a Vertex AI API client.
//
// TODO: This isn't in common code; may need PaLM client for embeddings, etc.
// Note the deltas: type of topk, candidate count.
type Vertex struct {
CallbacksHandler callbacks.Handler
client *genai.Client
opts options
}

var _ llms.Model = &Vertex{}

// NewVertex creates a new Vertex struct.
func NewVertex(ctx context.Context, opts ...Option) (*Vertex, error) {
clientOptions := defaultOptions()
for _, opt := range opts {
opt(&clientOptions)
}

v := &Vertex{
opts: clientOptions,
}

client, err := genai.NewClient(ctx, clientOptions.cloudProject, clientOptions.cloudLocation)
if err != nil {
log.Fatal(err)
}

v.client = client
return v, nil
}
60 changes: 60 additions & 0 deletions llms/googleai/vertex/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package vertex

// options is a set of options for GoogleAI clients.
type options struct {
cloudProject string
cloudLocation string
defaultModel string
defaultEmbeddingModel string
defaultCandidateCount int
defaultMaxTokens int
defaultTemperature float64
defaultTopK int
defaultTopP float64
}

func defaultOptions() options {
return options{
cloudProject: "",
cloudLocation: "",
defaultModel: "gemini-pro",
defaultEmbeddingModel: "embedding-001",
defaultCandidateCount: 1,
defaultMaxTokens: 256,
defaultTemperature: 0.5,
defaultTopK: 3,
defaultTopP: 0.95,
}
}

type Option func(*options)

// WithCloudProject passes the GCP cloud project name to the client.
func WithCloudProject(p string) Option {
return func(opts *options) {
opts.cloudProject = p
}
}

// WithCloudLocation passes the GCP cloud location (region) name to the client.
func WithCloudLocation(l string) Option {
return func(opts *options) {
opts.cloudLocation = l
}
}

// WithDefaultModel passes a default content model name to the client. This
// model name is used if not explicitly provided in specific client invocations.
func WithDefaultModel(defaultModel string) Option {
return func(opts *options) {
opts.defaultModel = defaultModel
}
}

// WithDefaultModel passes a default embedding model name to the client. This
// model name is used if not explicitly provided in specific client invocations.
func WithDefaultEmbeddingModel(defaultEmbeddingModel string) Option {
return func(opts *options) {
opts.defaultEmbeddingModel = defaultEmbeddingModel
}
}
Loading

0 comments on commit 029ff8e

Please sign in to comment.