forked from tmc/langchaingo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
googleai: add initial Vertex (GCP) implementation of Model (tmc#540)
* googleai: add initial Vertex (GCP) implementation of Model re tmc#410
- Loading branch information
Showing
8 changed files
with
541 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package vertex | ||
|
||
import ( | ||
"context" | ||
"log" | ||
|
||
"cloud.google.com/go/vertexai/genai" | ||
"github.com/tmc/langchaingo/callbacks" | ||
"github.com/tmc/langchaingo/llms" | ||
) | ||
|
||
// Vertex is a type that represents a Vertex AI API client. | ||
// | ||
// TODO: This isn't in common code; may need PaLM client for embeddings, etc. | ||
// Note the deltas: type of topk, candidate count. | ||
type Vertex struct { | ||
CallbacksHandler callbacks.Handler | ||
client *genai.Client | ||
opts options | ||
} | ||
|
||
var _ llms.Model = &Vertex{} | ||
|
||
// NewVertex creates a new Vertex struct. | ||
func NewVertex(ctx context.Context, opts ...Option) (*Vertex, error) { | ||
clientOptions := defaultOptions() | ||
for _, opt := range opts { | ||
opt(&clientOptions) | ||
} | ||
|
||
v := &Vertex{ | ||
opts: clientOptions, | ||
} | ||
|
||
client, err := genai.NewClient(ctx, clientOptions.cloudProject, clientOptions.cloudLocation) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
v.client = client | ||
return v, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package vertex | ||
|
||
// options is a set of options for GoogleAI clients. | ||
type options struct { | ||
cloudProject string | ||
cloudLocation string | ||
defaultModel string | ||
defaultEmbeddingModel string | ||
defaultCandidateCount int | ||
defaultMaxTokens int | ||
defaultTemperature float64 | ||
defaultTopK int | ||
defaultTopP float64 | ||
} | ||
|
||
func defaultOptions() options { | ||
return options{ | ||
cloudProject: "", | ||
cloudLocation: "", | ||
defaultModel: "gemini-pro", | ||
defaultEmbeddingModel: "embedding-001", | ||
defaultCandidateCount: 1, | ||
defaultMaxTokens: 256, | ||
defaultTemperature: 0.5, | ||
defaultTopK: 3, | ||
defaultTopP: 0.95, | ||
} | ||
} | ||
|
||
type Option func(*options) | ||
|
||
// WithCloudProject passes the GCP cloud project name to the client. | ||
func WithCloudProject(p string) Option { | ||
return func(opts *options) { | ||
opts.cloudProject = p | ||
} | ||
} | ||
|
||
// WithCloudLocation passes the GCP cloud location (region) name to the client. | ||
func WithCloudLocation(l string) Option { | ||
return func(opts *options) { | ||
opts.cloudLocation = l | ||
} | ||
} | ||
|
||
// WithDefaultModel passes a default content model name to the client. This | ||
// model name is used if not explicitly provided in specific client invocations. | ||
func WithDefaultModel(defaultModel string) Option { | ||
return func(opts *options) { | ||
opts.defaultModel = defaultModel | ||
} | ||
} | ||
|
||
// WithDefaultModel passes a default embedding model name to the client. This | ||
// model name is used if not explicitly provided in specific client invocations. | ||
func WithDefaultEmbeddingModel(defaultEmbeddingModel string) Option { | ||
return func(opts *options) { | ||
opts.defaultEmbeddingModel = defaultEmbeddingModel | ||
} | ||
} |
Oops, something went wrong.