From c6d1cc4a9747d4a04a58b2534e1af10536541458 Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Fri, 28 Jun 2024 17:22:28 +0200 Subject: [PATCH] Fixes #4091: Update RAG docs with vector db examples (#4116) --- .../database-integration/vectordb/chroma.adoc | 11 +++- .../database-integration/vectordb/milvus.adoc | 10 +++- .../vectordb/pinecone.adoc | 9 +++- .../database-integration/vectordb/qdrant.adoc | 9 +++- .../vectordb/weaviate.adoc | 9 +++- .../test/java/apoc/vectordb/ChromaDbTest.java | 47 +++++++++++++++-- .../test/java/apoc/vectordb/MilvusTest.java | 43 ++++++++++++++-- .../test/java/apoc/vectordb/QdrantTest.java | 45 ++++++++++++++-- .../test/java/apoc/vectordb/WeaviateTest.java | 51 +++++++++++++++++-- .../test/java/apoc/vectordb/PineconeTest.java | 35 ++++++++++++- .../java/apoc/vectordb/VectorDbTestUtil.java | 14 +++++ 11 files changed, 263 insertions(+), 20 deletions(-) diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc index f773e83976..8e064db1d6 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc @@ -1,4 +1,3 @@ - == ChromaDB Here is a list of all available ChromaDB procedures, @@ -221,6 +220,16 @@ For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: + +[source,cypher] +---- +CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, [, ], $conf) YIELD node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- + .Delete vectors (it leverages https://docs.trychroma.com/usage-guide#deleting-data-from-a-collection[this API]) [source,cypher] ---- diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc index 02f3bd7db5..6bbe4b9ac9 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/milvus.adoc @@ -1,4 +1,3 @@ - == Milvus Here is a list of all available Milvus procedures: @@ -223,6 +222,15 @@ For example, by executing a `CALL apoc.vectordb.milvus.query(...) YIELD metadata so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: + +[source,cypher] +---- +CALL apoc.vectordb.milvus.getAndUpdate($host, $collection, [, ], $conf) YIELD node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- .Delete vectors (it leverages https://milvus.io/api-reference/restful/v2.4.x/v2/Vector%20(v2)/Delete.md[this API]) diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc index d397507319..2b45ac7530 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/pinecone.adoc @@ -1,4 +1,3 @@ - == Pinecone Here is a list of all available Pinecone procedures: @@ -237,7 +236,15 @@ For example, by executing a `CALL apoc.vectordb.pinecone.query(...) YIELD metada so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: +[source,cypher] +---- +CALL apoc.vectordb.pinecone.getAndUpdate($host, $collection, [, ], $conf) YIELD node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- .Delete vectors (it leverages https://docs.pinecone.io/reference/api/data-plane/delete[this API]) [source,cypher] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc index c604766abf..e21917f9f7 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc @@ -1,4 +1,3 @@ - == Qdrant Here is a list of all available Qdrant procedures, @@ -224,7 +223,15 @@ For example, by executing a `CALL apoc.vectordb.qdrant.query(...) YIELD metadata so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: +[source,cypher] +---- +CALL apoc.vectordb.qdrant.getAndUpdate($host, $collection, [, ], $conf) YIELD node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- .Delete vectors (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/points/operation/delete_vectors[this API]) [source,cypher] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc index 5064a28975..d54d61e401 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc @@ -1,4 +1,3 @@ - == Weaviate Here is a list of all available Weaviate procedures, @@ -240,7 +239,15 @@ For example, by executing a `CALL apoc.vectordb.weaviate.query(...) YIELD metada so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: +[source,cypher] +---- +CALL apoc.vectordb.weaviate.getAndUpdate($host, $collection, [, ], $conf) YIELD score, node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- .Delete vectors (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/objects/delete/objects/\{className\}/\{id\}[this API]) [source,cypher] diff --git a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java index 7904f12dba..3e90b52883 100644 --- a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java +++ b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java @@ -1,5 +1,6 @@ package apoc.vectordb; +import apoc.ml.Prompt; import apoc.util.TestUtil; import org.junit.AfterClass; import org.junit.Before; @@ -16,30 +17,39 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicReference; +import static apoc.ml.Prompt.API_KEY_CONF; +import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; import static apoc.vectordb.VectorDbHandler.Type.CHROMA; +import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.REL; import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertRagWithVectors; import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; -import static apoc.vectordb.VectorDbTestUtil.EntityType.*; +import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import static org.junit.Assert.fail; +import static org.junit.Assert.assertTrue; import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; public class ChromaDbTest { private static final AtomicReference COLL_ID = new AtomicReference<>(); private static final ChromaDBContainer CHROMA_CONTAINER = new ChromaDBContainer("chromadb/chroma:0.4.25.dev137"); + private static final String READONLY_KEY = "my_readonly_api_key"; + private static final Map READONLY_AUTHORIZATION = getAuthHeader(READONLY_KEY); private static String HOST; @@ -60,7 +70,7 @@ public static void setUp() throws Exception { CHROMA_CONTAINER.start(); HOST = "localhost:" + CHROMA_CONTAINER.getMappedPort(8000); - TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class); + TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class, Prompt.class); testCall(db, "CALL apoc.vectordb.chroma.createCollection($host, 'test_collection', 'cosine', 4)", map("host", HOST), @@ -123,7 +133,7 @@ public void getVectorsWithoutVectorResult() { assertNull(row.get("id")); }); } - + @Test public void deleteVector() { testCall(db, """ @@ -421,4 +431,33 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } + + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map(NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + """ + CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector + WITH collect(node) as paths + CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value + RETURN value + """ + , + map( + "host", HOST, + "conf", conf, + "collection", COLL_ID.get(), + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + VectorDbTestUtil::assertRagWithVectors); + } } diff --git a/extended-it/src/test/java/apoc/vectordb/MilvusTest.java b/extended-it/src/test/java/apoc/vectordb/MilvusTest.java index 8ab4d342cb..2597c6cd7f 100644 --- a/extended-it/src/test/java/apoc/vectordb/MilvusTest.java +++ b/extended-it/src/test/java/apoc/vectordb/MilvusTest.java @@ -1,5 +1,6 @@ package apoc.vectordb; +import apoc.ml.Prompt; import apoc.util.TestUtil; import apoc.util.Util; import org.junit.AfterClass; @@ -16,6 +17,8 @@ import java.util.List; import java.util.Map; +import static apoc.ml.Prompt.API_KEY_CONF; +import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -29,16 +32,18 @@ import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; -import static apoc.vectordb.VectorMappingConfig.MODE_KEY; import static apoc.vectordb.VectorMappingConfig.EMBEDDING_KEY; import static apoc.vectordb.VectorMappingConfig.ENTITY_KEY; import static apoc.vectordb.VectorMappingConfig.METADATA_KEY; +import static apoc.vectordb.VectorMappingConfig.MODE_KEY; +import static apoc.vectordb.VectorMappingConfig.MappingMode; import static apoc.vectordb.VectorMappingConfig.NODE_LABEL; import static apoc.vectordb.VectorMappingConfig.REL_TYPE; -import static apoc.vectordb.VectorMappingConfig.MappingMode; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -49,6 +54,8 @@ public class MilvusTest { private static final List FIELDS = List.of("city", "foo"); private static final MilvusContainer MILVUS_CONTAINER = new MilvusContainer("milvusdb/milvus:v2.4.0"); + private static final String READONLY_KEY = "my_readonly_api_key"; + private static final Map READONLY_AUTHORIZATION = getAuthHeader(READONLY_KEY); private static String HOST; @@ -69,7 +76,7 @@ public static void setUp() throws Exception { MILVUS_CONTAINER.start(); HOST = MILVUS_CONTAINER.getEndpoint(); - TestUtil.registerProcedure(db, Milvus.class, VectorDb.class); + TestUtil.registerProcedure(db, Milvus.class, VectorDb.class, Prompt.class); testCall(db, "CALL apoc.vectordb.milvus.createCollection($host, 'test_collection', 'COSINE', 4)", map("host", HOST), @@ -429,4 +436,34 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = map( + FIELDS_KEY, FIELDS, + ALL_RESULTS_KEY, true, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map(NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + """ + CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector + WITH collect(node) as paths + CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value + RETURN value + """ + , + map( + "host", HOST, + "conf", conf, + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + VectorDbTestUtil::assertRagWithVectors); + } + } diff --git a/extended-it/src/test/java/apoc/vectordb/QdrantTest.java b/extended-it/src/test/java/apoc/vectordb/QdrantTest.java index 5d1d4c01cc..1ff1bec6da 100644 --- a/extended-it/src/test/java/apoc/vectordb/QdrantTest.java +++ b/extended-it/src/test/java/apoc/vectordb/QdrantTest.java @@ -3,6 +3,7 @@ import apoc.util.TestUtil; import apoc.util.Util; import org.junit.AfterClass; +import org.junit.Assume; import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -13,9 +14,12 @@ import org.neo4j.test.TestDatabaseManagementServiceBuilder; import org.testcontainers.qdrant.QdrantContainer; +import java.util.List; import java.util.Map; +import apoc.ml.Prompt; import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -35,9 +39,9 @@ import static apoc.vectordb.VectorMappingConfig.*; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.neo4j.configuration.GraphDatabaseSettings.*; @@ -72,7 +76,7 @@ public static void setUp() throws Exception { QDRANT_CONTAINER.start(); HOST = "localhost:" + QDRANT_CONTAINER.getMappedPort(6333); - TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class); + TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class, Prompt.class); testCall(db, "CALL apoc.vectordb.qdrant.createCollection($host, 'test_collection', 'Cosine', 4, $conf)", map("host", HOST, "conf", ADMIN_HEADER_CONF), @@ -177,6 +181,41 @@ public void deleteVector() { }); } + @Test + public void queryVectorsWithRag() { + String openAIKey = System.getenv("OPENAI_KEY");; + Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey); + + db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})"); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map(NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + """ + CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector + WITH collect(node) as paths + CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value + RETURN value + """ + , + map( + "host", HOST, + "conf", conf, + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + r -> { + Map row = r.next(); + Object value = row.get("value"); + assertTrue("The actual value is: " + value, value.toString().contains("Berlin")); + }); + } + @Test public void queryVectors() { testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", @@ -463,5 +502,5 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } - + } diff --git a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java index 4461960ba7..b76869d003 100644 --- a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java +++ b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java @@ -1,8 +1,10 @@ package apoc.vectordb; +import apoc.ml.Prompt; import apoc.util.MapUtil; import apoc.util.TestUtil; import org.junit.AfterClass; +import org.junit.Assume; import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -18,14 +20,24 @@ import java.util.List; import java.util.Map; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testCallEmpty; import static apoc.util.TestUtil.testResult; import static apoc.util.Util.map; import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE; -import static apoc.vectordb.VectorDbTestUtil.*; -import static apoc.vectordb.VectorDbTestUtil.EntityType.*; +import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.REL; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; +import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; +import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; @@ -36,7 +48,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; - import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; @@ -81,7 +92,7 @@ public static void setUp() throws Exception { WEAVIATE_CONTAINER.start(); HOST = WEAVIATE_CONTAINER.getHttpHostAddress(); - TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class); + TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class, Prompt.class); testCall(db, "CALL apoc.vectordb.weaviate.createCollection($host, 'TestCollection', 'cosine', 4, $conf)", MapUtil.map("host", HOST, "conf", ADMIN_HEADER_CONF), @@ -496,4 +507,36 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } + + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = MapUtil.map( + FIELDS_KEY, FIELDS, + ALL_RESULTS_KEY, true, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + """ + CALL apoc.vectordb.weaviate.getAndUpdate($host, 'TestCollection', [$id1], $conf) YIELD score, node, metadata, id, vector + WITH collect(node) as paths + CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value + RETURN value + """ + , + MapUtil.map( + "host", HOST, + "id1", ID_1, + "conf", conf, + "confPrompt", MapUtil.map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + VectorDbTestUtil::assertRagWithVectors); + } } diff --git a/extended/src/test/java/apoc/vectordb/PineconeTest.java b/extended/src/test/java/apoc/vectordb/PineconeTest.java index ead9db3e53..ff9d9e2abb 100644 --- a/extended/src/test/java/apoc/vectordb/PineconeTest.java +++ b/extended/src/test/java/apoc/vectordb/PineconeTest.java @@ -1,5 +1,6 @@ package apoc.vectordb; +import apoc.ml.Prompt; import apoc.util.MapUtil; import apoc.util.TestUtil; import apoc.util.Util; @@ -13,8 +14,10 @@ import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.test.TestDatabaseManagementServiceBuilder; +import java.util.List; import java.util.Map; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; @@ -31,6 +34,7 @@ import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; @@ -67,7 +71,7 @@ public static void setUp() { db = databaseManagementService.database(DEFAULT_DATABASE_NAME); sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); - TestUtil.registerProcedure(db, VectorDb.class, Pinecone.class); + TestUtil.registerProcedure(db, VectorDb.class, Pinecone.class, Prompt.class); ADMIN_AUTHORIZATION = map("Api-Key", API_KEY); ADMIN_HEADER_CONF = map(HEADERS_KEY, ADMIN_AUTHORIZATION); @@ -466,4 +470,33 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } + + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, ADMIN_AUTHORIZATION, + MAPPING_KEY, map(NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + """ + CALL apoc.vectordb.pinecone.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector + WITH collect(node) as paths + CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value + RETURN value + """ + , + map( + "host", HOST, + "conf", conf, + "collection", collName, + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + VectorDbTestUtil::assertRagWithVectors); + } } diff --git a/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java b/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java index ab68f98d9f..1f63223949 100644 --- a/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java +++ b/extended/src/test/java/apoc/vectordb/VectorDbTestUtil.java @@ -1,6 +1,7 @@ package apoc.vectordb; import apoc.util.MapUtil; +import org.junit.Assume; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.ResourceIterator; @@ -100,4 +101,17 @@ public static void assertReadOnlyProcWithMappingResults(Result r, String node) { assertFalse(r.hasNext()); } + + public static void assertRagWithVectors(Result r) { + Map row = r.next(); + Object value = row.get("value"); + assertTrue("The actual value is: " + value, value.toString().contains("Berlin")); + } + + public static String ragSetup(GraphDatabaseService db) { + String openAIKey = System.getenv("OPENAI_KEY");; + Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey); + db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})"); + return openAIKey; + } }