Skip to content

Commit

Permalink
[NOID] Fixes #4091: Update RAG docs with vector db examples (#4116)
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed Dec 5, 2024
1 parent ae7315d commit 85fbf0f
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ CALL apoc.vectordb.chroma.query($host, '<collection_id>',
----


[NOTE]
====
To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures.
For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]},
so that we do not return the other values that we do not need.
====

It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow:

[source,cypher]
----
CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, [<id1>, <id2>], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value
RETURN value
----

.Delete vectors (it leverages https://docs.trychroma.com/usage-guide#deleting-data-from-a-collection[this API])
[source,cypher]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

== Qdrant

Here is a list of all available Qdrant procedures,
Expand Down Expand Up @@ -200,7 +199,15 @@ For example, by executing a `CALL apoc.vectordb.qdrant.query(...) YIELD metadata
so that we do not return the other values that we do not need.
====

It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow:

[source,cypher]
----
CALL apoc.vectordb.qdrant.getAndUpdate($host, $collection, [<id1>, <id2>], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value
RETURN value
----

.Delete vectors (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/points/operation/delete_vectors[this API])
[source,cypher]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

== Weaviate

Here is a list of all available Weaviate procedures,
Expand Down Expand Up @@ -213,7 +212,15 @@ For example, by executing a `CALL apoc.vectordb.weaviate.query(...) YIELD metada
so that we do not return the other values that we do not need.
====

It is possible to execute vector db procedures together with the xref::ml/rag.adoc[apoc.ml.rag] as follow:

[source,cypher]
----
CALL apoc.vectordb.weaviate.getAndUpdate($host, $collection, [<id1>, <id2>], $conf) YIELD score, node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, $question, $confPrompt) YIELD value
RETURN value
----

.Delete vectors (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/objects/delete/objects/\{className\}/\{id\}[this API])
[source,cypher]
Expand Down
58 changes: 55 additions & 3 deletions full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
package apoc.full.it.vectordb;

import apoc.ml.Prompt;
import apoc.util.TestUtil;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.test.TestDatabaseManagementServiceBuilder;
import org.testcontainers.chromadb.ChromaDBContainer;

import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import static apoc.ml.RestAPIConfig.HEADERS_KEY;
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand All @@ -10,7 +28,10 @@
import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated;
import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll;
import static apoc.vectordb.VectorDbTestUtil.ragSetup;
import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING;
import static apoc.vectordb.VectorDbTestUtil.getAuthHeader;
import static apoc.vectordb.VectorDbTestUtil.ragSetup;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
import static apoc.vectordb.VectorMappingConfig.*;
Expand All @@ -27,6 +48,8 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import apoc.vectordb.VectorDbTestUtil;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
import org.junit.Before;
Expand Down Expand Up @@ -60,9 +83,9 @@ public static void setUp() throws Exception {
sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME);

CHROMA_CONTAINER.start();
HOST = CHROMA_CONTAINER.getEndpoint();

TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class);
HOST = CHROMA_CONTAINER.getEndpoint();
TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class, Prompt.class));

testCall(
db,
Expand Down Expand Up @@ -414,7 +437,7 @@ public void queryVectorsWithCreateRel() {
@Test
public void queryVectorsWithSystemDbStorage() {
String keyConfig = "chroma-config-foo";
String baseUrl = HOST;
String baseUrl = "http://" + HOST;
Map<String, Object> mapping =
map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo");
sysDb.executeTransactionally(
Expand Down Expand Up @@ -452,4 +475,33 @@ public void queryVectorsWithSystemDbStorage() {

assertNodesCreated(db);
}

@Test
public void queryVectorsWithRag() {
String openAIKey = ragSetup(db);

Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
HEADERS_KEY, READONLY_AUTHORIZATION,
MAPPING_KEY, map(NODE_LABEL, "Rag",
ENTITY_KEY, "readID",
METADATA_KEY, "foo")
);

testResult(db,
"""
CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value
RETURN value
"""
,
map(
"host", HOST,
"conf", conf,
"collection", COLL_ID.get(),
"confPrompt", map(API_KEY_CONF, openAIKey),
"attributes", List.of("city", "foo")
),
VectorDbTestUtil::assertRagWithVectors);
}
}
44 changes: 42 additions & 2 deletions full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.neo4j.configuration.GraphDatabaseSettings.*;

import apoc.ml.Prompt;
import apoc.util.TestUtil;
import apoc.util.Util;
import apoc.vectordb.Qdrant;
import apoc.vectordb.VectorDb;
import apoc.vectordb.VectorDbTestUtil;

import java.util.List;
import java.util.Map;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
import org.junit.Assume;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
Expand Down Expand Up @@ -72,9 +77,9 @@ public static void setUp() throws Exception {
sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME);

QDRANT_CONTAINER.start();
HOST = QDRANT_CONTAINER.getHost() + ":" + QDRANT_CONTAINER.getMappedPort(6333);

TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class);
HOST = QDRANT_CONTAINER.getHost() + ":" + QDRANT_CONTAINER.getMappedPort(6333);
TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class, Prompt.class);

testCall(
db,
Expand Down Expand Up @@ -203,6 +208,41 @@ public void queryVectors() {
});
}

