From c7422bec1c237c9d57e3d1cf255f2e3a81b7c805 Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Fri, 28 Jun 2024 17:22:28 +0200 Subject: [PATCH 1/2] [NOID] Fixes #4091: Update RAG docs with vector db examples (#4116) --- .../database-integration/vectordb/chroma.adoc | 16 +++++ .../database-integration/vectordb/qdrant.adoc | 9 ++- .../vectordb/weaviate.adoc | 9 ++- .../apoc/full/it/vectordb/ChromaDbTest.java | 58 ++++++++++++++++- .../apoc/full/it/vectordb/QdrantTest.java | 44 ++++++++++++- .../apoc/full/it/vectordb/WeaviateTest.java | 65 ++++++++++++++++++- .../java/apoc/vectordb/VectorDbTestUtil.java | 33 ++++++++++ 7 files changed, 226 insertions(+), 8 deletions(-) diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc index a0fe25f804..3dacecc5ff 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc @@ -148,6 +148,22 @@ CALL apoc.vectordb.chroma.query($host, '', ---- +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures. +For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]}, +so that we do not return the other values that we do not need. +==== + +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: + +[source,cypher] +---- +CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, [, ], $conf) YIELD node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- .Delete vectors (it leverages https://docs.trychroma.com/usage-guide#deleting-data-from-a-collection[this API]) [source,cypher] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc index 073fa9d146..30414b618a 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc @@ -1,4 +1,3 @@ - == Qdrant Here is a list of all available Qdrant procedures, @@ -200,7 +199,15 @@ For example, by executing a `CALL apoc.vectordb.qdrant.query(...) YIELD metadata so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: +[source,cypher] +---- +CALL apoc.vectordb.qdrant.getAndUpdate($host, $collection, [, ], $conf) YIELD node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- .Delete vectors (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/points/operation/delete_vectors[this API]) [source,cypher] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc index be672294aa..6268f4a7e8 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc @@ -1,4 +1,3 @@ - == Weaviate Here is a list of all available Weaviate procedures, @@ -213,7 +212,15 @@ For example, by executing a `CALL apoc.vectordb.weaviate.query(...) YIELD metada so that we do not return the other values that we do not need. ==== +It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow: +[source,cypher] +---- +CALL apoc.vectordb.weaviate.getAndUpdate($host, $collection, [, ], $conf) YIELD score, node, metadata, id, vector +WITH collect(node) as paths +CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value +RETURN value +---- .Delete vectors (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/objects/delete/objects/\{className\}/\{id\}[this API]) [source,cypher] diff --git a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java index 4c60de3ecd..6d0274c12b 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java @@ -1,5 +1,23 @@ package apoc.full.it.vectordb; +import apoc.ml.Prompt; +import apoc.util.TestUtil; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.test.TestDatabaseManagementServiceBuilder; +import org.testcontainers.chromadb.ChromaDBContainer; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -10,7 +28,10 @@ import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; +import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; @@ -27,6 +48,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; + +import apoc.vectordb.VectorDbTestUtil; import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; @@ -60,9 +83,9 @@ public static void setUp() throws Exception { sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); CHROMA_CONTAINER.start(); - HOST = CHROMA_CONTAINER.getEndpoint(); - TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class); + HOST = CHROMA_CONTAINER.getEndpoint(); + TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class, Prompt.class)); testCall( db, @@ -414,7 +437,7 @@ public void queryVectorsWithCreateRel() { @Test public void queryVectorsWithSystemDbStorage() { String keyConfig = "chroma-config-foo"; - String baseUrl = HOST; + String baseUrl = "http://" + HOST; Map mapping = map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo"); sysDb.executeTransactionally( @@ -452,4 +475,33 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } + + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map(NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + """ + CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector + WITH collect(node) as paths + CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value + RETURN value + """ + , + map( + "host", HOST, + "conf", conf, + "collection", COLL_ID.get(), + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + VectorDbTestUtil::assertRagWithVectors); + } } diff --git a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java index 9ec5890ad3..7edca73f34 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java @@ -22,17 +22,22 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.neo4j.configuration.GraphDatabaseSettings.*; +import apoc.ml.Prompt; import apoc.util.TestUtil; import apoc.util.Util; import apoc.vectordb.Qdrant; import apoc.vectordb.VectorDb; import apoc.vectordb.VectorDbTestUtil; + +import java.util.List; import java.util.Map; import org.assertj.core.api.Assertions; import org.junit.AfterClass; +import org.junit.Assume; import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -72,9 +77,9 @@ public static void setUp() throws Exception { sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); QDRANT_CONTAINER.start(); - HOST = QDRANT_CONTAINER.getHost() + ":" + QDRANT_CONTAINER.getMappedPort(6333); - TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class); + HOST = QDRANT_CONTAINER.getHost() + ":" + QDRANT_CONTAINER.getMappedPort(6333); + TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class, Prompt.class); testCall( db, @@ -203,6 +208,41 @@ public void queryVectors() { }); } + @Test + public void queryVectorsWithRag() { + String openAIKey = System.getenv("OPENAI_KEY");; + Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey); + + db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})"); + + Map conf = map(ALL_RESULTS_KEY, true, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, map(NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + """ + CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector + WITH collect(node) as paths + CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value + RETURN value + """ + , + map( + "host", HOST, + "conf", conf, + "confPrompt", map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + r -> { + Map row = r.next(); + Object value = row.get("value"); + assertTrue("The actual value is: " + value, value.toString().contains("Berlin")); + }); + } + @Test public void queryVectorsWithoutVectorResult() { testResult( diff --git a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java index 14b76f4ef9..5f550871af 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java @@ -1,5 +1,25 @@ package apoc.full.it.vectordb; +import apoc.ml.Prompt; +import apoc.util.MapUtil; +import apoc.util.TestUtil; +import org.junit.AfterClass; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.ResourceIterator; +import org.neo4j.test.TestDatabaseManagementServiceBuilder; +import org.testcontainers.weaviate.WeaviateContainer; + +import java.util.List; +import java.util.Map; + import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testCallEmpty; @@ -9,7 +29,19 @@ import static apoc.vectordb.VectorDbTestUtil.*; import static apoc.vectordb.VectorDbTestUtil.EntityType.*; import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; +import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.REL; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; +import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; +import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; +import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; import static org.assertj.core.api.Assertions.assertThat; @@ -21,6 +53,7 @@ import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; +import apoc.ml.Prompt; import apoc.util.MapUtil; import apoc.util.TestUtil; import apoc.vectordb.VectorDb; @@ -82,7 +115,7 @@ public static void setUp() throws Exception { WEAVIATE_CONTAINER.start(); HOST = WEAVIATE_CONTAINER.getHttpHostAddress(); - TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class); + TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class, Prompt.class); testCall( db, @@ -552,4 +585,34 @@ public void queryVectorsWithSystemDbStorage() { assertNodesCreated(db); } + + @Test + public void queryVectorsWithRag() { + String openAIKey = ragSetup(db); + + Map conf = MapUtil.map( + FIELDS_KEY, FIELDS, + ALL_RESULTS_KEY, true, + HEADERS_KEY, READONLY_AUTHORIZATION, + MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Rag", + ENTITY_KEY, "readID", + METADATA_KEY, "foo") + ); + + testResult(db, + "CALL apoc.vectordb.weaviate.getAndUpdate($host, 'TestCollection', [$id1], $conf) YIELD score, node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value" + , + MapUtil.map( + "host", HOST, + "id1", ID_1, + "conf", conf, + "confPrompt", MapUtil.map(API_KEY_CONF, openAIKey), + "attributes", List.of("city", "foo") + ), + VectorDbTestUtil::assertRagWithVectors); + } } diff --git a/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java b/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java index d949adf64d..c430ecb19c 100644 --- a/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java +++ b/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java @@ -4,9 +4,12 @@ import static apoc.util.Util.map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import apoc.util.MapUtil; import java.util.Map; +import org.junit.Assume; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.ResourceIterator; @@ -89,4 +92,34 @@ private static void assertBerlinProperties(Map props) { public static Map getAuthHeader(String key) { return map("Authorization", "Bearer " + key); } + + public static void assertReadOnlyProcWithMappingResults(Result r, String node) { + Map row = r.next(); + Map props = ((Entity) row.get(node)).getAllProperties(); + assertEquals(MapUtil.map("readID", "one"), props); + assertNotNull(row.get("vector")); + assertNotNull(row.get("id")); + + row = r.next(); + props = ((Entity) row.get(node)).getAllProperties(); + assertEquals(MapUtil.map("readID", "two"), props); + assertNotNull(row.get("vector")); + assertNotNull(row.get("id")); + + assertFalse(r.hasNext()); + } + + public static void assertRagWithVectors(Result r) { + Map row = r.next(); + Object value = row.get("value"); + assertTrue("The actual value is: " + value, value.toString().contains("Berlin")); + } + + public static String ragSetup(GraphDatabaseService db) { + String openAIKey = System.getenv("OPENAI_KEY"); + ; + Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey); + db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})"); + return openAIKey; + } } From 6a0796145ecd1fba9a6c4b189caad535e6f5e5d2 Mon Sep 17 00:00:00 2001 From: vga91 Date: Thu, 5 Dec 2024 17:09:58 +0100 Subject: [PATCH 2/2] [NOID] test fixes --- .../apoc/full/it/vectordb/ChromaDbTest.java | 56 +++++------------- .../apoc/full/it/vectordb/QdrantTest.java | 47 ++++++++------- .../apoc/full/it/vectordb/WeaviateTest.java | 57 ++++++------------- 3 files changed, 58 insertions(+), 102 deletions(-) diff --git a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java index 6d0274c12b..9221b0112f 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java @@ -1,23 +1,6 @@ package apoc.full.it.vectordb; -import apoc.ml.Prompt; -import apoc.util.TestUtil; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.neo4j.dbms.api.DatabaseManagementService; -import org.neo4j.graphdb.GraphDatabaseService; -import org.neo4j.test.TestDatabaseManagementServiceBuilder; -import org.testcontainers.chromadb.ChromaDBContainer; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; - -import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -30,8 +13,6 @@ import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; -import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; -import static apoc.vectordb.VectorDbTestUtil.ragSetup; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; @@ -42,14 +23,14 @@ import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; +import apoc.ml.Prompt; import apoc.util.TestUtil; import apoc.vectordb.ChromaDb; import apoc.vectordb.VectorDb; +import apoc.vectordb.VectorDbTestUtil; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; - -import apoc.vectordb.VectorDbTestUtil; import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Before; @@ -85,7 +66,7 @@ public static void setUp() throws Exception { CHROMA_CONTAINER.start(); HOST = CHROMA_CONTAINER.getEndpoint(); - TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class, Prompt.class)); + TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class, Prompt.class); testCall( db, @@ -437,7 +418,7 @@ public void queryVectorsWithCreateRel() { @Test public void queryVectorsWithSystemDbStorage() { String keyConfig = "chroma-config-foo"; - String baseUrl = "http://" + HOST; + String baseUrl = HOST; Map mapping = map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo"); sysDb.executeTransactionally( @@ -480,28 +461,21 @@ public void queryVectorsWithSystemDbStorage() { public void queryVectorsWithRag() { String openAIKey = ragSetup(db); - Map conf = map(ALL_RESULTS_KEY, true, - HEADERS_KEY, READONLY_AUTHORIZATION, - MAPPING_KEY, map(NODE_LABEL, "Rag", - ENTITY_KEY, "readID", - METADATA_KEY, "foo") - ); - - testResult(db, - """ - CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector - WITH collect(node) as paths - CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value - RETURN value - """ - , + Map conf = map( + ALL_RESULTS_KEY, true, MAPPING_KEY, map(NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value", map( "host", HOST, "conf", conf, "collection", COLL_ID.get(), "confPrompt", map(API_KEY_CONF, openAIKey), - "attributes", List.of("city", "foo") - ), + "attributes", List.of("city", "foo")), VectorDbTestUtil::assertRagWithVectors); } } diff --git a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java index 7edca73f34..1c1370dd57 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java @@ -1,5 +1,6 @@ package apoc.full.it.vectordb; +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.MapUtil.map; import static apoc.util.TestUtil.testCall; @@ -32,7 +33,6 @@ import apoc.vectordb.Qdrant; import apoc.vectordb.VectorDb; import apoc.vectordb.VectorDbTestUtil; - import java.util.List; import java.util.Map; import org.assertj.core.api.Assertions; @@ -210,32 +210,35 @@ public void queryVectors() { @Test public void queryVectorsWithRag() { - String openAIKey = System.getenv("OPENAI_KEY");; + String openAIKey = System.getenv("OPENAI_KEY"); + ; Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey); db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})"); - Map conf = map(ALL_RESULTS_KEY, true, - HEADERS_KEY, READONLY_AUTHORIZATION, - MAPPING_KEY, map(NODE_LABEL, "Rag", - ENTITY_KEY, "readID", - METADATA_KEY, "foo") - ); - - testResult(db, - """ - CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector - WITH collect(node) as paths - CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value - RETURN value - """ - , + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + READONLY_AUTHORIZATION, + MAPPING_KEY, + map(NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value", map( - "host", HOST, - "conf", conf, - "confPrompt", map(API_KEY_CONF, openAIKey), - "attributes", List.of("city", "foo") - ), + "host", + HOST, + "conf", + conf, + "confPrompt", + map(API_KEY_CONF, openAIKey), + "attributes", + List.of("city", "foo")), r -> { Map row = r.next(); Object value = row.get("value"); diff --git a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java index 5f550871af..f5d7941252 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java @@ -1,25 +1,6 @@ package apoc.full.it.vectordb; -import apoc.ml.Prompt; -import apoc.util.MapUtil; -import apoc.util.TestUtil; -import org.junit.AfterClass; -import org.junit.Assume; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.neo4j.dbms.api.DatabaseManagementService; -import org.neo4j.graphdb.Entity; -import org.neo4j.graphdb.GraphDatabaseService; -import org.neo4j.graphdb.ResourceIterator; -import org.neo4j.test.TestDatabaseManagementServiceBuilder; -import org.testcontainers.weaviate.WeaviateContainer; - -import java.util.List; -import java.util.Map; - +import static apoc.ml.Prompt.API_KEY_CONF; import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testCallEmpty; @@ -28,18 +9,17 @@ import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE; import static apoc.vectordb.VectorDbTestUtil.*; import static apoc.vectordb.VectorDbTestUtil.EntityType.*; -import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE; import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE; import static apoc.vectordb.VectorDbTestUtil.EntityType.REL; import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; -import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults; import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; import static apoc.vectordb.VectorDbTestUtil.ragSetup; +import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; @@ -591,28 +571,27 @@ public void queryVectorsWithRag() { String openAIKey = ragSetup(db); Map conf = MapUtil.map( - FIELDS_KEY, FIELDS, - ALL_RESULTS_KEY, true, - HEADERS_KEY, READONLY_AUTHORIZATION, - MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect", - NODE_LABEL, "Rag", - ENTITY_KEY, "readID", - METADATA_KEY, "foo") - ); - - testResult(db, - "CALL apoc.vectordb.weaviate.getAndUpdate($host, 'TestCollection', [$id1], $conf) YIELD score, node, metadata, id, vector\n" + - "WITH collect(node) as paths\n" + - "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + - "RETURN value" - , + FIELDS_KEY, + FIELDS, + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + READONLY_AUTHORIZATION, + MAPPING_KEY, + MapUtil.map(EMBEDDING_KEY, "vect", NODE_LABEL, "Rag", ENTITY_KEY, "readID", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.weaviate.getAndUpdate($host, 'TestCollection', [$id1], $conf) YIELD score, node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value", MapUtil.map( "host", HOST, "id1", ID_1, "conf", conf, "confPrompt", MapUtil.map(API_KEY_CONF, openAIKey), - "attributes", List.of("city", "foo") - ), + "attributes", List.of("city", "foo")), VectorDbTestUtil::assertRagWithVectors); } }