diff --git a/docs/docs/modules/data_connection/vector_stores/pgvector.mdx b/docs/docs/modules/data_connection/vector_stores/pgvector.mdx new file mode 100644 index 000000000..091da8e71 --- /dev/null +++ b/docs/docs/modules/data_connection/vector_stores/pgvector.mdx @@ -0,0 +1,84 @@ +--- +sidebar_label: pgvector +sidebar_position: 1 +draft: true +--- + +import CodeBlock from "@theme/CodeBlock"; +import ExamplePGVector from "@examples/pgvector-vectorstore-example/pgvector_vectorstore_example.go"; + +# Getting Started: pgvector + +[PGVector](https://github.com/pgvector/pgvector) is an open-source vector similarity search for Postgres + +PGVector supports: +* exact and approximate nearest neighbor search +* L2 distance, inner product, and cosine distance +* IVFFlat and HNSW index types + +See the [installation instructions](https://github.com/pgvector/pgvector#installation-notes). + +## Usage with Langchain Go + +In code, create an embedder based on an LLM (OpenAI, Ollama, etc.): +```go + llm, _:= openai.New() + emb, _ := embeddings.NewEmbedder(llm) +``` + +For OpenAI embeddings, you will need obtain an API key and provide as an environment variable to the program: + +```bash + export OPENAI_API_KEY=your_openai_api_key_here +``` + +Create a vector store: +```go + ctx := context.Background() + store, err := pgvector.New( + ctx, + pgvector.WithConnectionURL("postgres://testuser:testpass@localhost:5432/testdb?sslmode=disable"), + pgvector.WithEmbedder(emb), + ) +``` + +Document tables will be created automatically. + +Add documents: +```go + _, err = store.AddDocuments(context.Background(), []schema.Document{ + { + PageContent: "Tokyo", + Metadata: map[string]any{ + "population": 38, + "area": 2190, + }, + }, + { + PageContent: "Sao Paulo", + Metadata: map[string]any{ + "population": 22.6, + "area": 1523, + }, + }, + }) +``` + +Run a similarity search using cosine distance (`<=>`): + +```go + filter := map[string]any{"area": "1523"} + + docs, err = store.SimilaritySearch(ctx, "only cities in south america", + 10, + vectorstores.WithScoreThreshold(0.80), + vectorstores.WithFilters(filter), + ) +``` + +For now, pgvector integration only supports simple key-value filters and cosine distance search. + +## Full example + +Here is the entire program (from [pgvector-vectorstore-example](https://github.com/tmc/langchaingo/blob/main/examples/pgvector-vectorstore-example/pgvector_vectorstore_example.go)): +{ExamplePGVector} \ No newline at end of file diff --git a/go.mod b/go.mod index 8783358da..1504d4fb2 100644 --- a/go.mod +++ b/go.mod @@ -13,12 +13,12 @@ require ( ) require ( - cloud.google.com/go v0.110.8 // indirect + cloud.google.com/go v0.111.0 // indirect cloud.google.com/go/ai v0.3.0 // indirect - cloud.google.com/go/compute v1.23.1 // indirect + cloud.google.com/go/compute v1.23.3 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect - cloud.google.com/go/iam v1.1.3 // indirect - cloud.google.com/go/longrunning v0.5.2 // indirect + cloud.google.com/go/iam v1.1.5 // indirect + cloud.google.com/go/longrunning v0.5.4 // indirect dario.cat/mergo v1.0.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Masterminds/goutils v1.1.1 // indirect @@ -122,21 +122,23 @@ require ( go.opencensus.io v0.24.0 // indirect golang.org/x/crypto v0.17.0 // indirect golang.org/x/mod v0.11.0 // indirect - golang.org/x/net v0.17.0 // indirect - golang.org/x/oauth2 v0.13.0 // indirect - golang.org/x/sync v0.4.0 // indirect + golang.org/x/net v0.19.0 // indirect + golang.org/x/oauth2 v0.15.0 // indirect + golang.org/x/sync v0.5.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect + golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.10.0 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20231016165738-49dd2c1f3d0b // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect + google.golang.org/appengine v1.6.8 // indirect + google.golang.org/genproto v0.0.0-20231120223509-83a465c0220f // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20231211222908-989df2bf70f3 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20231211222908-989df2bf70f3 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) require ( - cloud.google.com/go/aiplatform v1.51.1 + cloud.google.com/go/aiplatform v1.58.0 + cloud.google.com/go/vertexai v0.6.0 github.com/Masterminds/sprig/v3 v3.2.3 github.com/PuerkitoBio/goquery v1.8.1 github.com/amikos-tech/chroma-go v0.0.0-20231228181736-e8f5e927093e @@ -162,7 +164,7 @@ require ( gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a go.starlark.net v0.0.0-20230302034142-4b1e35fe2254 golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 - google.golang.org/api v0.149.0 - google.golang.org/grpc v1.59.0 + google.golang.org/api v0.152.0 + google.golang.org/grpc v1.60.0 google.golang.org/protobuf v1.31.0 ) diff --git a/go.sum b/go.sum index fd23199fb..f92f3b497 100644 --- a/go.sum +++ b/go.sum @@ -13,28 +13,28 @@ cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKV cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= -cloud.google.com/go v0.110.8 h1:tyNdfIxjzaWctIiLYOTalaLKZ17SI44SKFW26QbOhME= -cloud.google.com/go v0.110.8/go.mod h1:Iz8AkXJf1qmxC3Oxoep8R1T36w8B92yU29PcBhHO5fk= +cloud.google.com/go v0.111.0 h1:YHLKNupSD1KqjDbQ3+LVdQ81h/UJbJyZG203cEfnQgM= +cloud.google.com/go v0.111.0/go.mod h1:0mibmpKP1TyOOFYQY5izo0LnT+ecvOQ0Sg3OdmMiNRU= cloud.google.com/go/ai v0.3.0 h1:M617N0brv+XFch2KToZUhv6ggzgFZMUnmDkNQjW2pYg= cloud.google.com/go/ai v0.3.0/go.mod h1:dTuQIBA8Kljuas5z1WNot1QZOl476A9TsFqEi6pzJlI= -cloud.google.com/go/aiplatform v1.51.1 h1:g+y03dll9HnX9U0oBKIqUOI+8VQWT1QJF12VGxkal0Q= -cloud.google.com/go/aiplatform v1.51.1/go.mod h1:kY3nIMAVQOK2XDqDPHaOuD9e+FdMA6OOpfBjsvaFSOo= +cloud.google.com/go/aiplatform v1.58.0 h1:xyCAfpI4yUMOQ4VtHN/bdmxPQ8xoEkTwFM1nbVmuQhs= +cloud.google.com/go/aiplatform v1.58.0/go.mod h1:pwZMGvqe0JRkI1GWSZCtnAfrR4K1bv65IHILGA//VEU= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/compute v1.23.1 h1:V97tBoDaZHb6leicZ1G6DLK2BAaZLJ/7+9BB/En3hR0= -cloud.google.com/go/compute v1.23.1/go.mod h1:CqB3xpmPKKt3OJpW2ndFIXnA9A4xAy/F3Xp1ixncW78= +cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiVlk= +cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= -cloud.google.com/go/iam v1.1.3 h1:18tKG7DzydKWUnLjonWcJO6wjSCAtzh4GcRKlH/Hrzc= -cloud.google.com/go/iam v1.1.3/go.mod h1:3khUlaBXfPKKe7huYgEpDn6FtgRyMEqbkvBxrQyY5SE= -cloud.google.com/go/longrunning v0.5.2 h1:u+oFqfEwwU7F9dIELigxbe0XVnBAo9wqMuQLA50CZ5k= -cloud.google.com/go/longrunning v0.5.2/go.mod h1:nqo6DQbNV2pXhGDbDMoN2bWz68MjZUzqv2YttZiveCs= +cloud.google.com/go/iam v1.1.5 h1:1jTsCu4bcsNsE4iiqNT5SHwrDRCfRmIaaaVFhRveTJI= +cloud.google.com/go/iam v1.1.5/go.mod h1:rB6P/Ic3mykPbFio+vo7403drjlgvoWfYpJhMXEbzv8= +cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg= +cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= @@ -44,6 +44,8 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= +cloud.google.com/go/vertexai v0.6.0 h1:f/2hvwTI/MsVKz1IwMBJpPNTgRv9RZFIXF1tDEGyww8= +cloud.google.com/go/vertexai v0.6.0/go.mod h1:aX7eXETSezwz1aSIXc0kljpOfJ420YJBNIXV72HHsqA= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= @@ -832,16 +834,16 @@ golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.13.0 h1:jDDenyj+WgFtmV3zYVoi8aE2BwtXFLWOA67ZfNWftiY= -golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0= +golang.org/x/oauth2 v0.15.0 h1:s8pnnxNVzjWyrvYdFUQq5llS1PX2zhPXmccZv99h7uQ= +golang.org/x/oauth2 v0.15.0/go.mod h1:q48ptWNTY5XWf+JNten23lcvHpLJ0ZSxF5ttTHKVCAM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -855,8 +857,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= -golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -943,6 +945,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= @@ -953,8 +956,8 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1029,16 +1032,16 @@ google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0M google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= -google.golang.org/api v0.149.0 h1:b2CqT6kG+zqJIVKRQ3ELJVLN1PwHZ6DJ3dW8yl82rgY= -google.golang.org/api v0.149.0/go.mod h1:Mwn1B7JTXrzXtnvmzQE2BD6bYZQ8DShKZDZbeN9I7qI= +google.golang.org/api v0.152.0 h1:t0r1vPnfMc260S2Ci+en7kfCZaLOPs5KI0sVV/6jZrY= +google.golang.org/api v0.152.0/go.mod h1:3qNJX5eOmhiWYc67jRA/3GsDw97UFb5ivv7Y2PrriAY= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1073,12 +1076,12 @@ google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210624195500-8bfb893ecb84/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24= google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= -google.golang.org/genproto v0.0.0-20231016165738-49dd2c1f3d0b h1:+YaDE2r2OG8t/z5qmsh7Y+XXwCbvadxxZ0YY6mTdrVA= -google.golang.org/genproto v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:CgAqfJo+Xmu0GwA0411Ht3OU3OntXwsGmrmjI8ioGXI= -google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b h1:CIC2YMXmIhYw6evmhPxBKJ4fmLbOFtXQN/GV3XOZR8k= -google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:IBQ646DjkDkvUIsVq/cc03FUFQ9wbZu7yE396YcL870= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b h1:ZlWIi1wSK56/8hn4QcBp/j9M7Gt3U/3hZw3mC7vDICo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:swOH3j0KzcDDgGUWr+SNpyTen5YrXjS3eyPzFYKc6lc= +google.golang.org/genproto v0.0.0-20231120223509-83a465c0220f h1:Vn+VyHU5guc9KjB5KrjI2q0wCOWEOIh0OEsleqakHJg= +google.golang.org/genproto v0.0.0-20231120223509-83a465c0220f/go.mod h1:nWSwAFPb+qfNJXsoeO3Io7zf4tMSfN8EA8RlDA04GhY= +google.golang.org/genproto/googleapis/api v0.0.0-20231211222908-989df2bf70f3 h1:EWIeHfGuUf00zrVZGEgYFxok7plSAXBGcH7NNdMAWvA= +google.golang.org/genproto/googleapis/api v0.0.0-20231211222908-989df2bf70f3/go.mod h1:k2dtGpRrbsSyKcNPKKI5sstZkrNCZwpU/ns96JoHbGg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20231211222908-989df2bf70f3 h1:kzJAXnzZoFbe5bhZd4zjUuHos/I31yH4thfMb/13oVY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20231211222908-989df2bf70f3/go.mod h1:eJVxU6o+4G1PSczBr85xmyvSNYAKvAYgkub40YGomFM= google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= @@ -1098,8 +1101,8 @@ google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= -google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= -google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= +google.golang.org/grpc v1.60.0 h1:6FQAR0kM31P6MRdeluor2w2gPaS4SVNrD/DNTxrQ15k= +google.golang.org/grpc v1.60.0/go.mod h1:OlCHIeLYqSSsLi6i49B5QGdzaMZK9+M7LXN2FKz4eGM= google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f h1:rqzndB2lIQGivcXdTuY3Y9NBvr70X+y77woofSRluec= google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f/go.mod h1:gxndsbNG1n4TZcHGgsYEfVGnTxqfEdfiDv6/DADXX9o= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/llms/googleai/download.go b/llms/googleai/download.go index ec2c618c6..60bcce2a7 100644 --- a/llms/googleai/download.go +++ b/llms/googleai/download.go @@ -10,7 +10,7 @@ import ( // downloadImageData downloads the content from the given URL and returns the // image type and data. The image type is the second part of the response's // MIME (e.g. "png" from "image/png"). -func downloadImageData(url string) (string, []byte, error) { +func DownloadImageData(url string) (string, []byte, error) { resp, err := http.Get(url) //nolint if err != nil { return "", nil, fmt.Errorf("failed to fetch image from url: %w", err) diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index eb8161025..2a9099cb8 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -178,7 +178,7 @@ func convertParts(parts []llms.ContentPart) ([]genai.Part, error) { case llms.BinaryContent: out = genai.Blob{MIMEType: p.MIMEType, Data: p.Data} case llms.ImageURLContent: - typ, data, err := downloadImageData(p.URL) + typ, data, err := DownloadImageData(p.URL) if err != nil { return nil, err } diff --git a/llms/googleai/vertex/new.go b/llms/googleai/vertex/new.go new file mode 100644 index 000000000..cf85a1ff9 --- /dev/null +++ b/llms/googleai/vertex/new.go @@ -0,0 +1,42 @@ +package vertex + +import ( + "context" + "log" + + "cloud.google.com/go/vertexai/genai" + "github.com/tmc/langchaingo/callbacks" + "github.com/tmc/langchaingo/llms" +) + +// Vertex is a type that represents a Vertex AI API client. +// +// TODO: This isn't in common code; may need PaLM client for embeddings, etc. +// Note the deltas: type of topk, candidate count. +type Vertex struct { + CallbacksHandler callbacks.Handler + client *genai.Client + opts options +} + +var _ llms.Model = &Vertex{} + +// NewVertex creates a new Vertex struct. +func NewVertex(ctx context.Context, opts ...Option) (*Vertex, error) { + clientOptions := defaultOptions() + for _, opt := range opts { + opt(&clientOptions) + } + + v := &Vertex{ + opts: clientOptions, + } + + client, err := genai.NewClient(ctx, clientOptions.cloudProject, clientOptions.cloudLocation) + if err != nil { + log.Fatal(err) + } + + v.client = client + return v, nil +} diff --git a/llms/googleai/vertex/option.go b/llms/googleai/vertex/option.go new file mode 100644 index 000000000..d6f7f5854 --- /dev/null +++ b/llms/googleai/vertex/option.go @@ -0,0 +1,60 @@ +package vertex + +// options is a set of options for GoogleAI clients. +type options struct { + cloudProject string + cloudLocation string + defaultModel string + defaultEmbeddingModel string + defaultCandidateCount int + defaultMaxTokens int + defaultTemperature float64 + defaultTopK int + defaultTopP float64 +} + +func defaultOptions() options { + return options{ + cloudProject: "", + cloudLocation: "", + defaultModel: "gemini-pro", + defaultEmbeddingModel: "embedding-001", + defaultCandidateCount: 1, + defaultMaxTokens: 256, + defaultTemperature: 0.5, + defaultTopK: 3, + defaultTopP: 0.95, + } +} + +type Option func(*options) + +// WithCloudProject passes the GCP cloud project name to the client. +func WithCloudProject(p string) Option { + return func(opts *options) { + opts.cloudProject = p + } +} + +// WithCloudLocation passes the GCP cloud location (region) name to the client. +func WithCloudLocation(l string) Option { + return func(opts *options) { + opts.cloudLocation = l + } +} + +// WithDefaultModel passes a default content model name to the client. This +// model name is used if not explicitly provided in specific client invocations. +func WithDefaultModel(defaultModel string) Option { + return func(opts *options) { + opts.defaultModel = defaultModel + } +} + +// WithDefaultModel passes a default embedding model name to the client. This +// model name is used if not explicitly provided in specific client invocations. +func WithDefaultEmbeddingModel(defaultEmbeddingModel string) Option { + return func(opts *options) { + opts.defaultEmbeddingModel = defaultEmbeddingModel + } +} diff --git a/llms/googleai/vertex/vertex.go b/llms/googleai/vertex/vertex.go new file mode 100644 index 000000000..96ffe1312 --- /dev/null +++ b/llms/googleai/vertex/vertex.go @@ -0,0 +1,284 @@ +// DO NOT EDIT: this code is auto-generated from llms/googleai/googleai_llm.go +package vertex + +import ( + "context" + "errors" + "fmt" + "log" + "strings" + + "cloud.google.com/go/vertexai/genai" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/googleai" + "github.com/tmc/langchaingo/schema" + "google.golang.org/api/iterator" +) + +var ( + ErrNoContentInResponse = errors.New("no content in generation response") + ErrUnknownPartInResponse = errors.New("unknown part type in generation response") + ErrInvalidMimeType = errors.New("invalid mime type on content") + ErrSystemRoleNotSupported = errors.New("system role isn't supporeted yet") +) + +const ( + CITATIONS = "citations" + SAFETY = "safety" + RoleModel = "model" + RoleUser = "user" +) + +// Call implements the [llms.Model] interface. +func (g *Vertex) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { + return llms.GenerateFromSinglePrompt(ctx, g, prompt, options...) +} + +// GenerateContent implements the [llms.Model] interface. +func (g *Vertex) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if g.CallbacksHandler != nil { + g.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := llms.CallOptions{ + Model: g.opts.defaultModel, + CandidateCount: g.opts.defaultCandidateCount, + MaxTokens: g.opts.defaultMaxTokens, + Temperature: g.opts.defaultTemperature, + TopP: g.opts.defaultTopP, + TopK: g.opts.defaultTopK, + } + for _, opt := range options { + opt(&opts) + } + + model := g.client.GenerativeModel(opts.Model) + model.SetCandidateCount(int32(opts.CandidateCount)) + model.SetMaxOutputTokens(int32(opts.MaxTokens)) + model.SetTemperature(float32(opts.Temperature)) + model.SetTopP(float32(opts.TopP)) + model.SetTopK(float32(opts.TopK)) + model.StopSequences = opts.StopWords + + var response *llms.ContentResponse + var err error + + if len(messages) == 1 { + theMessage := messages[0] + if theMessage.Role != schema.ChatMessageTypeHuman { + return nil, fmt.Errorf("got %v message role, want human", theMessage.Role) + } + response, err = generateFromSingleMessage(ctx, model, theMessage.Parts, &opts) + } else { + response, err = generateFromMessages(ctx, model, messages, &opts) + } + if err != nil { + return nil, err + } + + if g.CallbacksHandler != nil { + g.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) + } + + return response, nil +} + +// convertCandidates converts a sequence of genai.Candidate to a response. +func convertCandidates(candidates []*genai.Candidate) (*llms.ContentResponse, error) { + var contentResponse llms.ContentResponse + + for _, candidate := range candidates { + buf := strings.Builder{} + + if candidate.Content != nil { + for _, part := range candidate.Content.Parts { + if v, ok := part.(genai.Text); ok { + _, err := buf.WriteString(string(v)) + if err != nil { + return nil, err + } + } else { + return nil, ErrUnknownPartInResponse + } + } + } + + metadata := make(map[string]any) + metadata[CITATIONS] = candidate.CitationMetadata + metadata[SAFETY] = candidate.SafetyRatings + + contentResponse.Choices = append(contentResponse.Choices, + &llms.ContentChoice{ + Content: buf.String(), + StopReason: candidate.FinishReason.String(), + GenerationInfo: metadata, + }) + } + return &contentResponse, nil +} + +// CreateEmbedding creates embeddings from texts. +func (g *Vertex) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) { + panic("not implemented") +} + +// convertParts converts between a sequence of langchain parts and genai parts. +func convertParts(parts []llms.ContentPart) ([]genai.Part, error) { + convertedParts := make([]genai.Part, 0, len(parts)) + for _, part := range parts { + var out genai.Part + + switch p := part.(type) { + case llms.TextContent: + out = genai.Text(p.Text) + case llms.BinaryContent: + out = genai.Blob{MIMEType: p.MIMEType, Data: p.Data} + case llms.ImageURLContent: + typ, data, err := googleai.DownloadImageData(p.URL) + if err != nil { + return nil, err + } + out = genai.ImageData(typ, data) + } + + convertedParts = append(convertedParts, out) + } + return convertedParts, nil +} + +// convertContent converts between a langchain MessageContent and genai content. +func convertContent(content llms.MessageContent) (*genai.Content, error) { + parts, err := convertParts(content.Parts) + if err != nil { + return nil, err + } + + c := &genai.Content{ + Parts: parts, + } + + switch content.Role { + case schema.ChatMessageTypeSystem: + return nil, ErrSystemRoleNotSupported + case schema.ChatMessageTypeAI: + c.Role = RoleModel + case schema.ChatMessageTypeHuman: + c.Role = RoleUser + case schema.ChatMessageTypeGeneric: + c.Role = RoleUser + case schema.ChatMessageTypeFunction: + fallthrough + default: + return nil, fmt.Errorf("role %v not supported", content.Role) + } + + return c, nil +} + +// generateFromSingleMessage generates content from the parts of a single +// message. +func generateFromSingleMessage(ctx context.Context, model *genai.GenerativeModel, parts []llms.ContentPart, opts *llms.CallOptions) (*llms.ContentResponse, error) { + convertedParts, err := convertParts(parts) + if err != nil { + return nil, err + } + + if opts.StreamingFunc == nil { + // When no streaming is requested, just call GenerateContent and return + // the complete response with a list of candidates. + resp, err := model.GenerateContent(ctx, convertedParts...) + if err != nil { + return nil, err + } + + if len(resp.Candidates) == 0 { + return nil, ErrNoContentInResponse + } + return convertCandidates(resp.Candidates) + } + iter := model.GenerateContentStream(ctx, convertedParts...) + return convertAndStreamFromIterator(ctx, iter, opts) +} + +func generateFromMessages(ctx context.Context, model *genai.GenerativeModel, messages []llms.MessageContent, opts *llms.CallOptions) (*llms.ContentResponse, error) { + history := make([]*genai.Content, 0, len(messages)) + for _, mc := range messages { + content, err := convertContent(mc) + if err != nil { + return nil, err + } + history = append(history, content) + } + + // Given N total messages, genai's chat expects the first N-1 messages as + // history and the last message as the actual request. + n := len(history) + reqContent := history[n-1] + history = history[:n-1] + + if reqContent.Role != RoleUser { + return nil, fmt.Errorf("got %v message role, want user/human", reqContent.Role) + } + + session := model.StartChat() + session.History = history + + if opts.StreamingFunc == nil { + resp, err := session.SendMessage(ctx, reqContent.Parts...) + if err != nil { + return nil, err + } + + if len(resp.Candidates) == 0 { + return nil, ErrNoContentInResponse + } + return convertCandidates(resp.Candidates) + } + iter := session.SendMessageStream(ctx, reqContent.Parts...) + return convertAndStreamFromIterator(ctx, iter, opts) +} + +// convertAndStreamFromIterator takes an iterator of GenerateContentResponse +// and produces a llms.ContentResponse reply from it, while streaming the +// resulting text into the opts-provided streaming function. +// Note that this is tricky in the face of multiple +// candidates, so this code assumes only a single candidate for now. +func convertAndStreamFromIterator(ctx context.Context, iter *genai.GenerateContentResponseIterator, opts *llms.CallOptions) (*llms.ContentResponse, error) { + candidate := &genai.Candidate{ + Content: &genai.Content{}, + } +DoStream: + for { + resp, err := iter.Next() + if errors.Is(err, iterator.Done) { + break DoStream + } + if err != nil { + log.Fatal(err) + } + + if len(resp.Candidates) != 1 { + return nil, fmt.Errorf("expect single candidate in stream mode; got %v", len(resp.Candidates)) + } + respCandidate := resp.Candidates[0] + + if respCandidate.Content == nil { + break DoStream + } + candidate.Content.Parts = append(candidate.Content.Parts, respCandidate.Content.Parts...) + candidate.Content.Role = respCandidate.Content.Role + candidate.FinishReason = respCandidate.FinishReason + candidate.SafetyRatings = respCandidate.SafetyRatings + candidate.CitationMetadata = respCandidate.CitationMetadata + + for _, part := range respCandidate.Content.Parts { + if text, ok := part.(genai.Text); ok { + if opts.StreamingFunc(ctx, []byte(text)) != nil { + break DoStream + } + } + } + } + + return convertCandidates([]*genai.Candidate{candidate}) +} diff --git a/llms/googleai/vertex/vertex_test.go b/llms/googleai/vertex/vertex_test.go new file mode 100644 index 000000000..336c04a1d --- /dev/null +++ b/llms/googleai/vertex/vertex_test.go @@ -0,0 +1,91 @@ +package vertex + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/schema" +) + +func newClient(t *testing.T) *Vertex { + t.Helper() + + project := os.Getenv("VERTEX_PROJECT") + if project == "" { + t.Skip("VERTEX_PROJECT not set") + return nil + } + location := os.Getenv("VERTEX_LOCATION") + if location == "" { + location = "us-central1" + } + + llm, err := NewVertex(context.Background(), WithCloudProject(project), WithCloudLocation(location)) + require.NoError(t, err) + return llm +} + +func TestMultiContentText(t *testing.T) { + t.Parallel() + llm := newClient(t) + + parts := []llms.ContentPart{ + llms.TextContent{Text: "I'm a pomeranian"}, + llms.TextContent{Text: "What kind of mammal am I?"}, + } + content := []llms.MessageContent{ + { + Role: schema.ChatMessageTypeHuman, + Parts: parts, + }, + } + + rsp, err := llm.GenerateContent(context.Background(), content) + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + c1 := rsp.Choices[0] + assert.Regexp(t, "(?i)dog|canid|canine", c1.Content) +} + +func TestMultiContentTextStream(t *testing.T) { + t.Parallel() + llm := newClient(t) + + parts := []llms.ContentPart{ + llms.TextContent{Text: "I'm a pomeranian"}, + llms.TextContent{Text: "Tell me more about my taxonomy"}, + } + content := []llms.MessageContent{ + { + Role: schema.ChatMessageTypeHuman, + Parts: parts, + }, + } + + var chunks [][]byte + var sb strings.Builder + rsp, err := llm.GenerateContent(context.Background(), content, + llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error { + chunks = append(chunks, chunk) + sb.Write(chunk) + return nil + })) + + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + // Check that the combined response contains what we expect + c1 := rsp.Choices[0] + assert.Regexp(t, "(?i)dog|canid|canine", c1.Content) + + // Check that multiple chunks were received and they also have words + // we expect. + assert.GreaterOrEqual(t, len(chunks), 2) + assert.Regexp(t, "(?i)dog|canid|canine", sb.String()) +}