@Test
public void queryVectorsWithRag() {
String openAIKey = System.getenv("OPENAI_KEY");;
Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey);

db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");

Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
HEADERS_KEY, READONLY_AUTHORIZATION,
MAPPING_KEY, map(NODE_LABEL, "Rag",
ENTITY_KEY, "readID",
METADATA_KEY, "foo")
);

testResult(db,
"""
CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value
RETURN value
"""
,
map(
"host", HOST,
"conf", conf,
"confPrompt", map(API_KEY_CONF, openAIKey),
"attributes", List.of("city", "foo")
),
r -> {
Map<String, Object> row = r.next();
Object value = row.get("value");
assertTrue("The actual value is: " + value, value.toString().contains("Berlin"));
});
}

@Test
public void queryVectorsWithoutVectorResult() {
testResult(
Expand Down
65 changes: 64 additions & 1 deletion full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
package apoc.full.it.vectordb;

import apoc.ml.Prompt;
import apoc.util.MapUtil;
import apoc.util.TestUtil;
import org.junit.AfterClass;
import org.junit.Assume;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.test.TestDatabaseManagementServiceBuilder;
import org.testcontainers.weaviate.WeaviateContainer;

import java.util.List;
import java.util.Map;

import static apoc.ml.RestAPIConfig.HEADERS_KEY;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testCallEmpty;
Expand All @@ -9,7 +29,19 @@
import static apoc.vectordb.VectorDbTestUtil.*;
import static apoc.vectordb.VectorDbTestUtil.EntityType.*;
import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING;
import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE;
import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE;
import static apoc.vectordb.VectorDbTestUtil.EntityType.REL;
import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult;
import static apoc.vectordb.VectorDbTestUtil.assertLondonResult;
import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertReadOnlyProcWithMappingResults;
import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated;
import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll;
import static apoc.vectordb.VectorDbTestUtil.getAuthHeader;
import static apoc.vectordb.VectorDbTestUtil.ragSetup;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
import static apoc.vectordb.VectorMappingConfig.*;
import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -21,6 +53,7 @@
import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME;

import apoc.ml.Prompt;
import apoc.util.MapUtil;
import apoc.util.TestUtil;
import apoc.vectordb.VectorDb;
Expand Down Expand Up @@ -82,7 +115,7 @@ public static void setUp() throws Exception {
WEAVIATE_CONTAINER.start();
HOST = WEAVIATE_CONTAINER.getHttpHostAddress();

TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class);
TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class, Prompt.class);

testCall(
db,
Expand Down Expand Up @@ -552,4 +585,34 @@ public void queryVectorsWithSystemDbStorage() {

assertNodesCreated(db);
}

@Test
public void queryVectorsWithRag() {
String openAIKey = ragSetup(db);

Map<String, Object> conf = MapUtil.map(
FIELDS_KEY, FIELDS,
ALL_RESULTS_KEY, true,
HEADERS_KEY, READONLY_AUTHORIZATION,
MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect",
NODE_LABEL, "Rag",
ENTITY_KEY, "readID",
METADATA_KEY, "foo")
);

testResult(db,
"CALL apoc.vectordb.weaviate.getAndUpdate($host, 'TestCollection', [$id1], $conf) YIELD score, node, metadata, id, vector\n" +
"WITH collect(node) as paths\n" +
"CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" +
"RETURN value"
,
MapUtil.map(
"host", HOST,
"id1", ID_1,
"conf", conf,
"confPrompt", MapUtil.map(API_KEY_CONF, openAIKey),
"attributes", List.of("city", "foo")
),
VectorDbTestUtil::assertRagWithVectors);
}
}
38 changes: 38 additions & 0 deletions full/src/test/java/apoc/vectordb/VectorDbTestUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
package apoc.vectordb;

import apoc.util.MapUtil;
import org.junit.Assume;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.graphdb.Result;

import java.util.Map;

import static apoc.util.TestUtil.testResult;
import static apoc.util.Util.map;
import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -89,4 +98,33 @@ private static void assertBerlinProperties(Map props) {
public static Map<String, String> getAuthHeader(String key) {
return map("Authorization", "Bearer " + key);
}

public static void assertReadOnlyProcWithMappingResults(Result r, String node) {
Map<String, Object> row = r.next();
Map<String, Object> props = ((Entity) row.get(node)).getAllProperties();
assertEquals(MapUtil.map("readID", "one"), props);
assertNotNull(row.get("vector"));
assertNotNull(row.get("id"));

row = r.next();
props = ((Entity) row.get(node)).getAllProperties();
assertEquals(MapUtil.map("readID", "two"), props);
assertNotNull(row.get("vector"));
assertNotNull(row.get("id"));

assertFalse(r.hasNext());
}

public static void assertRagWithVectors(Result r) {
Map<String, Object> row = r.next();
Object value = row.get("value");
assertTrue("The actual value is: " + value, value.toString().contains("Berlin"));
}

public static String ragSetup(GraphDatabaseService db) {
String openAIKey = System.getenv("OPENAI_KEY");;
Assume.assumeNotNull("No OPENAI_KEY environment configured", openAIKey);
db.executeTransactionally("CREATE (:Rag {readID: 'one'}), (:Rag {readID: 'two'})");
return openAIKey;
}
}

0 comments on commit 85fbf0f

Please sign in to comment.