diff --git a/LICENSES.txt b/LICENSES.txt index 72979987a2..ca8b44bf06 100644 --- a/LICENSES.txt +++ b/LICENSES.txt @@ -48,9 +48,9 @@ Apache-2.0 curator-client-5.2.0.jar curator-framework-5.2.0.jar curator-recipes-5.2.0.jar - docker-java-api-3.2.13.jar - docker-java-transport-3.2.13.jar - docker-java-transport-zerodep-3.2.13.jar + docker-java-api-3.3.6.jar + docker-java-transport-3.3.6.jar + docker-java-transport-zerodep-3.3.6.jar ehcache-3.3.1.jar error_prone_annotations-2.18.0.jar failureaccess-1.0.1.jar @@ -133,6 +133,7 @@ Apache-2.0 jffi-1.2.16-native.jar jffi-1.2.16.jar jmespath-java-1.12.770.jar + jna-5.13.0.jar jna-5.9.0.jar jnr-constants-0.9.9.jar jnr-ffi-2.1.7.jar @@ -3045,6 +3046,7 @@ MIT bcutil-jdk18on-1.78.jar cassandra-1.17.6.jar checker-qual-3.42.0.jar + chromadb-1.19.7.jar couchbase-1.17.6.jar database-commons-1.17.6.jar duct-tape-1.0.8.jar @@ -3062,12 +3064,14 @@ MIT mysql-1.17.6.jar neo4j-1.17.6.jar postgresql-1.17.6.jar + qdrant-1.19.7.jar reactive-streams-1.0.4.jar slf4j-api-1.7.36.jar slf4j-api-2.0.11.jar slf4j-nop-1.7.30.jar slf4j-reload4j-1.7.36.jar - testcontainers-1.17.6.jar + testcontainers-1.19.7.jar + weaviate-1.19.7.jar ------------------------------------------------------------------------------ The MIT License diff --git a/NOTICE.txt b/NOTICE.txt index e33aa3db9f..17bf19ee60 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -78,9 +78,9 @@ Apache-2.0 curator-client-5.2.0.jar curator-framework-5.2.0.jar curator-recipes-5.2.0.jar - docker-java-api-3.2.13.jar - docker-java-transport-3.2.13.jar - docker-java-transport-zerodep-3.2.13.jar + docker-java-api-3.3.6.jar + docker-java-transport-3.3.6.jar + docker-java-transport-zerodep-3.3.6.jar ehcache-3.3.1.jar error_prone_annotations-2.18.0.jar failureaccess-1.0.1.jar @@ -163,6 +163,7 @@ Apache-2.0 jffi-1.2.16-native.jar jffi-1.2.16.jar jmespath-java-1.12.770.jar + jna-5.13.0.jar jna-5.9.0.jar jnr-constants-0.9.9.jar jnr-ffi-2.1.7.jar @@ -434,6 +435,7 @@ LGPL 2.1 javassist-3.25.0-GA.jar LGPL-2.1-or-later + jna-5.13.0.jar jna-5.9.0.jar MIT @@ -445,6 +447,7 @@ MIT bcutil-jdk18on-1.78.jar cassandra-1.17.6.jar checker-qual-3.42.0.jar + chromadb-1.19.7.jar couchbase-1.17.6.jar database-commons-1.17.6.jar duct-tape-1.0.8.jar @@ -462,12 +465,14 @@ MIT mysql-1.17.6.jar neo4j-1.17.6.jar postgresql-1.17.6.jar + qdrant-1.19.7.jar reactive-streams-1.0.4.jar slf4j-api-1.7.36.jar slf4j-api-2.0.11.jar slf4j-nop-1.7.30.jar slf4j-reload4j-1.7.36.jar - testcontainers-1.17.6.jar + testcontainers-1.19.7.jar + weaviate-1.19.7.jar MPL 1.1 javassist-3.25.0-GA.jar diff --git a/core/src/main/java/apoc/SystemLabels.java b/core/src/main/java/apoc/SystemLabels.java index 2e74c7fe72..269127d23b 100644 --- a/core/src/main/java/apoc/SystemLabels.java +++ b/core/src/main/java/apoc/SystemLabels.java @@ -29,5 +29,6 @@ public enum SystemLabels implements Label { ApocUuidMeta, ApocTriggerMeta, ApocTrigger, - DataVirtualizationCatalog + DataVirtualizationCatalog, + VectorDb } diff --git a/core/src/main/java/apoc/SystemPropertyKeys.java b/core/src/main/java/apoc/SystemPropertyKeys.java index 5271ba34da..18aea42968 100644 --- a/core/src/main/java/apoc/SystemPropertyKeys.java +++ b/core/src/main/java/apoc/SystemPropertyKeys.java @@ -46,5 +46,9 @@ public enum SystemPropertyKeys { label, addToSetLabel, addToExistingNodes, - propertyName; + propertyName, + + // vector db + host, + credentials } diff --git a/core/src/main/java/apoc/util/Util.java b/core/src/main/java/apoc/util/Util.java index 638f9b2f0c..54bda10227 100644 --- a/core/src/main/java/apoc/util/Util.java +++ b/core/src/main/java/apoc/util/Util.java @@ -1314,4 +1314,53 @@ public static ConstraintCategory getConstraintCategory(ConstraintType type) { return ConstraintCategory.NODE; } } + + public static void setProperties(Entity entity, Map props) { + for (var entry : props.entrySet()) { + entity.setProperty(entry.getKey(), entry.getValue()); + } + } + /** + * Transform a list like: [ {key1: valueFoo1, key2: valueFoo2}, {key1: valueBar1, key2: valueBar2} ] + * to a map like: { keyNew1: [valueFoo1, valueBar1], keyNew2: [valueFoo2, valueBar2] }, + * + * where mapKeys is e.g. {key1: keyNew1, key2: keyNew2} + */ + public static Map listOfMapToMapOfLists(Map mapKeys, List> vectors) { + Map additionalBodies = new HashMap(); + for (var vector : vectors) { + mapKeys.forEach((from, to) -> { + mapEntryToList(additionalBodies, vector, from, to); + }); + } + return additionalBodies; + } + + private static void mapEntryToList( + Map map, Map vector, Object keyFrom, Object keyTo) { + Object item = vector.get(keyFrom); + if (item == null) { + return; + } + + map.compute(keyTo, (k, v) -> { + if (v == null) { + List list = new ArrayList<>(); + list.add(item); + return list; + } + v.add(item); + return v; + }); + } + + public static float[] listOfNumbersToFloatArray(List embedding) { + float[] floats = new float[embedding.size()]; + int i = 0; + for (var item : embedding) { + floats[i] = item.floatValue(); + i++; + } + return floats; + } } diff --git a/docs/asciidoc/modules/ROOT/nav.adoc b/docs/asciidoc/modules/ROOT/nav.adoc index f24bd42cba..1b9e56188a 100644 --- a/docs/asciidoc/modules/ROOT/nav.adoc +++ b/docs/asciidoc/modules/ROOT/nav.adoc @@ -50,6 +50,7 @@ include::partial$generated-documentation/nav.adoc[] ** xref::database-integration/bolt-neo4j.adoc[] ** xref::database-integration/load-ldap.adoc[] ** xref::database-integration/redis.adoc[] + ** xref:database-integration/vectordb/index.adoc[] * xref:graph-updates/index.adoc[] ** xref::graph-updates/data-creation.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/index.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/index.adoc index e509948b30..74d6b45f1d 100644 --- a/docs/asciidoc/modules/ROOT/pages/database-integration/index.adoc +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/index.adoc @@ -18,4 +18,5 @@ For more information on how to use these procedures, see: * xref::database-integration/bolt-neo4j.adoc[] * xref::database-integration/load-ldap.adoc[] * xref::database-integration/redis.adoc[] +* xref:database-integration/vectordb/index.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc new file mode 100644 index 0000000000..a0fe25f804 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/chroma.adoc @@ -0,0 +1,157 @@ + +== ChromaDB + +Here is a list of all available ChromaDB procedures, +note that the list and the signature procedures are consistent with the others, like the Qdrant ones: + +[opts=header, cols="1, 3"] +|=== +| name | description +| apoc.vectordb.chroma.createCollection(hostOrKey, collection, similarity, size, $config) | + Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`. + The default endpoint is `/api/v1/collections`. +| apoc.vectordb.chroma.deleteCollection(hostOrKey, collection, $config) | + Deletes a collection with the name specified in the 2nd parameter. + The default endpoint is `/api/v1/collections/`. +| apoc.vectordb.chroma.upsert(hostOrKey, collection, vectors, $config) | + Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]. + The default endpoint is `/api/v1/collections//upsert`. +| apoc.vectordb.chroma.delete(hostOrKey, collection, ids, $config) | + Deletes the vectors with the specified `ids`. + The default endpoint is `/api/v1/collections//delete`. +| apoc.vectordb.chroma.get(hostOrKey, collection, ids, $config) | + Gets the vectors with the specified `ids`. + The default endpoint is `/api/v1/collections//get`. +| apoc.vectordb.chroma.query(hostOrKey, collection, vector, filter, limit, $config) | + Retrieve closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter. + The default endpoint is `/api/v1/collections//query`. +| apoc.vectordb.chroma.getAndUpdate(hostOrKey, collection, ids, $config) | + Gets the vectors with the specified `ids`, and optionally creates/updates neo4j entities. + The default endpoint is `/api/v1/collections//get`. +| apoc.vectordb.chroma.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $config) | + Retrieve closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter, and optionally creates/updates neo4j entities. + The default endpoint is `/api/v1/collections//query`. +|=== + +where the 1st parameter can be a key defined by the apoc config `apoc.chroma..host=myHost`. +With hostOrKey=null, the default is 'http://localhost:8000'. + +=== Examples + +.Create a collection (it leverages https://docs.trychroma.com/usage-guide#creating-inspecting-and-deleting-collections[this API]) +[source,cypher] +---- +CALL apoc.vectordb.chroma.createCollection($host, 'test_collection', 'Cosine', 4, {}) +---- + + +.Delete a collection (it leverages https://docs.trychroma.com/usage-guide#creating-inspecting-and-deleting-collections[this API]) +[source,cypher] +---- +CALL apoc.vectordb.chroma.deleteCollection($host, '', {}) +---- + + +.Upsert vectors (it leverages https://docs.trychroma.com/usage-guide#adding-data-to-a-collection[this API]) +[source,cypher] +---- +CALL apoc.vectordb.qdrant.upsert($host, '', + [ + {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}, text: 'ajeje'}, + {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}, text: 'brazorf'} + ], + {}) +---- + + +.Get vectors (it leverages https://docs.trychroma.com/usage-guide#querying-a-collection[this API]) +[source,cypher] +---- +CALL apoc.vectordb.chroma.get($host, '', ['1','2'], {}), text +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| null | {city: "Berlin", foo: "one"} | null | null | null | null +| null | {city: "Berlin", foo: "two"} | null | null | null | null +| ... +|=== + + +.Get vectors with `{allResults: true}` +[source,cypher] +---- +CALL apoc.vectordb.chroma.get($host, '', ['1','2'], {}), text +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| null | {city: "Berlin", foo: "one"} | 1 | [...] | ajeje | null +| null | {city: "Berlin", foo: "two"} | 2 | [...] | brazorf | null +| ... +|=== + + +.Query vectors (it leverages https://docs.trychroma.com/usage-guide#querying-a-collection[this API]) +[source,cypher] +---- +CALL apoc.vectordb.chroma.query($host, + '', + [0.2, 0.1, 0.9, 0.7], + {city: 'London'}, + 5, + {allResults: true, }), text +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text +| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | ajeje +| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | brazorf +| ... +|=== + + +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.chroma.query and the `apoc.vectordb.chroma.get` procedures. +For example, by executing a `CALL apoc.vectordb.chroma.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"include": ["metadatas", "documents", "distances"]}, +so that we do not return the other values that we do not need. +==== + + +In the same way as other procedures, we can define a mapping, to fetch the associated nodes and relationships and optionally create them, +by leveraging the vector metadata. For example: + +.Query vectors +[source,cypher] +---- +CALL apoc.vectordb.chroma.query($host, '', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + + + +.Delete vectors (it leverages https://docs.trychroma.com/usage-guide#deleting-data-from-a-collection[this API]) +[source,cypher] +---- +CALL apoc.vectordb.chroma.delete($host, '', [1,2], {}) +---- + diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/custom.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/custom.adoc new file mode 100644 index 0000000000..730bad15e8 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/custom.adoc @@ -0,0 +1,99 @@ + +== Custom (i.e. other vector databases) + +We can also interface with other db vectors that do not (yet) have dedicated procedures. +For example, with https://docs.pinecone.io/guides/getting-started/overview[Pinecone], as we will see later. + +Here is a list of all available custom procedures: + +[opts=header, cols="1, 3"] +|=== +| name | description +| apoc.vectordb.custom.get(host, $embeddingConfig) | Customizable get / query procedure, +returning a result like the others `apoc.vectordb.*.get` ones +| apoc.vectordb.custom(host, $config) | Fully customizable procedure, returns generic object results. +|=== + + +=== Examples + + +The `apoc.vectordb.custom.get` can be used with every API that return something like this +(note that the call does not need to return all keys): + +``` +[ + "": "value", + "": scoreValue, + "": [ ... ] + "": { .. }, + "": "..." +], +[ + ... +] +``` + +where we can customize idKey, scoreKey, vectorKey, metadataKey and textKey via the homonyms config parameters. + + +Let's look at some examples using https://docs.pinecone.io/guides/getting-started/overview[Pinecone]. + + +.apoc.vectordb.custom.get example +[source,cypher] +---- +CALL apoc.vectordb.custom.get('https://.svc.gcp-starter.pinecone.io/query', { + body: { + "namespace", namespace, + "vector", vector, + "topK", 3, + "includeValues", true, + "includeMetadata", true + }, + headers: {"Api-Key", apiKey}, + method: null, + jsonPath: "matches", + // the RestAPI return values as the key with values the vectors + vectorKey: 'values' +}), text +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text +| 1, | {a: 1} | 1 | [1,2,3,4] +| 0.1 | {a: 2} | 2 | [1,2,3,4] +| ... +|=== + + + +.apoc.vectordb.custom example +[source,cypher] +---- +CALL apoc.vectordb.custom('https://.svc.gcp-starter.pinecone.io/query', { + body: { + "namespace", namespace, + "vector", vector, + "topK", 3, + "includeValues", true, + "includeMetadata", true + }, + headers: {"Api-Key", apiKey}, + method: null, + jsonPath: "matches" +}) +---- + + +.Example esults +[opts="header"] +|=== +| value +| {score: , metadata: , id: , vector: } +| {score: , metadata: , id: , vector: } +| ... +|=== diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/index.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/index.adoc new file mode 100644 index 0000000000..7f818d2bff --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/index.adoc @@ -0,0 +1,108 @@ +[[vectordb]] += Vector Databases +:description: This section describes procedures that can be used to interact with Vector Databases. + +APOC provides these set of procedures, which leverages the Rest APIs, to interact with Vector Databases: + +- `apoc.vectordb.qdrant.*` (to interact with https://qdrant.tech/documentation/overview/[Qdrant]) +- `apoc.vectordb.chroma.*` (to interact with https://docs.trychroma.com/getting-started[Chroma]) +- `apoc.vectordb.weaviate.*` (to interact with https://weaviate.io/developers/weaviate[Weaviate]) +- `apoc.vectordb.custom.*` (to interact with other vector databases). +- `apoc.vectordb.configure` (to store host, credentials and mapping into the system database) + +All the procedures, except the `apoc.vectordb.configure` one, can have, as a final parameter, +a configuration map with these optional parameters: + +.config parameters + +|=== +| key | description +| headers | additional HTTP headers +| method | HTTP method +| endpoint | endpoint key, + can be used to override the default endpoint created via the 1st parameter of the procedures, + to handle potential endpoint changes. +| body | body HTTP request +| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] parsing of the response. The default is `null`. +|=== + + +Besides the above config, the `apoc.vectordb..get` and the `apoc.vectordb..query` procedures can have these additional parameters: + +.embeddingConfig parameters + +|=== +| key | description +| mapping | to fetch the associated entities and optionally create them. See examples below. +| allResults | if true, returns the vector, metadata and text (if present), otherwise returns null values for those columns. +| vectorKey, metadataKey, scoreKey, textKey | used with the `apoc.vectordb.custom.get` procedure. + To let the procedure know which key in the restAPI (if present) corresponds to the one that should be populated as respectively the vector/metadata/score/text result. + Defaults are "vector", "metadata", "score", "text". + See examples below. +|=== + + +== Ad-hoc procedures + +See the following pages for more details on specific vector db procedures + +- xref:./qdrant.adoc[Qdrant] +- xref:./chroma.adoc[ChromaDB] +- xref:./weaviate.adoc[Weaviate] + + +== Store Vector db info (i.e. `apoc.vectordb.configure`) + +We can save some info in the System Database to be reused later, that is the host, login credentials, and mapping, +to be used in `*.get` and `.*query` procedures, except for the `apoc.vectordb.custom.get` one. + +Therefore, to store the vector info, we can execute the `CALL apoc.vectordb.configure(vectorName, keyConfig, databaseName, $configMap)`, +where `vectorName` can be "QDRANT", "CHROMA" or "WEAVIATE", +that indicates info to be reused respectively by `apoc.vectordb.qdrant.*`, `apoc.vectordb.chroma.*` and `apoc.vectordb.weaviate.*`. + +Then `keyConfig` is the configuration name, `databaseName` is the database where the config will be set, + +and finally the `configMap`, that can have: + +- `host` is the host base name +- `credentialsValue` is the API key +- `mapping` is a map that can be used by the `apoc.vectordb.\*.getAndUpdate` and `apoc.vectordb.*.queryAndUpdate` procedures + +NOTE:: this procedure is only executable by a user with admin permissions and against the system database + +For example: +[source,cypher] +---- +// -- within the system database or using the Cypher clause `USE SYSTEM ..` as a prefix +CALL apoc.vectordb.configure('QDRANT', 'qdrant-config-test', 'neo4j', + { + mapping: { embeddingKey: "vect", nodeLabel: "Test", entityKey: "myId", metadataKey: "foo" }, + host: 'custom-host-name', + credentials: '' +} +) +---- + +and then we can execute e.g. the following procedure (within the `neo4j` database): + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query('qdrant-config-test', 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5) +---- + +instead of: + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, +{ mapping: { + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + }, + headers: {Authorization: 'Bearer '}, + endpoint: 'custom-host-name' +}) +---- + diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc new file mode 100644 index 0000000000..073fa9d146 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/qdrant.adoc @@ -0,0 +1,209 @@ + +== Qdrant + +Here is a list of all available Qdrant procedures, +note that the list and the signature procedures are consistent with the others, like the ChromaDB ones: + +[opts=header, cols="1, 3"] +|=== +| name | description +| apoc.vectordb.qdrant.createCollection(hostOrKey, collection, similarity, size, $config) | + Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`. + The default endpoint is `/collections/`. +| apoc.vectordb.qdrant.deleteCollection(hostOrKey, collection, $config) | + Deletes a collection with the name specified in the 2nd parameter. + The default endpoint is `/collections/`. +| apoc.vectordb.qdrant.upsert(hostOrKey, collection, vectors, $config) | + Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]. + The default endpoint is `/collections//points`. +| apoc.vectordb.qdrant.delete(hostOrKey, collection, ids, $config) | + Deletes the vectors with the specified `ids`. + The default endpoint is `/collections//points/delete`. +| apoc.vectordb.qdrant.get(hostOrKey, collection, ids, $config) | + Gets the vectors with the specified `ids`. + The default endpoint is `/collections//points`. +| apoc.vectordb.qdrant.getAndUpdate(hostOrKey, collection, ids, $config) | + Gets the vectors with the specified `ids`, and optionally creates/updates neo4j entities. + The default endpoint is `/collections//points`. +| apoc.vectordb.qdrant.query(hostOrKey, collection, vector, filter, limit, $config) | + Retrieve closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter. + The default endpoint is `/collections//points/search`. +| apoc.vectordb.qdrant.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $config) | + Retrieve closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter, and optionally creates/updates neo4j entities. + The default endpoint is `/collections//points/search`. +|=== + +where the 1st parameter can be a key defined by the apoc config `apoc.qdrant..host=myHost`. +With hostOrKey=null, the default is 'http://localhost:6333'. + + +=== Examples + +.Create a collection (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection[this API]) +[source,cypher] +---- +CALL apoc.vectordb.qdrant.createCollection($hostOrKey, 'test_collection', 'Cosine', 4, {}) +---- + + +.Delete a collection (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/delete_collection[this API]) +[source,cypher] +---- +CALL apoc.vectordb.qdrant.deleteCollection($hostOrKey, 'test_collection', {}) +---- + + +.Upsert vectors (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/points/operation/upsert_points[this API]) +[source,cypher] +---- +CALL apoc.vectordb.qdrant.upsert($hostOrKey, 'test_collection', + [ + {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}}, + {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}} + ], + {}) +---- + + +.Get vectors (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/points/operation/get_points[this API]) +[source,cypher] +---- +CALL apoc.vectordb.qdrant.get($hostOrKey, 'test_collection', [1,2], {}) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| null | {city: "Berlin", foo: "one"} | null | null | null | null +| null | {city: "Berlin", foo: "two"} | null | null | null | null +| ... +|=== + +.Get vectors with `{allResults: true}` +[source,cypher] +---- +CALL apoc.vectordb.qdrant.get($hostOrKey, 'test_collection', [1,2], {allResults: true, }) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| null | {city: "Berlin", foo: "one"} | 1 | [...] | null | null +| null | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| ... +|=== + +.Query vectors (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/points/operation/search_points[this API]) +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($hostOrKey, + 'test_collection', + [0.2, 0.1, 0.9, 0.7], + { must: + [ { key: "city", match: { value: "London" } } ] + }, + 5, + {allResults: true, }) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | null | null +| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| ... +|=== + + +[[mapping]] + + +We can define a mapping, to fetch the associated nodes and relationships and optionally create them, by leveraging the vector metadata. + +For example, if we have created 2 vectors with the above upsert procedures, +we can populate some existing nodes (i.e. `(:Test {myId: 'one'})` and `(:Test {myId: 'two'})`): + + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($hostOrKey, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which populates the two nodes as: `(:Test {myId: 'one', city: 'Berlin', vect: [vector1]})` and `(:Test {myId: 'two', city: 'London', vect: [vector2]})`, +which will be returned in the `entity` column result. + + +Or else, we can create a node if not exists, via `create: true`: + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($hostOrKey, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + create: true, + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which creates and 2 new nodes as above. + +Or, we can populate an existing relationship (i.e. `(:Start)-[:TEST {myId: 'one'}]->(:End)` and `(:Start)-[:TEST {myId: 'two'}]->(:End)`): + + +[source,cypher] +---- +CALL apoc.vectordb.qdrant.query($hostOrKey, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { mapping: { + embeddingKey: "vect", + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin', vect: [vector1]}]-()` +and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, +which will be returned in the `entity` column result. + + +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.qdrant.query and the `apoc.vectordb.qdrant.get` procedures. + +For example, by executing a `CALL apoc.vectordb.qdrant.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"with_payload": false, "with_vectors": false}, +so that we do not return the other values that we do not need. +==== + + + +.Delete vectors (it leverages https://qdrant.github.io/qdrant/redoc/index.html#tag/points/operation/delete_vectors[this API]) +[source,cypher] +---- +CALL apoc.vectordb.qdrant.delete($hostOrKey, 'test_collection', [1,2], {}) +---- diff --git a/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc new file mode 100644 index 0000000000..be672294aa --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/database-integration/vectordb/weaviate.adoc @@ -0,0 +1,222 @@ + +== Weaviate + +Here is a list of all available Weaviate procedures, +note that the list and the signature procedures are consistent with the others, like the Qdrant ones: + +[opts=header, cols="1, 3"] +|=== +| name | description +| apoc.vectordb.weaviate.createCollection(hostOrKey, collection, similarity, size, $config) | + Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`. + The default endpoint is `/schema`. +| apoc.vectordb.weaviate.deleteCollection(hostOrKey, collection, $config) | + Deletes a collection with the name specified in the 2nd parameter. + The default endpoint is `/schema/`. +| apoc.vectordb.weaviate.upsert(hostOrKey, collection, vectors, $config) | + Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]. + The default endpoint is `/objects`. +| apoc.vectordb.weaviate.delete(hostOrKey, collection, ids, $config) | + Deletes the vectors with the specified `ids`. + The default endpoint is `/schema`. +| apoc.vectordb.weaviate.get(hostOrKey, collection, ids, $config) | + Gets the vectors with the specified `ids`. + The default endpoint is `/schema`. +| apoc.vectordb.weaviate.query(hostOrKey, collection, vector, filter, limit, $config) | + Retrieve closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter. + Note that, besides the common config parameters, this procedure requires a `field: [listOfProperty]` config, to define which properties are to be retrieved from GraphQL running under-the-hood. + The default endpoint is `/graphql`. +| apoc.vectordb.weaviate.getAndUpdate(hostOrKey, collection, ids, $config) | + Gets the vectors with the specified `ids`, and optionally creates/updates neo4j entities. + The default endpoint is `/schema`. +| apoc.vectordb.weaviate.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $config) | + Retrieve closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter, and optionally creates/updates neo4j entities. + Note that, besides the common config parameters, this procedure requires a `field: [listOfProperty]` config, to define which properties are to be retrieved from GraphQL running under-the-hood. + The default endpoint is `/graphql`. +|=== + +where the 1st parameter can be a key defined by the apoc config `apoc.weaviate..host=myHost`. +With hostOrKey=null, the default is 'http://localhost:8080/v1'. + +=== Examples + +.Create a collection (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/schema/post/schema[this API]) +[source,cypher] +---- +CALL apoc.vectordb.weaviate.createCollection($host, 'test_collection', 'Cosine', 4, {}) +---- + +.Create a collection against a remote connection using an API key (see https://weaviate.io/developers/weaviate/configuration/authentication[here]) +[source,cypher] +---- +CALL apoc.vectordb.weaviate.createCollection("https://.weaviate.network", + 'TestCollection', + 'cosine', + 4, + {headers: {Authorization: 'Bearer '}}) +---- + + + +.Delete a collection (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/schema/delete/schema/{className}[this API]) +[source,cypher] +---- +CALL apoc.vectordb.weaviate.deleteCollection($host, 'test_collection', {}) +---- + + +.Upsert vectors (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/objects/post/objects[this API]) +[source,cypher] +---- +CALL apoc.vectordb.weaviate.upsert($host, 'test_collection', + [ + {id: "8ef2b3a7-1e56-4ddd-b8c3-2ca8901ce308", vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}}, + {id: "9ef2b3a7-1e56-4ddd-b8c3-2ca8901ce308", vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}} + ], + {}) +---- + + +.Get vectors (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/objects/get/objects/\{className\}/\{id\}[this API]) +[source,cypher] +---- +CALL apoc.vectordb.weaviate.get($host, 'test_collection', [1,2], {}) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| null | {city: "Berlin", foo: "one"} | null | null | null | null +| null | {city: "Berlin", foo: "two"} | null | null | null | null +| ... +|=== + + +.Get vectors with `{allResults: true}` +[source,cypher] +---- +CALL apoc.vectordb.weaviate.get($host, 'test_collection', [1,2], {allResults: true, }) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text | entity +| null | {city: "Berlin", foo: "one"} | 1 | [...] | null | null +| null | {city: "Berlin", foo: "two"} | 2 | [...] | null | null +| ... +|=== + + +.Query vectors (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/graphql/post/graphql[here]) +[source,cypher] +---- +CALL apoc.vectordb.weaviate.query($host, + 'test_collection', + [0.2, 0.1, 0.9, 0.7], + '{operator: Equal, valueString: "London", path: ["city"]}', + 5, + {fields: ["city", "foo"], allResults: true, }) +---- + + +.Example results +[opts="header"] +|=== +| score | metadata | id | vector | text +| 1, | {city: "Berlin", foo: "one"} | 1 | [...] | null +| 0.1 | {city: "Berlin", foo: "two"} | 2 | [...] | null +| ... +|=== + + +We can define a mapping, to fetch the associated nodes and relationships and optionally create them, by leveraging the vector metadata. + +For example, if we have created 2 vectors with the above upsert procedures, +we can populate some existing nodes (i.e. `(:Test {myId: 'one'})` and `(:Test {myId: 'two'})`): + + +[source,cypher] +---- +CALL apoc.vectordb.weaviate.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { fields: ["city", "foo"], + mapping: { + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which populates the two nodes as: `(:Test {myId: 'one', city: 'Berlin', vect: [vector1]})` +and `(:Test {myId: 'two', city: 'London', vect: [vector2]})`, +which will be returned in the `entity` column result. + + +Or else, we can create a node if not exists, via `create: true`: + +[source,cypher] +---- +CALL apoc.vectordb.weaviate.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { fields: ["city", "foo"], + mapping: { + create: true, + embeddingKey: "vect", + nodeLabel: "Test", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which creates 2 new nodes as above. + +Or, we can populate an existing relationship (i.e. `(:Start)-[:TEST {myId: 'one'}]->(:End)` and `(:Start)-[:TEST {myId: 'two'}]->(:End)`): + + +[source,cypher] +---- +CALL apoc.vectordb.weaviate.query($host, 'test_collection', + [0.2, 0.1, 0.9, 0.7], + {}, + 5, + { fields: ["city", "foo"], + mapping: { + embeddingKey: "vect", + relType: "TEST", + entityKey: "myId", + metadataKey: "foo" + } + }) +---- + +which populates the two relationships as: `()-[:TEST {myId: 'one', city: 'Berlin', vect: [vector1]}]-()` +and `()-[:TEST {myId: 'two', city: 'London', vect: [vector2]}]-()`, +which will be returned in the `entity` column result. + + +[NOTE] +==== +To optimize performances, we can choose what to `YIELD` with the apoc.vectordb.weaviate.query and the `apoc.vectordb.weaviate.get` procedures. + +For example, by executing a `CALL apoc.vectordb.weaviate.query(...) YIELD metadata, score, id`, the RestAPI request will have an {"with_payload": false, "with_vectors": false}, +so that we do not return the other values that we do not need. +==== + + + +.Delete vectors (it leverages https://weaviate.io/developers/weaviate/api/rest#tag/objects/delete/objects/\{className\}/\{id\}[this API]) +[source,cypher] +---- +CALL apoc.vectordb.weaviate.delete($host, 'test_collection', [1,2], {}) +---- diff --git a/extended/src/main/java/apoc/agg/AggregationExtended.java b/extended/src/main/java/apoc/agg/AggregationExtended.java deleted file mode 100644 index 128131c5fe..0000000000 --- a/extended/src/main/java/apoc/agg/AggregationExtended.java +++ /dev/null @@ -1,68 +0,0 @@ -package apoc.agg; - -import apoc.Extended; -import apoc.util.collection.Iterables; -import apoc.util.collection.Iterators; -import org.neo4j.graphdb.GraphDatabaseService; -import org.neo4j.graphdb.Transaction; -import org.neo4j.procedure.Context; -import org.neo4j.procedure.Description; -import org.neo4j.procedure.Name; -import org.neo4j.procedure.UserAggregationFunction; -import org.neo4j.procedure.UserAggregationResult; -import org.neo4j.procedure.UserAggregationUpdate; - -import java.util.Map; -import java.util.function.BiPredicate; - -@Extended -public class AggregationExtended { - - @Context - public GraphDatabaseService db; - - @Context - public Transaction tx; - - @UserAggregationFunction("apoc.agg.row") - @Description("apoc.agg.row(element, predicate) - Returns index of the `element` that match the given `predicate`") - public RowFunction row() { - BiPredicate curr = (current, value) -> db.executeTransactionally("RETURN " + value, - Map.of("curr", current), - result -> Iterators.singleOrNull(result.columnAs(Iterables.single(result.columns())))); - return new RowFunction(curr); - } - - @UserAggregationFunction("apoc.agg.position") - @Description("apoc.agg.position(element, value) - Returns index of the `element` that match the given `value`") - public RowFunction position() { - return new RowFunction(Object::equals); - } - - public static class RowFunction { - private boolean found; - private final BiPredicate biPredicate; - private long index = -1L; - - public RowFunction(BiPredicate biPredicate) { - this.biPredicate = biPredicate; - } - - @UserAggregationUpdate - public void update(@Name("value") Object value, @Name("element") Object element) { - if (!found) { - try { - found = this.biPredicate.test(value, element); - } catch (Exception e) { - throw new RuntimeException("The predicate query has thrown the following exception: \n" + e.getMessage()); - } - index++; - } - } - - @UserAggregationResult - public Object result() { - return index; - } - } -} diff --git a/extended/src/main/java/apoc/systemdb/metadata/ExportFunction.java b/extended/src/main/java/apoc/systemdb/metadata/ExportFunction.java deleted file mode 100644 index 167c5687ac..0000000000 --- a/extended/src/main/java/apoc/systemdb/metadata/ExportFunction.java +++ /dev/null @@ -1,41 +0,0 @@ -package apoc.systemdb.metadata; - -import apoc.ExtendedSystemPropertyKeys; -import apoc.SystemPropertyKeys; -import apoc.custom.CypherProceduresUtil; -import apoc.export.util.ProgressReporter; -import org.apache.commons.lang3.tuple.Pair; -import org.neo4j.graphdb.Node; -import org.neo4j.internal.kernel.api.procs.FieldSignature; - -import java.util.List; -import java.util.stream.Collectors; - - -public class ExportFunction implements ExportMetadata { - - @Override - public List> export(Node node, ProgressReporter progressReporter) { - final String inputs = getSignature(node, ExtendedSystemPropertyKeys.inputs.name()); - - final String outputName = ExtendedSystemPropertyKeys.output.name(); - final String outputs = node.hasProperty(outputName) - ? (String) node.getProperty(outputName) - : getSignature(node, ExtendedSystemPropertyKeys.outputs.name()); - - String statement = String.format("CALL apoc.custom.declareFunction('%s(%s) :: %s', '%s', %s, '%s');", - node.getProperty(SystemPropertyKeys.name.name()), inputs, outputs, - node.getProperty(SystemPropertyKeys.statement.name()), - node.getProperty(ExtendedSystemPropertyKeys.forceSingle.name()), - node.getProperty(ExtendedSystemPropertyKeys.description.name())); - progressReporter.nextRow(); - return List.of(Pair.of(getFileName(node, Type.CypherFunction.name()), statement)); - } - - - static String getSignature(Node node, String name) { - return CypherProceduresUtil.deserializeSignatures((String) node.getProperty(name)) - .stream().map(FieldSignature::toString) - .collect(Collectors.joining(", ")); - } -} \ No newline at end of file diff --git a/extended/src/main/java/apoc/systemdb/metadata/ExportProcedure.java b/extended/src/main/java/apoc/systemdb/metadata/ExportProcedure.java deleted file mode 100644 index cb9e87d13f..0000000000 --- a/extended/src/main/java/apoc/systemdb/metadata/ExportProcedure.java +++ /dev/null @@ -1,40 +0,0 @@ -package apoc.systemdb.metadata; - -import apoc.ExtendedSystemPropertyKeys; -import apoc.SystemPropertyKeys; -import apoc.custom.CypherProceduresUtil; -import apoc.export.util.ProgressReporter; -import org.apache.commons.lang3.tuple.Pair; -import org.neo4j.graphdb.Node; -import org.neo4j.internal.kernel.api.procs.FieldSignature; - -import java.util.List; -import java.util.stream.Collectors; - -public class ExportProcedure implements ExportMetadata { - - @Override - public List> export(Node node, ProgressReporter progressReporter) { - final String inputs = getSignature(node, ExtendedSystemPropertyKeys.inputs.name()); - - final String outputName = ExtendedSystemPropertyKeys.output.name(); - final String outputs = node.hasProperty(outputName) - ? (String) node.getProperty(outputName) - : getSignature(node, ExtendedSystemPropertyKeys.outputs.name()); - - String statement = String.format("CALL apoc.custom.declareProcedure('%s(%s) :: (%s)', '%s', '%s', '%s');", - node.getProperty(SystemPropertyKeys.name.name()), inputs, outputs, - node.getProperty(SystemPropertyKeys.statement.name()), - node.getProperty(ExtendedSystemPropertyKeys.mode.name()), - node.getProperty(ExtendedSystemPropertyKeys.description.name())); - progressReporter.nextRow(); - return List.of(Pair.of(getFileName(node, Type.CypherProcedure.name()), statement)); - } - - - static String getSignature(Node node, String name) { - return CypherProceduresUtil.deserializeSignatures((String) node.getProperty(name)) - .stream().map(FieldSignature::toString) - .collect(Collectors.joining(", ")); - } -} \ No newline at end of file diff --git a/extended/src/main/java/apoc/util/UtilsExtended.java b/extended/src/main/java/apoc/util/UtilsExtended.java deleted file mode 100644 index fe9d4960f1..0000000000 --- a/extended/src/main/java/apoc/util/UtilsExtended.java +++ /dev/null @@ -1,15 +0,0 @@ -package apoc.util; - -import apoc.Extended; - -import org.neo4j.procedure.*; - -@Extended -public class UtilsExtended { - - @UserFunction("apoc.util.hashCode") - @Description("apoc.util.hashCode(value) - Returns the java.lang.Object#hashCode() of the value") - public long hashCode(@Name("value") Object value) { - return value.hashCode(); - } -} diff --git a/extended/src/test/java/apoc/agg/AggregationExtendedTest.java b/extended/src/test/java/apoc/agg/AggregationExtendedTest.java deleted file mode 100644 index 456d37b49d..0000000000 --- a/extended/src/test/java/apoc/agg/AggregationExtendedTest.java +++ /dev/null @@ -1,66 +0,0 @@ -package apoc.agg; - -import apoc.util.TestUtil; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.neo4j.test.rule.DbmsRule; -import org.neo4j.test.rule.ImpermanentDbmsRule; - -import static apoc.util.TestUtil.testCall; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -public class AggregationExtendedTest { - - @ClassRule - public static DbmsRule db = new ImpermanentDbmsRule(); - - @BeforeClass - public static void setUp() { - TestUtil.registerProcedure(db, AggregationExtended.class); - db.executeTransactionally("UNWIND range(0,20) AS id CREATE (:Person {id: 'index' + id})"); - - db.executeTransactionally("UNWIND [{date: datetime('1999'), other: 5}, " + - "{date: datetime('2000'), other: 10}, " + - "{date: datetime('2000'), other: 15} ] AS prop " + - " CREATE (n:Date) SET n = prop"); - } - - @AfterClass - public static void tearDown() { - db.shutdown(); - } - - @Test - public void testAggRow() { - testCall(db, "MATCH (n:Person) RETURN apoc.agg.row(n.id, '$curr = \"index10\"') AS row", - (row) -> assertEquals(10L, row.get("row"))); - } - - @Test - public void testAggRowWithComplexPredicate() { - testCall(db, "MATCH (n:Date) RETURN apoc.agg.row(n, '$curr.date <> datetime(\"1999\") AND $curr.other > 11 ') AS row", - (row) -> assertEquals(2L, row.get("row"))); - } - - @Test - public void testPredicateShouldReturnABoolean() { - try { - testCall(db, "MATCH (n:Person) RETURN apoc.agg.row(n.id, '1') AS row", - (row) -> fail()); - } catch (Exception e) { - assertTrue(e.getMessage().contains("The predicate query has thrown the following exception: \n" + - "class java.lang.Long cannot be cast to class java.lang.Boolean")); - } - } - - @Test - public void testPosition() { - testCall(db, "MATCH (n:Person) RETURN apoc.agg.position(n.id, 'index10') AS row", - (row) -> assertEquals(10L, row.get("row"))); - - } -} diff --git a/extended/src/test/java/apoc/map/MapsExtendedTest.java b/extended/src/test/java/apoc/map/MapsExtendedTest.java deleted file mode 100644 index 44863269a1..0000000000 --- a/extended/src/test/java/apoc/map/MapsExtendedTest.java +++ /dev/null @@ -1,111 +0,0 @@ -package apoc.map; - -import apoc.util.TestUtil; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.neo4j.test.rule.DbmsRule; -import org.neo4j.test.rule.ImpermanentDbmsRule; - -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; - -public class MapsExtendedTest { - - private static final String OLD_KEY = "depends_on"; - private static final String NEW_KEY = "new_name"; - - @Rule - public DbmsRule db = new ImpermanentDbmsRule(); - - @Before - public void setUp() { - TestUtil.registerProcedure(db, MapsExtended.class); - } - - @Test - public void testRenameKeyNotExistent() { - Map map = Map.of( - "testKey", List.of(Map.of("testKey", "some_value")), - "name", "Mario" - ); - TestUtil.testCall(db, "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo) as value", - Map.of("map", map, "keyFrom", "notExistent", "keyTo", "something"), - (r) -> assertEquals(map, r.get("value"))); - } - - @Test - public void testRenameKeyInnerKey() { - Map map = Map.of( - "testKey", List.of(Map.of("innerKey", "some_value")), - "name", "Mario" - ); - TestUtil.testCall(db, "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo) as value", - Map.of("map", map, "keyFrom", "innerKey", "keyTo", "otherKey"), - (r) -> { - Map expected = Map.of( - "testKey", List.of(Map.of("otherKey", "some_value")), - "name", "Mario" - );; - assertEquals(expected, r.get("value")); - }); - } - - @Test - public void testRenameKey() { - Map map = Map.of( - OLD_KEY, List.of(Map.of(OLD_KEY, "some_value")), - "name", "Mario" - );; - TestUtil.testCall(db, "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo) as value", - Map.of("map", map, "keyFrom", OLD_KEY, "keyTo", NEW_KEY), - (r) -> { - Map expected = Map.of(NEW_KEY, List.of(Map.of(NEW_KEY, "some_value")), - "name", "Mario" - ); - Map> actual = (Map) r.get("value"); - assertEquals(expected, actual); - }); - } - - @Test - public void testRenameKeyComplexMap() { - Map map = Map.of( - OLD_KEY, List.of(1L, "test", Map.of(OLD_KEY, "some_value")), - "otherKey", List.of(1L, List.of("test", Map.of(OLD_KEY, "some_value"))), - "name", Map.of(OLD_KEY, "some_value"), - "other", "key" - ); - TestUtil.testCall(db, "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo) as value", - Map.of("map", map, "keyFrom", OLD_KEY, "keyTo", NEW_KEY), - (r) -> { - Map expected = Map.of( - NEW_KEY, List.of(1L, "test", Map.of(NEW_KEY, "some_value")), - "otherKey", List.of(1L, List.of("test", Map.of(NEW_KEY, "some_value"))), - "name", Map.of(NEW_KEY, "some_value"), - "other", "key" - ); - Map> actual = (Map) r.get("value"); - assertEquals(expected, actual); - }); - } - - @Test - public void testRenameKeyRecursiveFalse() { - Map map = Map.of( - OLD_KEY, List.of(Map.of(OLD_KEY, "some_value")), - "name", "Mario" - );; - TestUtil.testCall(db, "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo, {recursive: false}) as value", - Map.of("map", map, "keyFrom", OLD_KEY, "keyTo", NEW_KEY), - (r) -> { - Map> actual = (Map) r.get("value"); - Map expected = Map.of(NEW_KEY, List.of(Map.of(OLD_KEY, "some_value")), - "name", "Mario" - ); - assertEquals(expected, actual); - }); - } -} diff --git a/extended/src/test/java/apoc/util/ExtendedTestContainerUtil.java b/extended/src/test/java/apoc/util/ExtendedTestContainerUtil.java deleted file mode 100644 index a0cd36b423..0000000000 --- a/extended/src/test/java/apoc/util/ExtendedTestContainerUtil.java +++ /dev/null @@ -1,86 +0,0 @@ -package apoc.util; - -import java.io.File; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; -import java.util.function.Consumer; - -import org.apache.commons.io.filefilter.IOFileFilter; -import org.apache.commons.io.filefilter.WildcardFileFilter; -import org.neo4j.driver.AuthToken; -import org.neo4j.driver.AuthTokens; -import org.neo4j.driver.Driver; -import org.neo4j.driver.GraphDatabase; -import org.neo4j.driver.Session; -import org.neo4j.driver.SessionConfig; - -import static apoc.util.TestContainerUtil.copyFilesToPlugin; -import static apoc.util.TestContainerUtil.executeGradleTasks; -import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; - -public class ExtendedTestContainerUtil -{ - public static TestcontainersCausalCluster createEnterpriseCluster( List apocPackages, int numOfCoreInstances, int numberOfReadReplica, Map neo4jConfig, Map envSettings) { - return TestcontainersCausalCluster.create(apocPackages, numOfCoreInstances, numberOfReadReplica, Duration.ofMinutes(4), neo4jConfig, envSettings); - } - - public static T singleResultFirstColumn(Session session, String cypher) { - return (T) session.executeWrite(tx -> tx.run(cypher).single().fields().get(0).value().asObject()); - } - - public static void testCallInReadTransaction(Session session, String call, Consumer> consumer) { - TestContainerUtil.testCallInReadTransaction(session, call, null, consumer); - } - - public static void addExtraDependencies() { - File extraDepsDir = new File(TestContainerUtil.baseDir, "extra-dependencies"); - // build the extra-dependencies - executeGradleTasks(extraDepsDir, "buildDependencies"); - - // add all extra deps to the plugin docker folder - final File directory = new File(extraDepsDir, "build/allJars"); - final IOFileFilter instance = new WildcardFileFilter("*.jar"); - copyFilesToPlugin(directory, instance, TestContainerUtil.pluginsFolder); - } - - /** - * Open a `neo4j://` routing session for each cluster member against system db - */ - public static void routingSessionForEachMembers(List members, - BiConsumer sessionConsumer) { - - for (Neo4jContainerExtension container: members) { - // Bolt (routing) url - String neo4jUrl = "neo4j://localhost:" + container.getMappedPort(7687); - - AuthToken authToken = AuthTokens.basic("neo4j", container.getAdminPassword()); - try (Driver driver = GraphDatabase.driver(neo4jUrl, authToken); - Session session = driver.session(SessionConfig.forDatabase(SYSTEM_DATABASE_NAME))) { - sessionConsumer.accept(session, container); - } - } - } - - public static Driver getDriverIfNotReplica(Neo4jContainerExtension container) { - final String readReplica = TestcontainersCausalCluster.ClusterInstanceType.READ_REPLICA.toString(); - final Driver driver = container.getDriver(); - if (readReplica.equals(container.getEnvMap().get("NEO4J_dbms_mode")) || driver == null) { - return null; - } - return driver; - } - - public static String getBoltAddress(Neo4jContainerExtension instance) { - return instance.getEnvMap().get("NEO4J_dbms_connector_bolt_advertised__address"); - } - - public static boolean dbIsWriter(String dbName, Session session, String boltAddress) { - return session.run( "SHOW DATABASE $dbName WHERE address = $boltAddress", - Map.of("dbName", dbName, "boltAddress", boltAddress) ) - .single().get("writer") - .asBoolean(); - } - -} diff --git a/full-it/src/test/java/apoc/full/it/UUIDClusterRoutingTest.java b/full-it/src/test/java/apoc/full/it/UUIDClusterRoutingTest.java index 740a1cdc5d..7754508ee7 100644 --- a/full-it/src/test/java/apoc/full/it/UUIDClusterRoutingTest.java +++ b/full-it/src/test/java/apoc/full/it/UUIDClusterRoutingTest.java @@ -48,6 +48,7 @@ import org.junit.Ignore; import org.junit.Test; import org.neo4j.driver.*; +import org.neo4j.driver.Record; import org.neo4j.internal.helpers.collection.Iterators; @Ignore diff --git a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java new file mode 100644 index 0000000000..b9b8cbaccf --- /dev/null +++ b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java @@ -0,0 +1,455 @@ +package apoc.full.it.vectordb; + +import static apoc.util.MapUtil.map; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.vectordb.VectorDbHandler.Type.CHROMA; +import static apoc.vectordb.VectorDbTestUtil.EntityType.*; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; +import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; +import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static apoc.vectordb.VectorMappingConfig.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; +import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; + +import apoc.util.TestUtil; +import apoc.vectordb.ChromaDb; +import apoc.vectordb.VectorDb; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.test.TestDatabaseManagementServiceBuilder; +import org.testcontainers.chromadb.ChromaDBContainer; + +public class ChromaDbTest { + private static final AtomicReference COLL_ID = new AtomicReference<>(); + private static final ChromaDBContainer CHROMA_CONTAINER = new ChromaDBContainer("chromadb/chroma:0.4.25.dev137"); + + private static String HOST; + + @ClassRule + public static TemporaryFolder storeDir = new TemporaryFolder(); + + private static GraphDatabaseService sysDb; + private static GraphDatabaseService db; + private static DatabaseManagementService databaseManagementService; + + @BeforeClass + public static void setUp() throws Exception { + databaseManagementService = + new TestDatabaseManagementServiceBuilder(storeDir.getRoot().toPath()).build(); + db = databaseManagementService.database(DEFAULT_DATABASE_NAME); + sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); + + CHROMA_CONTAINER.start(); + + HOST = "localhost:" + CHROMA_CONTAINER.getMappedPort(8000); + TestUtil.registerProcedure(db, ChromaDb.class, VectorDb.class); + + testCall( + db, + "CALL apoc.vectordb.chroma.createCollection($host, 'test_collection', 'cosine', 4)", + map("host", HOST), + r -> { + Map value = (Map) r.get("value"); + COLL_ID.set((String) value.get("id")); + }); + + testCall( + db, + "CALL apoc.vectordb.chroma.upsert($host, $collection,\n" + " [\n" + + " {id: '1', vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}, text: 'ajeje'},\n" + + " {id: '2', vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}, text: 'brazorf'}\n" + + " ])", + map("host", HOST, "collection", COLL_ID.get()), + r -> { + assertNull(r.get("value")); + }); + } + + @AfterClass + public static void tearDown() throws Exception { + testCall(db, "CALL apoc.vectordb.chroma.deleteCollection($host, 'test_collection')", map("host", HOST), r -> { + Map value = (Map) r.get("value"); + assertNull(value); + }); + + databaseManagementService.shutdown(); + CHROMA_CONTAINER.stop(); + } + + @Before + public void before() { + dropAndDeleteAll(db); + } + + @Test + public void getVectors() { + testResult( + db, + "CALL apoc.vectordb.chroma.get($host, $collection, ['1'], $conf) ", + map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, FALSE); + assertNotNull(row.get("vector")); + assertEquals("ajeje", row.get("text")); + }); + } + + @Test + public void getVectorsWithoutVectorResult() { + testResult( + db, + "CALL apoc.vectordb.chroma.get($host, $collection, ['1'])", + map("host", HOST, "collection", COLL_ID.get()), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void deleteVector() { + testCall( + db, + "CALL apoc.vectordb.chroma.upsert($host, $collection,\n" + "[\n" + + " {id: 3, embedding: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" + + "])", + map("host", HOST, "collection", COLL_ID.get()), + r -> { + assertNull(r.get("value")); + }); + + testCall( + db, + "CALL apoc.vectordb.chroma.delete($host, $collection, [3]) ", + map("host", HOST, "collection", COLL_ID.get()), + r -> { + assertEquals(List.of("3"), r.get("value")); + }); + } + + @Test + public void createAndDeleteVector() { + testResult( + db, + "CALL apoc.vectordb.chroma.get($host, $collection, ['1'], $conf) ", + map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, FALSE); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void queryVectors() { + testResult( + db, + "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void queryVectorsWithoutVectorResult() { + testResult( + db, + "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5) " + + " YIELD score, vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "collection", COLL_ID.get()), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + + row = r.next(); + assertEquals(Map.of("city", "London", "foo", "two"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void queryVectorsWithYield() { + testResult( + db, + "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf) YIELD metadata, id", + map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), + r -> { + assertBerlinResult(r.next(), FALSE); + assertLondonResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithFilter() { + testResult( + db, + "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {city: 'London'}, 5, $conf) YIELD metadata, id", + map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), + r -> { + assertLondonResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithLimit() { + testResult( + db, + "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id", + map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), + r -> { + assertBerlinResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithCreateNode() { + Map conf = map( + ALL_RESULTS_KEY, + true, + MAPPING_KEY, + map( + EMBEDDING_KEY, + "vect", + NODE_LABEL, + "Test", + ENTITY_KEY, + "myId", + METADATA_KEY, + "foo", + CREATE_KEY, + true)); + + testResult( + db, + "CALL apoc.vectordb.chroma.queryAndUpdate($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + + testResult( + db, + "CALL apoc.vectordb.chroma.queryAndUpdate($host, $collection, [0.22, 0.11, 0.99, 0.17], {}, 5, $conf) " + + " YIELD score, vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getReadOnlyVectorsWithMapping() { + Map conf = map(ALL_RESULTS_KEY, true, MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + + try { + testCall( + db, + "CALL apoc.vectordb.chroma.get($host, $collection, [1, 2], $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> fail()); + } catch (RuntimeException e) { + Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); + } + } + + @Test + public void queryVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + testResult( + db, + "CALL apoc.vectordb.chroma.queryAndUpdate($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void queryReadOnlyVectorsWithMapping() { + Map conf = map(ALL_RESULTS_KEY, true, MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + + try { + testCall( + db, + "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> fail()); + } catch (RuntimeException e) { + Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); + } + } + + @Test + public void queryVectorsWithCreateRel() { + + db.executeTransactionally( + "CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", REL_TYPE, "TEST", ENTITY_KEY, "myId", METADATA_KEY, "foo", "create", true)); + testResult( + db, + "CALL apoc.vectordb.chroma.queryAndUpdate($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "collection", COLL_ID.get(), "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertRelsCreated(db); + } + + @Test + public void queryVectorsWithSystemDbStorage() { + String keyConfig = "chroma-config-foo"; + String baseUrl = "http://" + HOST; + Map mapping = + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo"); + sysDb.executeTransactionally( + "CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)", + map( + "vectorName", + CHROMA.toString(), + "keyConfig", + keyConfig, + "databaseName", + DEFAULT_DATABASE_NAME, + "conf", + map( + "host", baseUrl, + "credentials", null, + "mapping", mapping))); + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + testResult( + db, + "CALL apoc.vectordb.chroma.queryAndUpdate($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", keyConfig, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } +} diff --git a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java new file mode 100644 index 0000000000..1854ab39e4 --- /dev/null +++ b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java @@ -0,0 +1,491 @@ +package apoc.full.it.vectordb; + +import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.util.MapUtil.map; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.vectordb.VectorDbHandler.Type.QDRANT; +import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE; +import static apoc.vectordb.VectorDbTestUtil.EntityType.REL; +import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult; +import static apoc.vectordb.VectorDbTestUtil.assertLondonResult; +import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated; +import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated; +import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll; +import static apoc.vectordb.VectorDbTestUtil.getAuthHeader; +import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; +import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static apoc.vectordb.VectorMappingConfig.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.neo4j.configuration.GraphDatabaseSettings.*; + +import apoc.util.TestUtil; +import apoc.util.Util; +import apoc.vectordb.Qdrant; +import apoc.vectordb.VectorDb; +import apoc.vectordb.VectorDbTestUtil; +import java.util.Map; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.test.TestDatabaseManagementServiceBuilder; +import org.testcontainers.qdrant.QdrantContainer; + +public class QdrantTest { + private static final String ADMIN_KEY = "my_admin_api_key"; + private static final String READONLY_KEY = "my_readonly_api_key"; + + private static final QdrantContainer QDRANT_CONTAINER = new QdrantContainer("qdrant/qdrant:v1.7.4") + .withEnv("QDRANT__SERVICE__API_KEY", ADMIN_KEY) + .withEnv("QDRANT__SERVICE__READ_ONLY_API_KEY", READONLY_KEY); + + private static final Map ADMIN_AUTHORIZATION = getAuthHeader(ADMIN_KEY); + private static final Map READONLY_AUTHORIZATION = getAuthHeader(READONLY_KEY); + private static final Map ADMIN_HEADER_CONF = map(HEADERS_KEY, ADMIN_AUTHORIZATION); + + private static String HOST; + + @ClassRule + public static TemporaryFolder storeDir = new TemporaryFolder(); + + private static GraphDatabaseService sysDb; + private static GraphDatabaseService db; + private static DatabaseManagementService databaseManagementService; + + @BeforeClass + public static void setUp() throws Exception { + databaseManagementService = + new TestDatabaseManagementServiceBuilder(storeDir.getRoot().toPath()).build(); + db = databaseManagementService.database(DEFAULT_DATABASE_NAME); + sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); + + QDRANT_CONTAINER.start(); + + HOST = "localhost:" + QDRANT_CONTAINER.getMappedPort(6333); + TestUtil.registerProcedure(db, Qdrant.class, VectorDb.class); + + testCall( + db, + "CALL apoc.vectordb.qdrant.createCollection($host, 'test_collection', 'Cosine', 4, $conf)", + map("host", HOST, "conf", ADMIN_HEADER_CONF), + r -> { + Map value = (Map) r.get("value"); + assertEquals("ok", value.get("status")); + }); + + testCall( + db, + "CALL apoc.vectordb.qdrant.upsert($host, 'test_collection',\n" + "[\n" + + " {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}},\n" + + " {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}}\n" + + "],\n" + + "$conf)", + map("host", HOST, "conf", ADMIN_HEADER_CONF), + r -> { + Map value = (Map) r.get("value"); + assertEquals("ok", value.get("status")); + }); + } + + @AfterClass + public static void tearDown() throws Exception { + testCall( + db, + "CALL apoc.vectordb.qdrant.deleteCollection($host, 'test_collection', $conf)", + map("host", HOST, "conf", ADMIN_HEADER_CONF), + r -> { + Map value = (Map) r.get("value"); + assertEquals(true, value.get("result")); + }); + + databaseManagementService.shutdown(); + QDRANT_CONTAINER.stop(); + } + + @Before + public void before() { + dropAndDeleteAll(db); + } + + @Test + public void getVectorsWithReadOnlyApiKey() { + testResult( + db, + "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1], $conf) ", + map("host", HOST, "conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, READONLY_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertBerlinResult(row, FALSE); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void writeOperationWithReadOnlyUser() { + try { + testCall( + db, + "CALL apoc.vectordb.qdrant.deleteCollection($host, 'test_collection', $conf)", + Util.map("host", HOST, "conf", Util.map(HEADERS_KEY, READONLY_AUTHORIZATION)), + r -> fail()); + } catch (Exception e) { + assertThat(e.getMessage()).contains("HTTP response code: 403"); + } + } + + @Test + public void getVectorsWithoutVectorResult() { + testResult( + db, + "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1], $conf) ", + map("host", HOST, "conf", ADMIN_HEADER_CONF), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void deleteVector() { + testCall( + db, + "CALL apoc.vectordb.qdrant.upsert($host, 'test_collection',\n" + "[\n" + + " {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}},\n" + + " {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" + + "],\n" + + "$conf)", + map("host", HOST, "conf", ADMIN_HEADER_CONF), + r -> { + Map value = (Map) r.get("value"); + assertEquals("ok", value.get("status")); + }); + + testCall( + db, + "CALL apoc.vectordb.qdrant.delete($host, 'test_collection', [3, 4], $conf) ", + map("host", HOST, "conf", ADMIN_HEADER_CONF), + r -> { + Map value = (Map) r.get("value"); + assertEquals("ok", value.get("status")); + }); + } + + @Test + public void queryVectors() { + testResult( + db, + "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertBerlinResult(row, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void queryVectorsWithoutVectorResult() { + testResult( + db, + "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", map(HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + + row = r.next(); + assertEquals(Map.of("city", "London", "foo", "two"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void queryVectorsWithYield() { + testResult( + db, + "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf) YIELD metadata, id", + map("host", HOST, "conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertBerlinResult(r.next(), FALSE); + assertLondonResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithFilter() { + testResult( + db, + "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7],\n" + "{ must:\n" + + " [ { key: \"city\", match: { value: \"London\" } } ]\n" + + "},\n" + + "5, $conf) YIELD metadata, id", + map("host", HOST, "conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertLondonResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithLimit() { + testResult( + db, + "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id", + map("host", HOST, "conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertBerlinResult(r.next(), FALSE); + }); + } + + @Test + public void queryVectorsWithCreateNode() { + + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map( + EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId", + METADATA_KEY, "foo", + CREATE_KEY, true)); + testResult( + db, + "CALL apoc.vectordb.qdrant.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + + testResult( + db, + "MATCH (n:Test) RETURN properties(n) AS props ORDER BY n.myId", + VectorDbTestUtil::vectorEntityAssertions); + + testResult( + db, + "CALL apoc.vectordb.qdrant.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) " + + "YIELD vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getReadOnlyVectorsWithMapping() { + Map conf = map(ALL_RESULTS_KEY, true, MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + + try { + testCall( + db, + "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1, 2], $conf)", + map("host", HOST, "conf", conf), + r -> fail()); + } catch (RuntimeException e) { + Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); + } + } + + @Test + public void queryVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.qdrant.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void queryVectorsWithCreateRel() { + + db.executeTransactionally( + "CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map( + EMBEDDING_KEY, "vect", + REL_TYPE, "TEST", + ENTITY_KEY, "myId", + METADATA_KEY, "foo")); + testResult( + db, + "CALL apoc.vectordb.qdrant.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertRelsCreated(db); + } + + @Test + public void queryReadOnlyVectorsWithMapping() { + Map conf = map(ALL_RESULTS_KEY, true, MAPPING_KEY, map(EMBEDDING_KEY, "vect")); + + try { + testCall( + db, + "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", HOST, "conf", conf), + r -> fail()); + } catch (RuntimeException e) { + Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); + } + } + + @Test + public void queryVectorsWithSystemDbStorage() { + String keyConfig = "qdrant-config-foo"; + String baseUrl = "http://" + HOST; + Map mapping = + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo"); + sysDb.executeTransactionally( + "CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)", + map( + "vectorName", + QDRANT.toString(), + "keyConfig", + keyConfig, + "databaseName", + DEFAULT_DATABASE_NAME, + "conf", + map( + "host", baseUrl, + "credentials", ADMIN_KEY, + "mapping", mapping))); + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + testResult( + db, + "CALL apoc.vectordb.qdrant.queryAndUpdate($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + map("host", keyConfig, "conf", map(ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } +} diff --git a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java new file mode 100644 index 0000000000..14b76f4ef9 --- /dev/null +++ b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java @@ -0,0 +1,555 @@ +package apoc.full.it.vectordb; + +import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testCallEmpty; +import static apoc.util.TestUtil.testResult; +import static apoc.util.Util.map; +import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE; +import static apoc.vectordb.VectorDbTestUtil.*; +import static apoc.vectordb.VectorDbTestUtil.EntityType.*; +import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; +import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; +import static apoc.vectordb.VectorMappingConfig.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; +import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME; + +import apoc.util.MapUtil; +import apoc.util.TestUtil; +import apoc.vectordb.VectorDb; +import apoc.vectordb.VectorDbTestUtil; +import apoc.vectordb.Weaviate; +import java.util.List; +import java.util.Map; +import org.assertj.core.api.Assertions; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.neo4j.dbms.api.DatabaseManagementService; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.ResourceIterator; +import org.neo4j.test.TestDatabaseManagementServiceBuilder; +import org.testcontainers.weaviate.WeaviateContainer; + +public class WeaviateTest { + private static final List FIELDS = List.of("city", "foo"); + private static final String ADMIN_KEY = "jane-secret-key"; + private static final String READONLY_KEY = "ian-secret-key"; + + private static final WeaviateContainer WEAVIATE_CONTAINER = new WeaviateContainer( + "semitechnologies/weaviate:1.24.5") + .withEnv("AUTHENTICATION_APIKEY_ENABLED", "true") + .withEnv("AUTHENTICATION_APIKEY_ALLOWED_KEYS", ADMIN_KEY + "," + READONLY_KEY) + .withEnv("AUTHENTICATION_APIKEY_USERS", "jane@doe.com,ian-smith") + .withEnv("AUTHORIZATION_ADMINLIST_ENABLED", "true") + .withEnv("AUTHORIZATION_ADMINLIST_USERS", "jane@doe.com,john@doe.com") + .withEnv("AUTHORIZATION_ADMINLIST_READONLY_USERS", "ian-smith,roberta@doe.com"); + + private static final Map ADMIN_AUTHORIZATION = getAuthHeader(ADMIN_KEY); + private static final Map READONLY_AUTHORIZATION = getAuthHeader(READONLY_KEY); + private static final Map ADMIN_HEADER_CONF = map(HEADERS_KEY, ADMIN_AUTHORIZATION); + + private static final String ID_1 = "8ef2b3a7-1e56-4ddd-b8c3-2ca8901ce308"; + private static final String ID_2 = "9ef2b3a7-1e56-4ddd-b8c3-2ca8901ce308"; + + private static String HOST; + + @ClassRule + public static TemporaryFolder storeDir = new TemporaryFolder(); + + private static GraphDatabaseService sysDb; + private static GraphDatabaseService db; + private static DatabaseManagementService databaseManagementService; + + @BeforeClass + public static void setUp() throws Exception { + databaseManagementService = + new TestDatabaseManagementServiceBuilder(storeDir.getRoot().toPath()).build(); + db = databaseManagementService.database(DEFAULT_DATABASE_NAME); + sysDb = databaseManagementService.database(SYSTEM_DATABASE_NAME); + + WEAVIATE_CONTAINER.start(); + HOST = WEAVIATE_CONTAINER.getHttpHostAddress(); + + TestUtil.registerProcedure(db, Weaviate.class, VectorDb.class); + + testCall( + db, + "CALL apoc.vectordb.weaviate.createCollection($host, 'TestCollection', 'cosine', 4, $conf)", + MapUtil.map("host", HOST, "conf", ADMIN_HEADER_CONF), + r -> { + Map value = (Map) r.get("value"); + assertEquals("TestCollection", value.get("class")); + }); + + testResult( + db, + "CALL apoc.vectordb.weaviate.upsert($host, 'TestCollection', [\n" + + " {id: $id1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}},\n" + + " {id: $id2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}},\n" + + " {id: '7ef2b3a7-1e56-4ddd-b8c3-2ca8901ce308', vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}},\n" + + " {id: '7ef2b3a7-1e56-4ddd-b8c3-2ca8901ce309', vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" + + "], $conf)", + MapUtil.map("host", HOST, "id1", ID_1, "id2", ID_2, "conf", ADMIN_HEADER_CONF), + r -> { + ResourceIterator values = r.columnAs("value"); + assertEquals("TestCollection", values.next().get("class")); + assertEquals("TestCollection", values.next().get("class")); + assertEquals("TestCollection", values.next().get("class")); + assertEquals("TestCollection", values.next().get("class")); + assertFalse(values.hasNext()); + }); + + // -- delete vector + testCall( + db, + "CALL apoc.vectordb.weaviate.delete($host, 'TestCollection', " + + "['7ef2b3a7-1e56-4ddd-b8c3-2ca8901ce308', '7ef2b3a7-1e56-4ddd-b8c3-2ca8901ce309']" + + ", $conf) ", + map("host", HOST, "conf", ADMIN_HEADER_CONF), + r -> { + List value = (List) r.get("value"); + assertEquals( + List.of("7ef2b3a7-1e56-4ddd-b8c3-2ca8901ce308", "7ef2b3a7-1e56-4ddd-b8c3-2ca8901ce309"), + value); + }); + } + + @AfterClass + public static void tearDown() throws Exception { + testCallEmpty( + db, + "CALL apoc.vectordb.weaviate.deleteCollection($host, 'TestCollection', $conf)", + MapUtil.map("host", HOST, "conf", ADMIN_HEADER_CONF)); + + WEAVIATE_CONTAINER.stop(); + databaseManagementService.shutdown(); + } + + @Before + public void before() { + dropAndDeleteAll(db); + } + + @Test + public void getVectorsWithReadOnlyApiKey() { + testResult( + db, + "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', [$id1], $conf)", + map("host", HOST, "id1", ID_1, "conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, READONLY_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, FALSE); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void writeOperationWithReadOnlyUser() { + try { + testCall( + db, + "CALL apoc.vectordb.weaviate.deleteCollection($host, 'TestCollection', $conf)", + map("host", HOST, "conf", map(HEADERS_KEY, READONLY_AUTHORIZATION)), + r -> fail()); + } catch (Exception e) { + assertThat(e.getMessage()).contains("HTTP response code: 403"); + } + } + + @Test + public void getVectorsWithoutVectorResult() { + testResult( + db, + "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', [$id1], $conf)", + map("host", HOST, "id1", ID_1, "conf", map(HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void queryVectors() { + testResult( + db, + "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata RETURN * ORDER BY id", + map( + "host", + HOST, + "conf", + map(ALL_RESULTS_KEY, true, "fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, ID_2, FALSE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + } + + @Test + public void queryVectorsWithoutVectorResult() { + testResult( + db, + "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", map("fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + Map row = r.next(); + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + + row = r.next(); + assertEquals(Map.of("city", "London", "foo", "two"), row.get("metadata")); + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + assertNull(row.get("id")); + }); + } + + @Test + public void queryVectorsWithYield() { + testResult( + db, + "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + "YIELD metadata, id RETURN * ORDER BY id", + map( + "host", + HOST, + "conf", + map(ALL_RESULTS_KEY, true, "fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertBerlinResult(r.next(), ID_1, FALSE); + assertLondonResult(r.next(), ID_2, FALSE); + }); + } + + @Test + public void queryVectorsWithFilter() { + testResult( + db, + "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7],\n" + + " '{operator: Equal, valueString: \"London\", path: [\"city\"]}',\n" + + " 5, $conf) YIELD metadata, id RETURN * ORDER BY id", + map( + "host", + HOST, + "conf", + map(ALL_RESULTS_KEY, true, "fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertLondonResult(r.next(), ID_2, FALSE); + }); + } + + @Test + public void queryVectorsWithLimit() { + testResult( + db, + "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 1, $conf) YIELD metadata, id RETURN * ORDER BY id", + map( + "host", + HOST, + "conf", + map(ALL_RESULTS_KEY, true, "fields", FIELDS, HEADERS_KEY, ADMIN_AUTHORIZATION)), + r -> { + assertBerlinResult(r.next(), ID_1, FALSE); + }); + } + + @Test + public void queryVectorsWithCreateNode() { + + Map conf = map( + ALL_RESULTS_KEY, + true, + "fields", + FIELDS, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map( + EMBEDDING_KEY, + "vect", + NODE_LABEL, + "Test", + ENTITY_KEY, + "myId", + METADATA_KEY, + "foo", + CREATE_KEY, + true)); + testResult( + db, + "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + "YIELD score, vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, ID_2, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + + testResult( + db, + "MATCH (n:Test) RETURN properties(n) AS props ORDER BY n.myId", + VectorDbTestUtil::vectorEntityAssertions); + + testResult( + db, + "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, ID_2, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void queryVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + "fields", + FIELDS, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + testResult( + db, + "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata, node RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, ID_2, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getVectorsWithCreateNodeUsingExistingNode() { + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + Map conf = MapUtil.map( + ALL_RESULTS_KEY, + true, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + MapUtil.map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + + testResult( + db, + "CALL apoc.vectordb.weaviate.getAndUpdate($host, 'TestCollection', [$id1, $id2], $conf)", + map("host", HOST, "id1", ID_1, "id2", ID_2, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, NODE); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, ID_2, NODE); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } + + @Test + public void getReadOnlyVectorsWithMapping() { + Map conf = MapUtil.map(ALL_RESULTS_KEY, true, MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect")); + + try { + testCall( + db, + "CALL apoc.vectordb.weaviate.get($host, 'TestCollection', [1, 2], $conf)", + map("host", HOST, "conf", conf), + r -> fail()); + } catch (RuntimeException e) { + Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); + } + } + + @Test + public void queryVectorsWithCreateRel() { + + db.executeTransactionally( + "CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); + + Map conf = map( + ALL_RESULTS_KEY, + true, + "fields", + FIELDS, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map(EMBEDDING_KEY, "vect", REL_TYPE, "TEST", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + testResult( + db, + "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata, rel RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, ID_2, REL); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertRelsCreated(db); + } + + @Test + public void queryReadOnlyVectorsWithMapping() { + Map conf = MapUtil.map(ALL_RESULTS_KEY, true, MAPPING_KEY, MapUtil.map(EMBEDDING_KEY, "vect")); + + try { + testCall( + db, + "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + MapUtil.map("host", HOST, "conf", conf), + r -> fail()); + } catch (RuntimeException e) { + Assertions.assertThat(e.getMessage()).contains(ERROR_READONLY_MAPPING); + } + } + + @Test + public void queryVectorsWithCreateRelWithoutVectorResult() { + + db.executeTransactionally( + "CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)"); + + Map conf = map( + "fields", + FIELDS, + HEADERS_KEY, + ADMIN_AUTHORIZATION, + MAPPING_KEY, + map(REL_TYPE, "TEST", ENTITY_KEY, "myId", METADATA_KEY, "foo")); + testResult( + db, + "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf) " + + " YIELD score, vector, id, metadata, rel RETURN * ORDER BY id", + map("host", HOST, "conf", conf), + r -> { + Map row = r.next(); + Map props = ((Entity) row.get("rel")).getAllProperties(); + assertEquals("Berlin", props.get("city")); + assertEquals("one", props.get("myId")); + assertNull(props.get("vect")); + + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + + row = r.next(); + props = ((Entity) row.get("rel")).getAllProperties(); + assertEquals("London", props.get("city")); + assertEquals("two", props.get("myId")); + assertNull(props.get("vect")); + + assertNotNull(row.get("score")); + assertNull(row.get("vector")); + }); + } + + @Test + public void queryVectorsWithSystemDbStorage() { + String keyConfig = "weaviate-config-foo"; + String baseUrl = "http://" + HOST + "/v1"; + Map mapping = + map(EMBEDDING_KEY, "vect", NODE_LABEL, "Test", ENTITY_KEY, "myId", METADATA_KEY, "foo"); + sysDb.executeTransactionally( + "CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)", + map( + "vectorName", + WEAVIATE.toString(), + "keyConfig", + keyConfig, + "databaseName", + DEFAULT_DATABASE_NAME, + "conf", + map( + "host", baseUrl, + "credentials", ADMIN_KEY, + "mapping", mapping))); + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + testResult( + db, + "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", + map("host", keyConfig, "conf", map("fields", FIELDS, ALL_RESULTS_KEY, true)), + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, ID_2, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } +} diff --git a/extended/src/main/java/apoc/diff/DiffExtended.java b/full/src/main/java/apoc/diff/DiffExtended.java similarity index 95% rename from extended/src/main/java/apoc/diff/DiffExtended.java rename to full/src/main/java/apoc/diff/DiffExtended.java index 7039675b84..beadd09adb 100644 --- a/extended/src/main/java/apoc/diff/DiffExtended.java +++ b/full/src/main/java/apoc/diff/DiffExtended.java @@ -1,19 +1,17 @@ package apoc.diff; - import apoc.Extended; import apoc.util.Util; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; import org.neo4j.graphdb.Relationship; import org.neo4j.graphdb.Transaction; +import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Name; import org.neo4j.procedure.UserFunction; -import org.neo4j.procedure.Context; - -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; @Extended public class DiffExtended { @@ -23,7 +21,8 @@ public class DiffExtended { @UserFunction("apoc.diff.relationships") @Description("Returns a Map detailing the property differences between the two given relationships") - public Map relationships(@Name("leftRel") Relationship leftRel, @Name("rightRel") Relationship rightRel) { + public Map relationships( + @Name("leftRel") Relationship leftRel, @Name("rightRel") Relationship rightRel) { leftRel = Util.rebind(tx, leftRel); rightRel = Util.rebind(tx, rightRel); Map allLeftProperties = leftRel.getAllProperties(); diff --git a/extended/src/main/java/apoc/graph/GraphsExtended.java b/full/src/main/java/apoc/graph/GraphsExtended.java similarity index 76% rename from extended/src/main/java/apoc/graph/GraphsExtended.java rename to full/src/main/java/apoc/graph/GraphsExtended.java index 09bf5e86f8..5683661e95 100644 --- a/extended/src/main/java/apoc/graph/GraphsExtended.java +++ b/full/src/main/java/apoc/graph/GraphsExtended.java @@ -4,7 +4,13 @@ import apoc.result.GraphResult; import apoc.result.VirtualNode; import apoc.result.VirtualRelationship; -import apoc.util.collection.Iterables; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.neo4j.driver.internal.util.Iterables; import org.neo4j.graphdb.Label; import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Path; @@ -17,13 +23,6 @@ import org.neo4j.procedure.UserAggregationResult; import org.neo4j.procedure.UserAggregationUpdate; -import java.util.Collection; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.stream.Stream; - @Extended public class GraphsExtended { @@ -32,15 +31,17 @@ public class GraphsExtended { "CALL apoc.graph.filterProperties(anyEntityObject, nodePropertiesToRemove, relPropertiesToRemove) YIELD nodes, relationships - returns a set of virtual nodes and relationships without the properties defined in nodePropertiesToRemove and relPropertiesToRemove") public Stream fromData( @Name("value") Object value, - @Name(value = "nodePropertiesToRemove", defaultValue = "{}") Map> nodePropertiesToRemove, - @Name(value = "relPropertiesToRemove", defaultValue = "{}") Map> relPropertiesToRemove) { - + @Name(value = "nodePropertiesToRemove", defaultValue = "{}") + Map> nodePropertiesToRemove, + @Name(value = "relPropertiesToRemove", defaultValue = "{}") + Map> relPropertiesToRemove) { + VirtualGraphExtractor extractor = new VirtualGraphExtractor(nodePropertiesToRemove, relPropertiesToRemove); extractor.extract(value); - GraphResult result = new GraphResult( extractor.nodes(), extractor.rels() ); + GraphResult result = new GraphResult(extractor.nodes(), extractor.rels()); return Stream.of(result); } - + @UserAggregationFunction("apoc.graph.filterProperties") @Description( "apoc.graph.filterProperties(anyEntityObject, nodePropertiesToRemove, relPropertiesToRemove) - aggregation function which returns an object {node: [virtual nodes], relationships: [virtual relationships]} without the properties defined in nodePropertiesToRemove and relPropertiesToRemove") @@ -57,9 +58,11 @@ public static class GraphFunction { @UserAggregationUpdate public void filterProperties( @Name("value") Object value, - @Name(value = "nodePropertiesToRemove", defaultValue = "{}") Map> nodePropertiesToRemove, - @Name(value = "relPropertiesToRemove", defaultValue = "{}") Map> relPropertiesToRemove) { - + @Name(value = "nodePropertiesToRemove", defaultValue = "{}") + Map> nodePropertiesToRemove, + @Name(value = "relPropertiesToRemove", defaultValue = "{}") + Map> relPropertiesToRemove) { + if (virtualGraphExtractor == null) { virtualGraphExtractor = new VirtualGraphExtractor(nodePropertiesToRemove, relPropertiesToRemove); } @@ -72,20 +75,20 @@ public Object result() { Collection relationships = virtualGraphExtractor.rels(); return Map.of( NODES, nodes, - RELATIONSHIPS, relationships - ); + RELATIONSHIPS, relationships); } } public static class VirtualGraphExtractor { private static final String ALL_FILTER = "_all"; - - private final Map nodes; - private final Map rels; + + private final Map nodes; + private final Map rels; private final Map> nodePropertiesToRemove; private final Map> relPropertiesToRemove; - public VirtualGraphExtractor(Map> nodePropertiesToRemove, Map> relPropertiesToRemove) { + public VirtualGraphExtractor( + Map> nodePropertiesToRemove, Map> relPropertiesToRemove) { this.nodes = new HashMap<>(); this.rels = new HashMap<>(); this.nodePropertiesToRemove = nodePropertiesToRemove; @@ -96,26 +99,31 @@ public void extract(Object value) { if (value == null) { return; } - if (value instanceof Node node) { + if (value instanceof Node) { + Node node = (Node) value; addVirtualNode(node); - - } else if (value instanceof Relationship rel) { + + } else if (value instanceof Relationship) { + Relationship rel = (Relationship) value; addVirtualRel(rel); - - } else if (value instanceof Path path) { + + } else if (value instanceof Path) { + Path path = (Path) value; path.nodes().forEach(this::addVirtualNode); path.relationships().forEach(this::addVirtualRel); - + } else if (value instanceof Iterable) { ((Iterable) value).forEach(this::extract); - - } else if (value instanceof Map map) { + + } else if (value instanceof Map) { + Map map = (Map) value; map.values().forEach(this::extract); - + } else if (value instanceof Iterator) { ((Iterator) value).forEachRemaining(this::extract); - - } else if (value instanceof Object[] array) { + + } else if (value instanceof Object[]) { + Object[] array = (Object[]) value; for (Object i : array) { extract(i); } @@ -123,20 +131,20 @@ public void extract(Object value) { } /** - * We can use the elementId as a unique key for virtual nodes/relations, + * We can use the elementId as a unique key for virtual nodes/relations, * as it is the same as the analogue for real nodes/relations. */ private void addVirtualRel(Relationship rel) { - rels.putIfAbsent(rel.getElementId(), createVirtualRel(rel)); + rels.putIfAbsent(rel.getId(), createVirtualRel(rel)); } private void addVirtualNode(Node node) { - nodes.putIfAbsent(node.getElementId(), createVirtualNode(node)); + nodes.putIfAbsent(node.getId(), createVirtualNode(node)); } private Node createVirtualNode(Node startNode) { List props = Iterables.asList(startNode.getPropertyKeys()); - nodePropertiesToRemove.forEach((k,v) -> { + nodePropertiesToRemove.forEach((k, v) -> { if (k.equals(ALL_FILTER) || startNode.hasLabel(Label.label(k))) { props.removeAll(v); } @@ -147,14 +155,14 @@ private Node createVirtualNode(Node startNode) { private Relationship createVirtualRel(Relationship rel) { Node startNode = rel.getStartNode(); - startNode = nodes.putIfAbsent(startNode.getElementId(), createVirtualNode(startNode)); + startNode = nodes.putIfAbsent(startNode.getId(), createVirtualNode(startNode)); Node endNode = rel.getEndNode(); - endNode = nodes.putIfAbsent(endNode.getElementId(), createVirtualNode(endNode)); - + endNode = nodes.putIfAbsent(endNode.getId(), createVirtualNode(endNode)); + Map props = rel.getAllProperties(); - - relPropertiesToRemove.forEach((k,v) -> { + + relPropertiesToRemove.forEach((k, v) -> { if (k.equals(ALL_FILTER) || rel.isType(RelationshipType.withName(k))) { v.forEach(props.keySet()::remove); } diff --git a/extended/src/main/java/apoc/map/MapsExtended.java b/full/src/main/java/apoc/map/MapsExtended.java similarity index 53% rename from extended/src/main/java/apoc/map/MapsExtended.java rename to full/src/main/java/apoc/map/MapsExtended.java index 7dcbaa7e88..76e40ec5dd 100644 --- a/extended/src/main/java/apoc/map/MapsExtended.java +++ b/full/src/main/java/apoc/map/MapsExtended.java @@ -2,23 +2,23 @@ import apoc.Extended; import apoc.util.Util; -import org.neo4j.procedure.Description; -import org.neo4j.procedure.Name; -import org.neo4j.procedure.UserFunction; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.UserFunction; @Extended public class MapsExtended { @UserFunction("apoc.map.renameKey") @Description("Rename the given key(s) in the `MAP`.") - public Map renameKeyRecursively(@Name("map") Map map, - @Name("keyFrom") String keyFrom, - @Name("keyTo") String keyTo, - @Name(value = "config", defaultValue = "{}") Map config) { + public Map renameKeyRecursively( + @Name("map") Map map, + @Name("keyFrom") String keyFrom, + @Name("keyTo") String keyTo, + @Name(value = "config", defaultValue = "{}") Map config) { boolean recursive = Util.toBoolean(config.getOrDefault("recursive", true)); if (recursive) { return (Map) renameKeyRecursively(map, keyFrom, keyTo); @@ -32,22 +32,21 @@ public Map renameKeyRecursively(@Name("map") Map private Object renameKeyRecursively(Object object, String keyFrom, String keyTo) { if (object instanceof Map) { - return ((Map) object).entrySet() - .stream() - .collect(Collectors.toMap( - e -> { - String key = e.getKey(); - return key.equals(keyFrom) ? keyTo : key; - }, - e -> renameKeyRecursively(e.getValue(), keyFrom, keyTo)) - ); + return ((Map) object) + .entrySet().stream() + .collect(Collectors.toMap( + e -> { + String key = e.getKey(); + return key.equals(keyFrom) ? keyTo : key; + }, + e -> renameKeyRecursively(e.getValue(), keyFrom, keyTo))); } - if (object instanceof List subList) { + if (object instanceof List) { + List subList = (List) object; return subList.stream() .map(v -> renameKeyRecursively(v, keyFrom, keyTo)) - .toList(); + .collect(Collectors.toList()); } return object; } - } diff --git a/full/src/main/java/apoc/ml/RestAPIConfig.java b/full/src/main/java/apoc/ml/RestAPIConfig.java new file mode 100644 index 0000000000..d560e9adb8 --- /dev/null +++ b/full/src/main/java/apoc/ml/RestAPIConfig.java @@ -0,0 +1,100 @@ +package apoc.ml; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +// TODO - could be moved to `apoc.util` package? +public class RestAPIConfig { + public static final String HEADERS_KEY = "headers"; + public static final String METHOD_KEY = "method"; + public static final String ENDPOINT_KEY = "endpoint"; + public static final String JSON_PATH_KEY = "jsonPath"; + public static final String BODY_KEY = "body"; + public static final String BASE_URL_KEY = "baseUrl"; + + // used internally to handle multiple endpoints, like in `apoc.vectordb.weaviate.get` + // the config documented is the `endpoint` one + private final String baseUrl; + + private Map headers; + private Map body; + private String endpoint; + private String jsonPath; + + public RestAPIConfig(Map config) { + this(config, Map.of(), Map.of()); + } + + public RestAPIConfig(Map config, Map additionalHeaders, Map additionalBodies) { + if (config == null) { + config = Collections.emptyMap(); + } + + String httpMethod = (String) config.getOrDefault(METHOD_KEY, "POST"); + + this.headers = populateHeaders(config, additionalHeaders, httpMethod); + + this.endpoint = (String) config.get(ENDPOINT_KEY); + this.baseUrl = (String) config.get(BASE_URL_KEY); + + this.jsonPath = (String) config.get(JSON_PATH_KEY); + this.body = populateBody(config, additionalBodies); + } + + private static Map populateHeaders( + Map config, Map additionalHeaders, String httpMethod) { + Map headerConf = (Map) config.getOrDefault(HEADERS_KEY, new HashMap<>()); + headerConf.putIfAbsent("content-type", "application/json"); + headerConf.putIfAbsent(METHOD_KEY, httpMethod); + additionalHeaders.forEach(headerConf::putIfAbsent); + return headerConf; + } + + private static Map populateBody(Map config, Map additionalBodies) { + Map bodyConf = (Map) config.getOrDefault(BODY_KEY, new HashMap<>()); + + // if we force body to be null, e.g. with Http GET operations that doesn't allow payloads, + // we skip additional body addition + if (bodyConf != null) { + additionalBodies.forEach(bodyConf::putIfAbsent); + } + return bodyConf; + } + + public Map getHeaders() { + return headers; + } + + public Map getBody() { + return body; + } + + public String getEndpoint() { + return endpoint; + } + + public String getBaseUrl() { + return baseUrl; + } + + public String getJsonPath() { + return jsonPath; + } + + public void setHeaders(Map headers) { + this.headers = headers; + } + + public void setBody(Map body) { + this.body = body; + } + + public void setEndpoint(String endpoint) { + this.endpoint = endpoint; + } + + public void setJsonPath(String jsonPath) { + this.jsonPath = jsonPath; + } +} diff --git a/full/src/main/java/apoc/vectordb/ChromaDb.java b/full/src/main/java/apoc/vectordb/ChromaDb.java new file mode 100644 index 0000000000..e8b6947ed2 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/ChromaDb.java @@ -0,0 +1,261 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.*; +import static apoc.util.MapUtil.map; +import static apoc.util.Util.listOfMapToMapOfLists; +import static apoc.vectordb.VectorDb.executeRequest; +import static apoc.vectordb.VectorDb.getEmbeddingResultStream; +import static apoc.vectordb.VectorDbHandler.Type.CHROMA; +import static apoc.vectordb.VectorDbUtil.*; +import static apoc.vectordb.VectorEmbeddingConfig.*; + +import apoc.Extended; +import apoc.ml.RestAPIConfig; +import apoc.result.ListResult; +import apoc.result.MapResult; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.commons.collections4.CollectionUtils; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +@Extended +public class ChromaDb { + public static final VectorDbHandler DB_HANDLER = CHROMA.get(); + + @Context + public ProcedureCallContext procedureCallContext; + + @Context + public Transaction tx; + + @Context + public GraphDatabaseService db; + + @Procedure("apoc.vectordb.chroma.createCollection") + @Description( + "apoc.vectordb.chroma.createCollection(hostOrKey, collection, similarity, size, $configuration) - Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`") + public Stream createCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("similarity") String similarity, + @Name("size") Long size, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + String url = "%s/api/v1/collections"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map metadata = Map.of("hnsw:space", similarity, "size", size); + Map additionalBodies = Map.of("name", collection, "metadata", metadata); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.chroma.deleteCollection") + @Description( + "apoc.vectordb.chroma.deleteCollection(hostOrKey, collection, $configuration) - Deletes a collection with the name specified in the 2nd parameter") + public Stream deleteCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + String url = "%s/api/v1/collections/%s"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "DELETE"); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), Map.of()); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.chroma.upsert") + @Description( + "apoc.vectordb.chroma.upsert(hostOrKey, collection, vectors, $configuration) - Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]") + public Stream upsert( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List> vectors, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + String url = "%s/api/v1/collections/%s/upsert"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + Map mapKeys = + Map.of("id", "ids", "vector", "embeddings", "metadata", "metadatas", "text", "documents"); + + // transform to format digestible by RestAPI, + // that is from [{id: , vector: ,,,}, {id: , vector: ,,,}] + // to {ids: [, ], vectors: [, ]} + Map additionalBodies = listOfMapToMapOfLists(mapKeys, vectors); + additionalBodies.compute("ids", (k, v) -> getStringIds(v)); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure(value = "apoc.vectordb.chroma.delete") + @Description( + "apoc.vectordb.chroma.delete(hostOrKey, collection, ids, $configuration) - Deletes the vectors with the specified `ids`") + public Stream delete( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + String url = "%s/api/v1/collections/%s/delete"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, getStringIds(ids)); + return executeRequest(apiConfig.getApiConfig()).map(v -> (List) v).map(ListResult::new); + } + + @Procedure(value = "apoc.vectordb.chroma.get") + @Description( + "apoc.vectordb.chroma.get(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`") + public Stream get( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return getCommon(hostOrKey, collection, ids, configuration, true); + } + + @Procedure(value = "apoc.vectordb.chroma.getAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.chroma.getAndUpdate(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`, and optionally creates/updates neo4j entities") + public Stream getAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return getCommon(hostOrKey, collection, ids, configuration, false); + } + + private Stream getCommon( + String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) + throws Exception { + String url = "%s/api/v1/collections/%s/get"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.chroma.getAndUpdate"); + } + + VectorEmbeddingConfig apiConfig = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids); + return getEmbeddingResultStream(apiConfig, procedureCallContext, tx, v -> listToMap((Map) v).stream()); + } + + @Procedure(value = "apoc.vectordb.chroma.query") + @Description( + "apoc.vectordb.chroma.query(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter") + public Stream query( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + } + + @Procedure(value = "apoc.vectordb.chroma.queryAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.chroma.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter, and optionally creates/updates neo4j entities") + public Stream queryAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + } + + private Stream queryCommon( + String hostOrKey, + String collection, + List vector, + Map filter, + long limit, + Map configuration, + boolean readOnly) + throws Exception { + String url = "%s/api/v1/collections/%s/query"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.chroma.queryAndUpdate"); + } + + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + return getEmbeddingResultStream(apiConfig, procedureCallContext, tx, v -> listOfListsToMap((Map) v).stream()); + } + + private Map getVectorDbInfo( + String hostOrKey, String collection, Map configuration, String templateUrl) { + return getCommonVectorDbInfo(hostOrKey, collection, configuration, templateUrl, DB_HANDLER); + } + + private static List listOfListsToMap(Map startMap) { + List distances = startMap.get("distances") == null ? null : ((List) startMap.get("distances")).get(0); + List metadatas = startMap.get("metadatas") == null ? null : ((List) startMap.get("metadatas")).get(0); + List documents = startMap.get("documents") == null ? null : ((List) startMap.get("documents")).get(0); + List embeddings = startMap.get("embeddings") == null ? null : ((List) startMap.get("embeddings")).get(0); + + List ids = ((List) startMap.get("ids")).get(0); + + return getMaps(distances, metadatas, documents, embeddings, ids); + } + + private static List listToMap(Map startMap) { + List distances = (List) startMap.get("distances"); + List metadatas = (List) startMap.get("metadatas"); + List documents = (List) startMap.get("documents"); + List embeddings = (List) startMap.get("embeddings"); + + List ids = (List) startMap.get("ids"); + + return getMaps(distances, metadatas, documents, embeddings, ids); + } + + private static List getMaps(List distances, List metadatas, List documents, List embeddings, List ids) { + final List result = new ArrayList<>(); + for (int i = 0; i < ids.size(); i++) { + Map map = map(DEFAULT_ID, ids.get(i)); + if (CollectionUtils.isNotEmpty(distances)) { + map.put(DEFAULT_SCORE, distances.get(i)); + } + if (CollectionUtils.isNotEmpty(metadatas)) { + map.put(DEFAULT_METADATA, metadatas.get(i)); + } + if (CollectionUtils.isNotEmpty(documents)) { + map.put(DEFAULT_TEXT, documents.get(i)); + } + if (CollectionUtils.isNotEmpty(embeddings)) { + map.put(DEFAULT_VECTOR, embeddings.get(i)); + } + result.add(map); + } + + return result; + } + + private List getStringIds(List ids) { + return ids.stream().map(Object::toString).collect(Collectors.toList()); + } +} diff --git a/full/src/main/java/apoc/vectordb/ChromaHandler.java b/full/src/main/java/apoc/vectordb/ChromaHandler.java new file mode 100644 index 0000000000..6442f5ee99 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/ChromaHandler.java @@ -0,0 +1,84 @@ +package apoc.vectordb; + +import static apoc.util.MapUtil.map; + +import apoc.util.UrlResolver; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; + +public class ChromaHandler implements VectorDbHandler { + @Override + public String getUrl(String hostOrKey) { + return new UrlResolver("http", "localhost", 8000).getUrl("chroma", hostOrKey); + } + + @Override + public VectorEmbeddingHandler getEmbedding() { + return new ChromaEmbeddingHandler(); + } + + @Override + public String getLabel() { + return "Chroma"; + } + + // -- embedding handler + static class ChromaEmbeddingHandler implements VectorEmbeddingHandler { + + @Override + public VectorEmbeddingConfig fromGet( + Map config, ProcedureCallContext procedureCallContext, List ids) { + + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + + VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config); + Map additionalBodies = map("ids", ids); + + return getVectorEmbeddingConfig(conf, fields, additionalBodies); + } + + @Override + public VectorEmbeddingConfig fromQuery( + Map config, + ProcedureCallContext procedureCallContext, + List vector, + Object filter, + long limit, + String collection) { + + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + + VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config); + Map additionalBodies = + map("query_embeddings", List.of(vector), "where", filter, "n_results", limit); + + return getVectorEmbeddingConfig(conf, fields, additionalBodies); + } + + // "include": [metadatas, embeddings, ...] return the metadata/embeddings/... if included in the list + // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding + private static VectorEmbeddingConfig getVectorEmbeddingConfig( + VectorEmbeddingConfig config, List fields, Map additionalBodies) { + ArrayList include = new ArrayList<>(); + if (fields.contains("metadata")) { + include.add("metadatas"); + } + if (fields.contains("text") && config.isAllResults()) { + include.add("documents"); + } + if (fields.contains("vector") && config.isAllResults()) { + include.add("embeddings"); + } + if (fields.contains("score")) { + include.add("distances"); + } + + additionalBodies.put("include", include); + + return VectorEmbeddingHandler.populateApiBodyRequest(config, additionalBodies); + } + } +} diff --git a/full/src/main/java/apoc/vectordb/Qdrant.java b/full/src/main/java/apoc/vectordb/Qdrant.java new file mode 100644 index 0000000000..b697e299a4 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/Qdrant.java @@ -0,0 +1,217 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.vectordb.VectorDb.executeRequest; +import static apoc.vectordb.VectorDb.getEmbeddingResultStream; +import static apoc.vectordb.VectorDbHandler.Type.QDRANT; +import static apoc.vectordb.VectorDbUtil.*; + +import apoc.Extended; +import apoc.ml.RestAPIConfig; +import apoc.result.MapResult; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +@Extended +public class Qdrant { + public static final VectorDbHandler DB_HANDLER = QDRANT.get(); + + @Context + public ProcedureCallContext procedureCallContext; + + @Context + public Transaction tx; + + @Context + public GraphDatabaseService db; + + @Procedure("apoc.vectordb.qdrant.createCollection") + @Description( + "apoc.vectordb.qdrant.createCollection(hostOrKey, collection, similarity, size, $configuration) - Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`") + public Stream createCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("similarity") String similarity, + @Name("size") Long size, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + String url = "%s/collections/%s"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "PUT"); + + Map additionalBodies = Map.of( + "vectors", + Map.of( + "size", size, + "distance", similarity)); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.qdrant.deleteCollection") + @Description( + "apoc.vectordb.qdrant.deleteCollection(hostOrKey, collection, $configuration) - Deletes a collection with the name specified in the 2nd parameter") + public Stream deleteCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/collections/%s"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "DELETE"); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.qdrant.upsert") + @Description( + "apoc.vectordb.qdrant.upsert(hostOrKey, collection, vectors, $configuration) - Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]") + public Stream upsert( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List> vectors, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/collections/%s/points"; + + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "PUT"); + + List> point = vectors.stream() + .map(i -> { + Map map = new HashMap<>(i); + map.putIfAbsent("vector", map.remove("vector")); + map.putIfAbsent("payload", map.remove("metadata")); + return map; + }) + .collect(Collectors.toList()); + Map additionalBodies = Map.of("points", point); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.qdrant.delete") + @Description( + "apoc.vectordb.qdrant.delete(hostOrKey, collection, ids, $configuration) - Deletes the vectors with the specified `ids`") + public Stream delete( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + String url = "%s/collections/%s/points/delete"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map additionalBodies = Map.of("points", ids); + RestAPIConfig apiConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + return executeRequest(apiConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure(value = "apoc.vectordb.qdrant.get") + @Description( + "apoc.vectordb.qdrant.get(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`") + public Stream get( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return getCommon(hostOrKey, collection, ids, configuration, true); + } + + @Procedure(value = "apoc.vectordb.qdrant.getAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.qdrant.getAndUpdate(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`, and optionally creates/updates neo4j entities") + public Stream getAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return getCommon(hostOrKey, collection, ids, configuration, false); + } + + private Stream getCommon( + String hostOrKey, String collection, List ids, Map configuration, boolean readOnly) + throws Exception { + String url = "%s/collections/%s/points"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.qdrant.getAndUpdate"); + } + + VectorEmbeddingConfig apiConfig = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids); + return getEmbeddingResultStream(apiConfig, procedureCallContext, tx); + } + + @Procedure(value = "apoc.vectordb.qdrant.query") + @Description( + "apoc.vectordb.qdrant.query(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter") + public Stream query( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + } + + @Procedure(value = "apoc.vectordb.qdrant.queryAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.chroma.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter, and optionally creates/updates neo4j entities") + public Stream queryAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + } + + private Stream queryCommon( + String hostOrKey, + String collection, + List vector, + Map filter, + long limit, + Map configuration, + boolean readOnly) + throws Exception { + String url = "%s/collections/%s/points/search"; + Map config = getVectorDbInfo(hostOrKey, collection, configuration, url); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.qdrant.queryAndUpdate"); + } + + VectorEmbeddingConfig apiConfig = + DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + return getEmbeddingResultStream(apiConfig, procedureCallContext, tx); + } + + private Map getVectorDbInfo( + String hostOrKey, String collection, Map configuration, String templateUrl) { + return getCommonVectorDbInfo(hostOrKey, collection, configuration, templateUrl, DB_HANDLER); + } +} diff --git a/full/src/main/java/apoc/vectordb/QdrantHandler.java b/full/src/main/java/apoc/vectordb/QdrantHandler.java new file mode 100644 index 0000000000..1a9f9954d1 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/QdrantHandler.java @@ -0,0 +1,76 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.JSON_PATH_KEY; +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.MapUtil.map; +import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY; + +import apoc.util.UrlResolver; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; + +public class QdrantHandler implements VectorDbHandler { + + @Override + public String getUrl(String hostOrKey) { + return new UrlResolver("http", "localhost", 6333).getUrl("qdrant", hostOrKey); + } + + @Override + public VectorEmbeddingHandler getEmbedding() { + return new QdrantEmbeddingHandler(); + } + + @Override + public String getLabel() { + return "Qdrant"; + } + + // -- embedding handler + static class QdrantEmbeddingHandler implements VectorEmbeddingHandler { + + @Override + public VectorEmbeddingConfig fromGet( + Map config, ProcedureCallContext procedureCallContext, List ids) { + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map additionalBodies = map("ids", ids); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + @Override + public VectorEmbeddingConfig fromQuery( + Map config, + ProcedureCallContext procedureCallContext, + List vector, + Object filter, + long limit, + String collection) { + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + + Map additionalBodies = map("vector", vector, "filter", filter, "limit", limit); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + // "with_payload": and "with_vectors": return the metadata and vector, if true + // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding + private static VectorEmbeddingConfig getVectorEmbeddingConfig( + Map config, List fields, Map additionalBodies) { + config.putIfAbsent(VECTOR_KEY, "vector"); + config.putIfAbsent(METADATA_KEY, "payload"); + config.putIfAbsent(JSON_PATH_KEY, "result"); + + VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config); + additionalBodies.put("with_payload", fields.contains("metadata")); + additionalBodies.put("with_vectors", fields.contains("vector") && conf.isAllResults()); + + return VectorEmbeddingHandler.populateApiBodyRequest(conf, additionalBodies); + } + } +} diff --git a/full/src/main/java/apoc/vectordb/VectorDb.java b/full/src/main/java/apoc/vectordb/VectorDb.java new file mode 100644 index 0000000000..64f4c8640d --- /dev/null +++ b/full/src/main/java/apoc/vectordb/VectorDb.java @@ -0,0 +1,286 @@ +package apoc.vectordb; + +import static apoc.util.JsonUtil.OBJECT_MAPPER; +import static apoc.util.SystemDbUtil.withSystemDb; +import static apoc.util.Util.listOfNumbersToFloatArray; +import static apoc.util.Util.setProperties; +import static apoc.vectordb.VectorDbUtil.*; +import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; + +import apoc.Extended; +import apoc.SystemPropertyKeys; +import apoc.ml.RestAPIConfig; +import apoc.result.ObjectResult; +import apoc.util.JsonUtil; +import apoc.util.SystemDbUtil; +import apoc.util.Util; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.commons.collections4.MapUtils; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Label; +import org.neo4j.graphdb.MultipleFoundException; +import org.neo4j.graphdb.Node; +import org.neo4j.graphdb.Relationship; +import org.neo4j.graphdb.RelationshipType; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.helpers.collection.Pair; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.kernel.api.procedure.SystemProcedure; +import org.neo4j.procedure.Admin; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +/** + * Base class + */ +@Extended +public class VectorDb { + + @Context + public GraphDatabaseService db; + + @Context + public Transaction tx; + + @Context + public ProcedureCallContext procedureCallContext; + + /** + * We can use this procedure with every API that return something like this: + * ``` + * [ + * "idKey": "idValue", + * "scoreKey": 1, + * "vectorKey": [ ] + * "metadataKey": { .. }, + * "textKey": "..." + * ], + * [ + * ... + * ] + * ``` + * + * Otherwise, if the result is different (e.g. the Chroma result), we have to leverage the apoc.vectordb.custom, + * which return an Object, but we can't use it to filter result via `ProcedureCallContext procedureCallContext` + * and mapping data to fetch the associated nodes and relationships and optionally create them + */ + @Procedure(value = "apoc.vectordb.custom.get", mode = Mode.WRITE) + @Description( + "apoc.vectordb.custom.get(host, $configuration) - Customizable get / query procedure, which retrieves vectors from the host and the configuration map") + public Stream get( + @Name("host") String host, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + getEndpoint(configuration, host); + VectorEmbeddingConfig restAPIConfig = new VectorEmbeddingConfig(configuration); + return getEmbeddingResultStream(restAPIConfig, procedureCallContext, tx); + } + + public static Stream getEmbeddingResultStream( + VectorEmbeddingConfig conf, ProcedureCallContext procedureCallContext, Transaction tx) throws Exception { + return getEmbeddingResultStream(conf, procedureCallContext, tx, v -> ((List) v).stream()); + } + + public static Stream getEmbeddingResultStream( + VectorEmbeddingConfig conf, + ProcedureCallContext procedureCallContext, + Transaction tx, + Function> objectMapper) + throws Exception { + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + + boolean hasVector = fields.contains("vector") && conf.isAllResults(); + boolean hasMetadata = fields.contains("metadata"); + + VectorMappingConfig mapping = conf.getMapping(); + + return executeRequest(conf.getApiConfig()) + .flatMap(objectMapper) + .map(m -> getEmbeddingResult(conf, tx, hasVector, hasMetadata, mapping, m)); + } + + public static EmbeddingResult getEmbeddingResult( + VectorEmbeddingConfig conf, + Transaction tx, + boolean hasEmbedding, + boolean hasMetadata, + VectorMappingConfig mapping, + Map m) { + Object id = conf.isAllResults() ? m.get(conf.getIdKey()) : null; + List embedding = hasEmbedding ? (List) m.get(conf.getVectorKey()) : null; + Map metadata = hasMetadata ? (Map) m.get(conf.getMetadataKey()) : null; + // in case of get operation, e.g. http://localhost:52798/collections/{coll_name}/points with Qdrant db, + // score is not present + Double score = Util.toDouble(m.get(conf.getScoreKey())); + String text = conf.isAllResults() ? (String) m.get(conf.getTextKey()) : null; + + Entity entity = handleMapping(tx, mapping, metadata, embedding); + if (entity != null) entity = Util.rebind(tx, entity); + return new EmbeddingResult( + id, + score, + embedding, + metadata, + text, + mapping.getNodeLabel() == null ? null : (Node) entity, + mapping.getNodeLabel() != null ? null : (Relationship) entity); + } + + private static Entity handleMapping( + Transaction tx, VectorMappingConfig mapping, Map metadata, List embedding) { + if (mapping.getEntityKey() == null) { + return null; + } + if (MapUtils.isEmpty(metadata)) { + throw new RuntimeException( + "To use mapping config, the metadata should not be empty. Make sure you execute `YIELD metadata` on the procedure"); + } + Map metaProps = new HashMap<>(metadata); + if (mapping.getNodeLabel() != null) { + return handleMappingNode(tx, mapping, metaProps, embedding); + } else if (mapping.getRelType() != null) { + return handleMappingRel(tx, mapping, metaProps, embedding); + } else { + throw new RuntimeException("Mapping conf has to contain either label or type key"); + } + } + + private static Entity handleMappingNode( + Transaction transaction, + VectorMappingConfig mapping, + Map metaProps, + List embedding) { + try { + Node node; + Object propValue = metaProps.get(mapping.getMetadataKey()); + node = transaction.findNode(Label.label(mapping.getNodeLabel()), mapping.getEntityKey(), propValue); + if (node == null && mapping.isCreate()) { + node = transaction.createNode(Label.label(mapping.getNodeLabel())); + node.setProperty(mapping.getEntityKey(), propValue); + } + if (node != null) { + setProperties(node, metaProps); + setVectorProp(mapping, embedding, node); + } + + return node; + } catch (MultipleFoundException e) { + throw new RuntimeException("Multiple nodes found"); + } + } + + private static Entity handleMappingRel( + Transaction transaction, + VectorMappingConfig mapping, + Map metaProps, + List embedding) { + try { + // in this case we cannot auto-create the rel, since we should have to define start and end node as well + Relationship rel; + Object propValue = metaProps.get(mapping.getMetadataKey()); + rel = transaction.findRelationship( + RelationshipType.withName(mapping.getRelType()), mapping.getEntityKey(), propValue); + if (rel != null) { + setProperties(rel, metaProps); + setVectorProp(mapping, embedding, rel); + } + + return rel; + } catch (MultipleFoundException e) { + throw new RuntimeException("Multiple relationships found"); + } + } + + private static void setVectorProp( + VectorMappingConfig mapping, List embedding, T entity) { + if (mapping.getEmbeddingKey() == null) { + return; + } + + if (embedding == null) { + String embeddingErrMsg = String.format( + "The embedding value is null. Make sure you execute `YIELD embedding` on the procedure and you configured `%s: true`", + ALL_RESULTS_KEY); + throw new RuntimeException(embeddingErrMsg); + } + + float[] floats = listOfNumbersToFloatArray(embedding); + entity.setProperty(mapping.getEmbeddingKey(), floats); + } + + // TODO - evaluate. It could be renamed e.g. to `apoc.util.restapi.custom` or `apoc.restapi.custom`, + // since it can potentially be used as a generic method to call any RestAPI + @Procedure("apoc.vectordb.custom") + @Description( + "apoc.vectordb.custom(host, $configuration) - fully customizable procedure, returns generic object results") + public Stream custom( + @Name("host") String host, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + + getEndpoint(configuration, host); + RestAPIConfig restAPIConfig = new RestAPIConfig(configuration); + return executeRequest(restAPIConfig).map(ObjectResult::new); + } + + public static Stream executeRequest(RestAPIConfig apiConfig) throws Exception { + Map headers = apiConfig.getHeaders(); + Map configBody = apiConfig.getBody(); + String body = configBody == null ? null : OBJECT_MAPPER.writeValueAsString(configBody); + + String endpoint = apiConfig.getEndpoint(); + if (endpoint == null) { + throw new RuntimeException("Endpoint must be specified"); + } + + return JsonUtil.loadJson(endpoint, headers, body, apiConfig.getJsonPath(), true, List.of()); + } + + @Admin + @SystemProcedure + @Procedure(name = "apoc.vectordb.configure") + @Description( + "CALL apoc.vectordb.configure(vectorName, host, credentialsValue, mapping) - To configure, given the vector defined by the 1st parameter, `host`, `credentials` and `mapping` into the system db") + public void vectordb( + @Name("vectorName") String vectorName, + @Name("configKey") String configKey, + @Name("databaseName") String databaseName, + @Name(value = "config", defaultValue = "{}") Map config) { + SystemDbUtil.checkInSystemLeader(db); + SystemDbUtil.checkTargetDatabase(tx, databaseName, "Vector DB configuration"); + + VectorDbHandler.Type type = VectorDbHandler.Type.valueOf(vectorName.toUpperCase()); + + withSystemDb(transaction -> { + Label label = Label.label(type.get().getLabel()); + Node node = Util.mergeNode(transaction, label, null, Pair.of(SystemPropertyKeys.name.name(), configKey)); + + Map mapping = (Map) config.get("mapping"); + String host = (String) config.get("host"); + Object credentials = config.get("credentials"); + + if (host != null) { + node.setProperty(SystemPropertyKeys.host.name(), host); + } + + if (credentials != null) { + node.setProperty(SystemPropertyKeys.credentials.name(), Util.toJson(credentials)); + } + + if (mapping != null) { + node.setProperty(MAPPING_KEY, Util.toJson(mapping)); + } + }); + } +} diff --git a/full/src/main/java/apoc/vectordb/VectorDbHandler.java b/full/src/main/java/apoc/vectordb/VectorDbHandler.java new file mode 100644 index 0000000000..894b805646 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/VectorDbHandler.java @@ -0,0 +1,37 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.HEADERS_KEY; + +import java.util.HashMap; +import java.util.Map; + +public interface VectorDbHandler { + default Map getCredentials(Object credentialsObj, Map config) { + Map headers = (Map) config.getOrDefault(HEADERS_KEY, new HashMap<>()); + headers.putIfAbsent("Authorization", "Bearer " + credentialsObj); + config.put(HEADERS_KEY, headers); + return config; + } + + String getUrl(String hostOrKey); + + VectorEmbeddingHandler getEmbedding(); + + String getLabel(); + + enum Type { + CHROMA(new ChromaHandler()), + QDRANT(new QdrantHandler()), + WEAVIATE(new WeaviateHandler()); + + private final VectorDbHandler handler; + + Type(VectorDbHandler handler) { + this.handler = handler; + } + + public VectorDbHandler get() { + return handler; + } + } +} diff --git a/full/src/main/java/apoc/vectordb/VectorDbUtil.java b/full/src/main/java/apoc/vectordb/VectorDbUtil.java new file mode 100644 index 0000000000..48218a1bf4 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -0,0 +1,113 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.BASE_URL_KEY; +import static apoc.ml.RestAPIConfig.ENDPOINT_KEY; +import static apoc.util.SystemDbUtil.withSystemDb; +import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; + +import apoc.SystemPropertyKeys; +import apoc.util.Util; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.commons.collections.MapUtils; +import org.neo4j.graphdb.Label; +import org.neo4j.graphdb.Node; +import org.neo4j.graphdb.Relationship; + +public class VectorDbUtil { + + public static final String ERROR_READONLY_MAPPING = + "The mapping is not possible with this procedure, as it is read-only."; + + /** + * we can configure the endpoint via config map or via hostOrKey parameter, + * to handle potential endpoint changes. + * For example, in Qdrant `BASE_URL/collections/COLLECTION_NAME/points` could change in the future. + */ + public static void getEndpoint(Map config, String endpoint) { + config.putIfAbsent(ENDPOINT_KEY, endpoint); + } + + /** + * Result of `apoc.vectordb.*.get` and `apoc.vectordb.*.query` procedures + */ + public static final class EmbeddingResult { + public final Object id; + public final Double score; + public final List vector; + public final Map metadata; + public final String text; + public final Node node; + public final Relationship rel; + + public EmbeddingResult( + Object id, + Double score, + List vector, + Map metadata, + String text, + Node node, + Relationship rel) { + this.id = id; + this.score = score; + this.vector = vector; + this.metadata = metadata; + this.text = text; + this.node = node; + this.rel = rel; + } + } + + public static Map getCommonVectorDbInfo( + String hostOrKey, + String collection, + Map configuration, + String templateUrl, + VectorDbHandler handler) { + Map config = new HashMap<>(configuration); + + Map props = withSystemDb(transaction -> { + Label label = Label.label(handler.getLabel()); + Node node = transaction.findNode(label, SystemPropertyKeys.name.name(), hostOrKey); + return node == null ? Map.of() : node.getAllProperties(); + }); + + String url = getUrl(hostOrKey, handler, props); + config.put(BASE_URL_KEY, url); + + Map mappingConfVal = (Map) config.get(MAPPING_KEY); + if (MapUtils.isEmpty(mappingConfVal)) { + String mappingStoreVal = (String) props.get(MAPPING_KEY); + if (mappingStoreVal != null) { + config.put(MAPPING_KEY, Util.fromJson(mappingStoreVal, Map.class)); + } + } + + String credentials = (String) props.get(SystemPropertyKeys.credentials.name()); + if (credentials != null) { + Object credentialsObj = Util.fromJson(credentials, Object.class); + + config = handler.getCredentials(credentialsObj, config); + } + + String endpoint = String.format(templateUrl, url, collection); + getEndpoint(config, endpoint); + + return config; + } + + private static String getUrl(String hostOrKey, VectorDbHandler handler, Map props) { + if (props.isEmpty()) { + return handler.getUrl(hostOrKey); + } + return (String) props.get(SystemPropertyKeys.host.name()); + } + + public static void checkMappingConf(Map configuration, String procName) { + if (configuration.containsKey(MAPPING_KEY)) { + throw new RuntimeException( + ERROR_READONLY_MAPPING + "\n" + "Try the equivalent procedure, which is the " + procName); + } + } +} diff --git a/full/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java b/full/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java new file mode 100644 index 0000000000..6c937e87af --- /dev/null +++ b/full/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java @@ -0,0 +1,76 @@ +package apoc.vectordb; + +import apoc.ml.RestAPIConfig; +import apoc.util.Util; +import java.util.Map; + +public class VectorEmbeddingConfig { + public static final String FIELDS_KEY = "fields"; + public static final String VECTOR_KEY = "vectorKey"; + public static final String METADATA_KEY = "metadataKey"; + public static final String SCORE_KEY = "scoreKey"; + public static final String TEXT_KEY = "textKey"; + public static final String ID_KEY = "idKey"; + public static final String MAPPING_KEY = "mapping"; + + public static final String DEFAULT_ID = "id"; + public static final String DEFAULT_TEXT = "text"; + public static final String DEFAULT_VECTOR = "vector"; + public static final String DEFAULT_METADATA = "metadata"; + public static final String DEFAULT_SCORE = "score"; + public static final String ALL_RESULTS_KEY = "allResults"; + + private final String idKey; + private final String textKey; + private final String vectorKey; + private final String metadataKey; + private final String scoreKey; + private final boolean allResults; + + private final VectorMappingConfig mapping; + private final RestAPIConfig apiConfig; + + public VectorEmbeddingConfig(Map config) { + this.vectorKey = (String) config.getOrDefault(VECTOR_KEY, DEFAULT_VECTOR); + this.metadataKey = (String) config.getOrDefault(METADATA_KEY, DEFAULT_METADATA); + this.scoreKey = (String) config.getOrDefault(SCORE_KEY, DEFAULT_SCORE); + this.idKey = (String) config.getOrDefault(ID_KEY, DEFAULT_ID); + this.textKey = (String) config.getOrDefault(TEXT_KEY, DEFAULT_TEXT); + this.allResults = Util.toBoolean(config.get(ALL_RESULTS_KEY)); + this.mapping = new VectorMappingConfig((Map) config.getOrDefault(MAPPING_KEY, Map.of())); + + this.apiConfig = new RestAPIConfig(config); + } + + public String getIdKey() { + return idKey; + } + + public String getVectorKey() { + return vectorKey; + } + + public String getMetadataKey() { + return metadataKey; + } + + public String getScoreKey() { + return scoreKey; + } + + public String getTextKey() { + return textKey; + } + + public boolean isAllResults() { + return allResults; + } + + public VectorMappingConfig getMapping() { + return mapping; + } + + public RestAPIConfig getApiConfig() { + return apiConfig; + } +} diff --git a/full/src/main/java/apoc/vectordb/VectorEmbeddingHandler.java b/full/src/main/java/apoc/vectordb/VectorEmbeddingHandler.java new file mode 100644 index 0000000000..e58eb45453 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/VectorEmbeddingHandler.java @@ -0,0 +1,32 @@ +package apoc.vectordb; + +import static apoc.vectordb.VectorEmbeddingConfig.*; + +import apoc.ml.RestAPIConfig; +import java.util.List; +import java.util.Map; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; + +public interface VectorEmbeddingHandler { + + VectorEmbeddingConfig fromGet( + Map config, ProcedureCallContext procedureCallContext, List ids); + + VectorEmbeddingConfig fromQuery( + Map config, + ProcedureCallContext procedureCallContext, + List vector, + Object filter, + long limit, + String collection); + + static VectorEmbeddingConfig populateApiBodyRequest( + VectorEmbeddingConfig config, Map additionalBodies) { + + RestAPIConfig apiConfig = config.getApiConfig(); + Map body = apiConfig.getBody(); + if (body != null) additionalBodies.forEach(body::putIfAbsent); + apiConfig.setBody(body); + return config; + } +} diff --git a/full/src/main/java/apoc/vectordb/VectorMappingConfig.java b/full/src/main/java/apoc/vectordb/VectorMappingConfig.java new file mode 100644 index 0000000000..d925517a80 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/VectorMappingConfig.java @@ -0,0 +1,69 @@ +package apoc.vectordb; + +import apoc.util.Util; +import java.util.Collections; +import java.util.Map; + +public class VectorMappingConfig { + public static final String METADATA_KEY = "metadataKey"; + public static final String ENTITY_KEY = "entityKey"; + public static final String NODE_LABEL = "nodeLabel"; + public static final String REL_TYPE = "relType"; + public static final String EMBEDDING_KEY = "embeddingKey"; + public static final String SIMILARITY_KEY = "similarity"; + public static final String CREATE_KEY = "create"; + + private final String metadataKey; + private final String entityKey; + + private final String nodeLabel; + private final String relType; + private final String embeddingKey; + private final String similarity; + + private final boolean create; + + public VectorMappingConfig(Map mapping) { + if (mapping == null) { + mapping = Collections.emptyMap(); + } + this.metadataKey = (String) mapping.get(METADATA_KEY); + this.entityKey = (String) mapping.get(ENTITY_KEY); + + this.nodeLabel = (String) mapping.get(NODE_LABEL); + this.relType = (String) mapping.get(REL_TYPE); + this.embeddingKey = (String) mapping.get(EMBEDDING_KEY); + + this.similarity = (String) mapping.getOrDefault(SIMILARITY_KEY, "cosine"); + + this.create = Util.toBoolean(mapping.get(CREATE_KEY)); + } + + public String getMetadataKey() { + return metadataKey; + } + + public String getEntityKey() { + return entityKey; + } + + public String getNodeLabel() { + return nodeLabel; + } + + public String getRelType() { + return relType; + } + + public String getEmbeddingKey() { + return embeddingKey; + } + + public boolean isCreate() { + return create; + } + + public String getSimilarity() { + return similarity; + } +} diff --git a/full/src/main/java/apoc/vectordb/Weaviate.java b/full/src/main/java/apoc/vectordb/Weaviate.java new file mode 100644 index 0000000000..c56eb2b982 --- /dev/null +++ b/full/src/main/java/apoc/vectordb/Weaviate.java @@ -0,0 +1,275 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.Util.map; +import static apoc.vectordb.VectorDb.executeRequest; +import static apoc.vectordb.VectorDb.getEmbeddingResult; +import static apoc.vectordb.VectorDb.getEmbeddingResultStream; +import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE; +import static apoc.vectordb.VectorDbUtil.*; + +import apoc.Extended; +import apoc.ml.RestAPIConfig; +import apoc.result.ListResult; +import apoc.result.MapResult; +import apoc.util.UrlResolver; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +@Extended +public class Weaviate { + public static final VectorDbHandler DB_HANDLER = WEAVIATE.get(); + + @Context + public ProcedureCallContext procedureCallContext; + + @Context + public Transaction tx; + + @Context + public GraphDatabaseService db; + + @Procedure("apoc.vectordb.weaviate.createCollection") + @Description( + "apoc.vectordb.weaviate.createCollection(hostOrKey, collection, similarity, size, $configuration) - Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`") + public Stream createCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("similarity") String similarity, + @Name("size") Long size, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema"); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map additionalBodies = + Map.of("class", collection, "vectorIndexConfig", Map.of("distance", similarity, "size", size)); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies); + + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.weaviate.deleteCollection") + @Description( + "apoc.vectordb.weaviate.deleteCollection(hostOrKey, collection, $configuration) - Deletes a collection with the name specified in the 2nd parameter") + public Stream deleteCollection( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema/%s"); + config.putIfAbsent(METHOD_KEY, "DELETE"); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config); + return executeRequest(restAPIConfig).map(v -> (Map) v).map(MapResult::new); + } + + @Procedure("apoc.vectordb.weaviate.upsert") + @Description( + "apoc.vectordb.weaviate.upsert(hostOrKey, collection, vectors, $configuration) - Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '', medatada: ''}]") + public Stream upsert( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("vectors") List> vectors, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/objects"); + config.putIfAbsent(METHOD_KEY, "POST"); + + Map body = new HashMap<>(); + body.put("class", collection); + RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), body); + + return vectors.stream() + .flatMap(vector -> { + try { + Map configBody = new HashMap<>(restAPIConfig.getBody()); + configBody.putAll(vector); + configBody.put("properties", vector.remove("metadata")); + restAPIConfig.setBody(configBody); + + Stream objectStream = executeRequest(restAPIConfig); + return objectStream; + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .map(v -> (Map) v) + .map(MapResult::new); + } + + @Procedure(value = "apoc.vectordb.weaviate.delete") + @Description( + "apoc.vectordb.weaviate.delete(hostOrKey, collection, ids, $configuration) - Deletes the vectors with the specified `ids`") + public Stream delete( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + var config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema"); + config.putIfAbsent(METHOD_KEY, "DELETE"); + + RestAPIConfig restAPIConfig = new RestAPIConfig(config, map(), map()); + + List objects = ids.stream() + .peek(id -> { + String endpoint = String.format("%s/objects/%s/%s", restAPIConfig.getBaseUrl(), collection, id); + restAPIConfig.setEndpoint(endpoint); + try { + executeRequest(restAPIConfig); + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + return Stream.of(new ListResult(objects)); + } + + @Procedure(value = "apoc.vectordb.weaviate.getAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.weaviate.getAndUpdate(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`") + public Stream getAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) { + return getCommon(hostOrKey, collection, ids, configuration, false); + } + + @Procedure(value = "apoc.vectordb.weaviate.get") + @Description( + "apoc.vectordb.weaviate.get(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`") + public Stream get( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) { + return getCommon(hostOrKey, collection, ids, configuration, true); + } + + private Stream getCommon( + String hostOrKey, + String collection, + List ids, + Map configuration, + boolean readOnly) { + Map config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema"); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.chroma.getAndUpdate"); + } + + /** + * TODO: we put method: null as a workaround, it should be "GET": https://weaviate.io/developers/weaviate/api/rest#tag/objects/get/objects/{className}/{id} + * Since with `method: GET` the {@link apoc.util.Util#openUrlConnection(URL, Map)} has a `setChunkedStreamingMode` + * that makes the request to respond with an error 405 Method Not Allowed + */ + config.putIfAbsent(METHOD_KEY, null); + + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids); + boolean hasEmbedding = fields.contains("vector") && conf.isAllResults(); + boolean hasMetadata = fields.contains("metadata"); + VectorMappingConfig mapping = conf.getMapping(); + + String suffix = hasEmbedding ? "?include=vector" : ""; + + return ids.stream().flatMap(id -> { + String endpoint = + String.format("%s/objects/%s/%s", conf.getApiConfig().getBaseUrl(), collection, id) + suffix; + conf.getApiConfig().setEndpoint(endpoint); + try { + return executeRequest(conf.getApiConfig()) + .map(v -> (Map) v) + .map(m -> getEmbeddingResult(conf, tx, hasEmbedding, hasMetadata, mapping, m)); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + @Procedure(value = "apoc.vectordb.weaviate.query") + @Description( + "apoc.vectordb.weaviate.query(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter") + public Stream query( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "null") Object filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + checkMappingConf(configuration, "apoc.vectordb.weaviate.queryAndUpdate"); + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, true); + } + + @Procedure(value = "apoc.vectordb.weaviate.queryAndUpdate", mode = Mode.WRITE) + @Description( + "apoc.vectordb.weaviate.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter") + public Stream queryAndUpdate( + @Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "null") Object filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + return queryCommon(hostOrKey, collection, vector, filter, limit, configuration, false); + } + + private Stream queryCommon( + String hostOrKey, + String collection, + List vector, + Object filter, + long limit, + Map configuration, + boolean readOnly) + throws Exception { + Map config = getVectorDbInfo(hostOrKey, collection, configuration, "%s/graphql"); + + if (readOnly) { + checkMappingConf(configuration, "apoc.vectordb.weaviate.queryAndUpdate"); + } + + VectorEmbeddingConfig conf = + DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); + return getEmbeddingResultStream(conf, procedureCallContext, tx, v -> { + Object getValue = ((Map) v).get("data").get("Get"); + Object collectionValue = ((Map) getValue).get(collection); + return ((List) collectionValue).stream().map(i -> { + Map additional = (Map) i.remove("_additional"); + + Map map = new HashMap<>(); + map.put(conf.getMetadataKey(), i); + map.put(conf.getScoreKey(), additional.get("distance")); + map.put(conf.getIdKey(), additional.get("id")); + map.put(conf.getVectorKey(), additional.get("vector")); + return map; + }); + }); + } + + private Map getVectorDbInfo( + String hostOrKey, String collection, Map configuration, String templateUrl) { + return getCommonVectorDbInfo(hostOrKey, collection, configuration, templateUrl, DB_HANDLER); + } + + protected String getWeaviateUrl(String hostOrKey) { + String baseUrl = new UrlResolver("http", "localhost", 8000).getUrl("weaviate", hostOrKey); + return baseUrl + "/v1"; + } +} diff --git a/full/src/main/java/apoc/vectordb/WeaviateHandler.java b/full/src/main/java/apoc/vectordb/WeaviateHandler.java new file mode 100644 index 0000000000..b58ffd79af --- /dev/null +++ b/full/src/main/java/apoc/vectordb/WeaviateHandler.java @@ -0,0 +1,84 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.BODY_KEY; +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.MapUtil.map; +import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY; + +import apoc.util.UrlResolver; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; + +public class WeaviateHandler implements VectorDbHandler { + + @Override + public String getUrl(String hostOrKey) { + String url = new UrlResolver("http", "localhost", 8000).getUrl("weaviate", hostOrKey); + return url + "/v1"; + } + + @Override + public VectorEmbeddingHandler getEmbedding() { + return new WeaviateEmbeddingHandler(); + } + + @Override + public String getLabel() { + return "Weaviate"; + } + + // -- embedding handler + static class WeaviateEmbeddingHandler implements VectorEmbeddingHandler { + + @Override + public VectorEmbeddingConfig fromGet( + Map config, ProcedureCallContext procedureCallContext, List ids) { + config.putIfAbsent(BODY_KEY, null); + return VectorEmbeddingHandler.populateApiBodyRequest(getVectorEmbeddingConfig(config), Map.of()); + } + + @Override + public VectorEmbeddingConfig fromQuery( + Map config, + ProcedureCallContext procedureCallContext, + List vector, + Object filter, + long limit, + String collection) { + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); + config.putIfAbsent(METHOD_KEY, "POST"); + VectorEmbeddingConfig vectorEmbeddingConfig = getVectorEmbeddingConfig(config); + + List list = (List) config.get("fields"); + if (list == null) { + throw new RuntimeException("You have to define `field` list of parameter to be returned"); + } + Object fieldList = String.join("\n", list); + + filter = filter == null ? "" : ", where: " + filter; + + String includeVector = (fields.contains("vector") && vectorEmbeddingConfig.isAllResults()) ? ",vector" : ""; + String additional = "_additional {id, distance " + includeVector + "}"; + String query = String.format( + "{\n" + " Get {\n" + + " %s(limit: %s, nearVector: {vector: %s } %s) {%s %s}\n" + + " }\n" + + "}", + collection, limit, vector, filter, fieldList, additional); + + Map additionalBodies = map("query", query); + + return VectorEmbeddingHandler.populateApiBodyRequest(vectorEmbeddingConfig, additionalBodies); + } + + private static VectorEmbeddingConfig getVectorEmbeddingConfig(Map config) { + config.putIfAbsent(VECTOR_KEY, "vector"); + config.putIfAbsent(METADATA_KEY, "properties"); + + return new VectorEmbeddingConfig(config); + } + } +} diff --git a/full/src/main/resources/extended.txt b/full/src/main/resources/extended.txt index f52194c086..56aac858c2 100644 --- a/full/src/main/resources/extended.txt +++ b/full/src/main/resources/extended.txt @@ -204,3 +204,30 @@ apoc.trigger.propertiesByKey apoc.trigger.toNode apoc.trigger.toRelationship apoc.ttl.config +apoc.vectordb.chroma.createCollection +apoc.vectordb.chroma.deleteCollection +apoc.vectordb.chroma.upsert +apoc.vectordb.chroma.delete +apoc.vectordb.chroma.get +apoc.vectordb.chroma.getAndUpdate +apoc.vectordb.chroma.query +apoc.vectordb.chroma.queryAndUpdate +apoc.vectordb.qdrant.createCollection +apoc.vectordb.qdrant.deleteCollection +apoc.vectordb.qdrant.upsert +apoc.vectordb.qdrant.delete +apoc.vectordb.qdrant.get +apoc.vectordb.qdrant.getAndUpdate +apoc.vectordb.qdrant.query +apoc.vectordb.qdrant.queryAndUpdate +apoc.vectordb.weaviate.createCollection +apoc.vectordb.weaviate.deleteCollection +apoc.vectordb.weaviate.upsert +apoc.vectordb.weaviate.delete +apoc.vectordb.weaviate.get +apoc.vectordb.weaviate.getAndUpdate +apoc.vectordb.weaviate.query +apoc.vectordb.weaviate.queryAndUpdate +apoc.vectordb.custom.get +apoc.vectordb.custom +apoc.vectordb.configure \ No newline at end of file diff --git a/extended/src/test/java/apoc/diff/DiffExtendedTest.java b/full/src/test/java/apoc/diff/DiffExtendedTest.java similarity index 84% rename from extended/src/test/java/apoc/diff/DiffExtendedTest.java rename to full/src/test/java/apoc/diff/DiffExtendedTest.java index b7116ced5e..c1b6b282aa 100644 --- a/extended/src/test/java/apoc/diff/DiffExtendedTest.java +++ b/full/src/test/java/apoc/diff/DiffExtendedTest.java @@ -1,8 +1,15 @@ package apoc.diff; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + import apoc.create.Create; import apoc.util.TestUtil; -import apoc.util.collection.Iterators; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -12,18 +19,10 @@ import org.neo4j.graphdb.Relationship; import org.neo4j.graphdb.RelationshipType; import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.helpers.collection.Iterators; import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - public class DiffExtendedTest { @ClassRule @@ -68,7 +67,7 @@ public void relationshipsWithList() { TestUtil.testCall( db, "CREATE ()-[rel1:REL $propRel1]->(), ()-[rel2:REL $propRel2]->()\n" - + "RETURN apoc.diff.relationships(rel1, rel2) AS result", + + "RETURN apoc.diff.relationships(rel1, rel2) AS result", Map.of( "propRel1", Map.of("name", "Charlie", "alpha", "one", "born", 1999, "grocery_list", list), @@ -93,34 +92,31 @@ public void relationshipsSame() { @Test public void relationshipsDiffering() { - String query = "MATCH (leftNode:Node2Start)-[rel2]->(), (rightNode:Node3Start)-[rel3]->() " + - "RETURN apoc.diff.relationships(rel2, rel3) as diff"; + String query = "MATCH (leftNode:Node2Start)-[rel2]->(), (rightNode:Node3Start)-[rel3]->() " + + "RETURN apoc.diff.relationships(rel2, rel3) as diff"; commonAssertionDifferentRels(query); } @Test public void shouldBeDiffWithVirtualRelationships() { - String query = """ - MATCH (start1:Node1Start)-[rel1]->(end1), (start2:Node2Start)-[rel2]->(end2) - WITH apoc.create.vRelationship(start1, type(rel1), {prop1: 'val1', prop2: 2, prop4: 'four'}, end1) AS relA, - apoc.create.vRelationship(start2, type(rel2), {prop1: 'val1', prop3: '3', prop4: 'for'}, end2) AS relB - RETURN apoc.diff.relationships(relA, relB) as diff"""; + String query = "MATCH (start1:Node1Start)-[rel1]->(end1), (start2:Node2Start)-[rel2]->(end2)\n" + + "WITH apoc.create.vRelationship(start1, type(rel1), {prop1: 'val1', prop2: 2, prop4: 'four'}, end1) AS relA,\n" + + " apoc.create.vRelationship(start2, type(rel2), {prop1: 'val1', prop3: '3', prop4: 'for'}, end2) AS relB\n" + + "RETURN apoc.diff.relationships(relA, relB) as diff"; commonAssertionDifferentRels(query); } @Test public void shouldBeSameWithVirtualRelationships() { - String query = "MATCH (start:Node1Start)-[rel]->(end)" + - "WITH apoc.create.vRelationship(start, type(rel), {prop1: 'val1', prop2: 2}, end) AS rel " - + "RETURN apoc.diff.relationships(rel, rel) as diff"; + String query = "MATCH (start:Node1Start)-[rel]->(end)" + + "WITH apoc.create.vRelationship(start, type(rel), {prop1: 'val1', prop2: 2}, end) AS rel " + + "RETURN apoc.diff.relationships(rel, rel) as diff"; commonAssertionSameRels(query); } private void commonAssertionDifferentRels(String query) { - Map result = db.executeTransactionally( - query, - Map.of(), - r -> Iterators.single(r.columnAs("diff"))); + Map result = + db.executeTransactionally(query, Map.of(), r -> Iterators.single(r.columnAs("diff"))); assertNotNull(result); HashMap leftOnly = (HashMap) result.get("leftOnly"); diff --git a/extended/src/test/java/apoc/graph/GraphsExtendedTest.java b/full/src/test/java/apoc/graph/GraphsExtendedTest.java similarity index 52% rename from extended/src/test/java/apoc/graph/GraphsExtendedTest.java rename to full/src/test/java/apoc/graph/GraphsExtendedTest.java index af6cc99c67..79d238b004 100644 --- a/extended/src/test/java/apoc/graph/GraphsExtendedTest.java +++ b/full/src/test/java/apoc/graph/GraphsExtendedTest.java @@ -1,8 +1,19 @@ package apoc.graph; +import static apoc.util.TestUtil.*; +import static apoc.util.Util.map; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + import apoc.create.Create; import apoc.map.Maps; import apoc.util.TestUtil; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; @@ -14,87 +25,86 @@ import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; -import java.util.Comparator; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -import static apoc.util.TestUtil.*; -import static apoc.util.Util.map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - - public class GraphsExtendedTest { @ClassRule public static DbmsRule db = new ImpermanentDbmsRule(); - private static final Map propsPerson1 = map("name", "foo", "plotEmbedding", "22", "posterEmbedding", "3", "plot", "4", "bio", "5", "idNode", 1L); - private static final Map propsPerson2 = map("name", "bar", "plotEmbedding", "22", "posterEmbedding", "3", "plot", "4", "bio", "5", "idNode", 3L); - private static final Map propsMovie1 = map("title", "1", "tmdbId", "ajeje", "idNode", 2L, "posterEmbedding", "33"); - private static final Map propsMovie2 = map("title", "1", "tmdbId", "brazorf", "idNode", 4L, "posterEmbedding", "44"); + private static final Map propsPerson1 = + map("name", "foo", "plotEmbedding", "22", "posterEmbedding", "3", "plot", "4", "bio", "5", "idNode", 1L); + private static final Map propsPerson2 = + map("name", "bar", "plotEmbedding", "22", "posterEmbedding", "3", "plot", "4", "bio", "5", "idNode", 3L); + private static final Map propsMovie1 = + map("title", "1", "tmdbId", "ajeje", "idNode", 2L, "posterEmbedding", "33"); + private static final Map propsMovie2 = + map("title", "1", "tmdbId", "brazorf", "idNode", 4L, "posterEmbedding", "44"); private static final Map propsRel1 = map("idRel", 1L); private static final Map propsRel2 = map("idRel", 2L); - + @BeforeClass public static void setUp() { TestUtil.registerProcedure(db, GraphsExtended.class, Create.class, Maps.class, Graphs.class); - - db.executeTransactionally(""" - CREATE (:Person $propsPerson1)-[:REL $propsRel1]->(:Movie $propsMovie1), - (:Person $propsPerson2)-[:REL $propsRel2]->(:Movie $propsMovie2)""", - map("propsPerson1", propsPerson1, - "propsPerson2", propsPerson2, - "propsMovie1", propsMovie1, - "propsMovie2", propsMovie2, - "propsRel1", propsRel1, - "propsRel2", propsRel2)); db.executeTransactionally( - """ - CREATE (a:Foo {idNode: 11, remove: 1})-[r1:MY_REL {idRel: 11, remove: 1}]->(b:Bar {idNode: 22, remove: 1})-[r2:ANOTHER_REL {idRel: 22, remove: 1}]->(c:Baz {idNode: 33, remove: 1})\s - WITH b, c\s - CREATE (b)-[:REL_TWO {idRel: 33, remove: 1}]->(c), (b)-[:REL_THREE {idRel: 44, remove: 1}]->(c), (b)-[:REL_FOUR {idRel: 55, remove: 1}]->(c)"""); + "CREATE (:Person $propsPerson1)-[:REL $propsRel1]->(:Movie $propsMovie1), (:Person $propsPerson2)-[:REL $propsRel2]->(:Movie $propsMovie2)", + map( + "propsPerson1", + propsPerson1, + "propsPerson2", + propsPerson2, + "propsMovie1", + propsMovie1, + "propsMovie2", + propsMovie2, + "propsRel1", + propsRel1, + "propsRel2", + propsRel2)); + + db.executeTransactionally( + "CREATE (a:Foo {idNode: 11, remove: 1})-[r1:MY_REL {idRel: 11, remove: 1}]->(b:Bar {idNode: 22, remove: 1})-[r2:ANOTHER_REL {idRel: 22, remove: 1}]->(c:Baz {idNode: 33, remove: 1})" + + " WITH b, c " + + "CREATE (b)-[:REL_TWO {idRel: 33, remove: 1}]->(c), (b)-[:REL_THREE {idRel: 44, remove: 1}]->(c), (b)-[:REL_FOUR {idRel: 55, remove: 1}]->(c)"); db.executeTransactionally( "CREATE (a:Foo {idNode: 44, remove: 1})-[r1:MY_REL {idRel: 66, remove: 1}]->(b:Bar {idNode: 55, remove: 1})-[r2:ANOTHER_REL {idRel: 77, remove: 1}]->(c:Baz {idNode: 66, remove: 1})"); db.executeTransactionally( - "CREATE (a:One {idNode: 77, remove: 1})-[r1:MY_REL {idRel: 88, remove: 1}]->(b:Two {idNode: 88, remove: 1}), " + - "(:Two {idNode: 100, remove: 1})-[r2:ANOTHER_REL {idRel: 99, remove: 1}]->(c:Three {idNode: 99, remove: 1})"); + "CREATE (a:One {idNode: 77, remove: 1})-[r1:MY_REL {idRel: 88, remove: 1}]->(b:Two {idNode: 88, remove: 1}), " + + "(:Two {idNode: 100, remove: 1})-[r2:ANOTHER_REL {idRel: 99, remove: 1}]->(c:Three {idNode: 99, remove: 1})"); } - + @Test public void testFilterPropertiesConsistentWithManualFilteringAndDoesNotChangeOriginalEntities() { - // check that the apoc.graph.filterProperties and the query used here: https://github.com/neo4j-contrib/neo4j-apoc-procedures/issues/3937 + // check that the apoc.graph.filterProperties and the query used here: + // https://github.com/neo4j-contrib/neo4j-apoc-procedures/issues/3937 // produce the same result - testCall(db, """ - match path=(:Person)-[:REL]->(:Movie) - with collect(path) as paths - call apoc.graph.fromPaths(paths,"results",{}) yield graph - with graph.nodes as nodes, graph.relationships as rels - with rels, apoc.map.fromPairs([n in nodes | [coalesce(n.tmdbId, n.name), apoc.create.vNode(labels(n), apoc.map.removeKeys(properties(n), ['plotEmbedding', 'posterEmbedding', 'plot', 'bio'] ) )]]) as nodes - return apoc.map.values(nodes, keys(nodes)) AS nodes, - [r in rels | apoc.create.vRelationship(nodes[coalesce(startNode(r).tmdbId,startNode(r).name)], type(r), properties(r), nodes[coalesce(endNode(r).tmdbId,endNode(r).name)])] AS relationships""", + testCall( + db, + "match path=(:Person)-[:REL]->(:Movie)\n" + "with collect(path) as paths\n" + + "call apoc.graph.fromPaths(paths,\"results\",{}) yield graph\n" + + "with graph.nodes as nodes, graph.relationships as rels\n" + + "with rels, apoc.map.fromPairs([n in nodes | [coalesce(n.tmdbId, n.name), apoc.create.vNode(labels(n), apoc.map.removeKeys(properties(n), ['plotEmbedding', 'posterEmbedding', 'plot', 'bio'] ) )]]) as nodes\n" + + "return apoc.map.values(nodes, keys(nodes)) AS nodes,\n" + + " [r in rels | apoc.create.vRelationship(nodes[coalesce(startNode(r).tmdbId,startNode(r).name)], type(r), properties(r), nodes[coalesce(endNode(r).tmdbId,endNode(r).name)])] AS relationships", this::commonFilterPropertiesAssertions); - - testCall(db, """ - MATCH path=(:Person)-[:REL]->(:Movie) - WITH apoc.graph.filterProperties(path, {_all: ['plotEmbedding', 'posterEmbedding', 'plot', 'bio']}) as graph - RETURN graph.nodes AS nodes, graph.relationships AS relationships""", + + testCall( + db, + "MATCH path=(:Person)-[:REL]->(:Movie)\n" + + "WITH apoc.graph.filterProperties(path, {_all: ['plotEmbedding', 'posterEmbedding', 'plot', 'bio']}) as graph\n" + + "RETURN graph.nodes AS nodes, graph.relationships AS relationships", this::commonFilterPropertiesAssertions); - + // check that original nodes haven't changed testResult(db, "MATCH path=(n:Person)-[:REL]->(:Movie) RETURN path ORDER BY n.id", r -> { Iterator row = r.columnAs("path"); Path path = row.next(); Map propsStart = path.startNode().getAllProperties(); Map propsEnd = path.endNode().getAllProperties(); - Map propsRel = path.relationships().iterator().next().getAllProperties(); - + Map propsRel = + path.relationships().iterator().next().getAllProperties(); + assertEquals(propsPerson1, propsStart); assertEquals(propsMovie1, propsEnd); assertEquals(propsRel1, propsRel); @@ -106,28 +116,28 @@ with collect(path) as paths assertEquals(propsPerson2, propsStart); assertEquals(propsMovie2, propsEnd); assertEquals(propsRel2, propsRel); - + assertFalse(row.hasNext()); }); } @Test public void testFilterPropertiesProcedure() { - - testCall(db, """ - MATCH path=(:Person)-[:REL]->(:Movie) - WITH collect(path) AS paths - CALL apoc.graph.filterProperties(paths, {_all: ['plotEmbedding', 'posterEmbedding', 'plot', 'bio']}) - YIELD nodes, relationships - RETURN nodes, relationships""", + + testCall( + db, + "MATCH path=(:Person)-[:REL]->(:Movie)\n" + "WITH collect(path) AS paths\n" + + "CALL apoc.graph.filterProperties(paths, {_all: ['plotEmbedding', 'posterEmbedding', 'plot', 'bio']})\n" + + "YIELD nodes, relationships\n" + + "RETURN nodes, relationships", this::commonFilterPropertiesAssertions); - - testCall(db, """ - MATCH path=(:Person)-[:REL]->(:Movie) - WITH collect(path) AS paths - CALL apoc.graph.filterProperties(paths, {Movie: ['posterEmbedding'], Person: ['posterEmbedding', 'plotEmbedding', 'plot', 'bio']}) - YIELD nodes, relationships - RETURN nodes, relationships""", + + testCall( + db, + "MATCH path=(:Person)-[:REL]->(:Movie)\n" + "WITH collect(path) AS paths\n" + + "CALL apoc.graph.filterProperties(paths, {Movie: ['posterEmbedding'], Person: ['posterEmbedding', 'plotEmbedding', 'plot', 'bio']})\n" + + "YIELD nodes, relationships\n" + + "RETURN nodes, relationships", this::commonFilterPropertiesAssertions); } @@ -165,19 +175,20 @@ private void commonFilterPropertiesAssertions(Map r) { public void filterPropertiesWithPathsWithMultipleRels() { Set expectedIdNodes = Set.of(11L, 22L, 33L, 44L, 55L, 66L); Set expectedIdRels = Set.of(11L, 22L, 33L, 44L, 55L, 66L, 77L); - - testCall(db, """ - MATCH path=(:Foo)--(:Bar)--(:Baz) - WITH collect(path) AS paths - CALL apoc.graph.filterProperties(paths, {_all: ['remove']}, {_all: ['remove']}) - YIELD nodes, relationships - RETURN nodes, relationships""", + + testCall( + db, + "MATCH path=(:Foo)--(:Bar)--(:Baz)\n" + "WITH collect(path) AS paths\n" + + "CALL apoc.graph.filterProperties(paths, {_all: ['remove']}, {_all: ['remove']})\n" + + "YIELD nodes, relationships\n" + + "RETURN nodes, relationships", r -> assertNodeAndRelIdProps(r, expectedIdNodes, expectedIdRels)); - - testCall(db, """ - MATCH path=(:Foo)--(:Bar)--(:Baz) - WITH apoc.graph.filterProperties(path, {_all: ['remove']}, {_all: ['remove']}) as graph - RETURN graph.nodes AS nodes, graph.relationships AS relationships""", + + testCall( + db, + "MATCH path=(:Foo)--(:Bar)--(:Baz)\n" + + "WITH apoc.graph.filterProperties(path, {_all: ['remove']}, {_all: ['remove']}) as graph\n" + + "RETURN graph.nodes AS nodes, graph.relationships AS relationships", r -> assertNodeAndRelIdProps(r, expectedIdNodes, expectedIdRels)); } @@ -185,57 +196,57 @@ WITH collect(path) AS paths public void testWithCompositeDataTypes() { Set expectedIdNodes = Set.of(100L, 99L, 88L, 77L); Set expectedIdRels = Set.of(99L, 88L); - - testCall(db, """ - MATCH p1=(:One)--(:Two), p2=(:Two)--(:Three) - CALL apoc.graph.filterProperties([p1, p2], {_all: ['remove']}, {_all: ['remove']}) - YIELD nodes, relationships - RETURN nodes, relationships""", + + testCall( + db, + "MATCH p1=(:One)--(:Two), p2=(:Two)--(:Three)\n" + + "CALL apoc.graph.filterProperties([p1, p2], {_all: ['remove']}, {_all: ['remove']})\n" + + "YIELD nodes, relationships\n" + + "RETURN nodes, relationships", r -> assertNodeAndRelIdProps(r, expectedIdNodes, expectedIdRels)); - - testCall(db, """ - MATCH p1=(:One)--(:Two), p2=(:Two)--(:Three) - CALL apoc.graph.filterProperties([{key1: p1, key2: [p1, p2]}], {_all: ['remove']}, {_all: ['remove']}) - YIELD nodes, relationships - RETURN nodes, relationships""", + + testCall( + db, + "MATCH p1=(:One)--(:Two), p2=(:Two)--(:Three)\n" + + "CALL apoc.graph.filterProperties([{key1: p1, key2: [p1, p2]}], {_all: ['remove']}, {_all: ['remove']})\n" + + "YIELD nodes, relationships\n" + + "RETURN nodes, relationships", r -> assertNodeAndRelIdProps(r, expectedIdNodes, expectedIdRels)); - - testCall(db, """ - MATCH p1=(:One)--(:Two), p2=(:Two)--(:Three) - CALL apoc.graph.filterProperties([{key2: {subKey: [p1, p2]}}], {_all: ['remove']}, {_all: ['remove']}) - YIELD nodes, relationships - RETURN nodes, relationships""", + + testCall( + db, + "MATCH p1=(:One)--(:Two), p2=(:Two)--(:Three)\n" + + "CALL apoc.graph.filterProperties([{key2: {subKey: [p1, p2]}}], {_all: ['remove']}, {_all: ['remove']})\n" + + "YIELD nodes, relationships\n" + + "RETURN nodes, relationships", r -> assertNodeAndRelIdProps(r, expectedIdNodes, expectedIdRels)); } - private void assertNodeAndRelIdProps(Map r, Set expectedIdNodes, Set expectedIdRels) { + private void assertNodeAndRelIdProps( + Map r, Set expectedIdNodes, Set expectedIdRels) { Set actualIdNodes = ((List) r.get("nodes")) - .stream() - .map(i -> i.getProperty("idNode")) - .collect(Collectors.toSet()); + .stream().map(i -> i.getProperty("idNode")).collect(Collectors.toSet()); assertEquals(expectedIdNodes, actualIdNodes); Set actualIdRels = ((List) r.get("relationships")) - .stream() - .map(i -> i.getProperty("idRel")) - .collect(Collectors.toSet()); + .stream().map(i -> i.getProperty("idRel")).collect(Collectors.toSet()); assertEquals(expectedIdRels, actualIdRels); } @Test public void testFilterPropertiesWithEmptyNodeAndRelPropertiesToRemove() { - testCall(db, """ - MATCH path=(:Person)-[:REL]->(:Movie) - WITH collect(path) AS paths - CALL apoc.graph.filterProperties(paths) - YIELD nodes, relationships - RETURN nodes, relationships""", + testCall( + db, + "MATCH path=(:Person)-[:REL]->(:Movie)\n" + "WITH collect(path) AS paths\n" + + "CALL apoc.graph.filterProperties(paths)\n" + + "YIELD nodes, relationships\n" + + "RETURN nodes, relationships", this::assertEmptyFilter); - testCall(db, """ - MATCH path=(:Person)-[:REL]->(:Movie) - WITH apoc.graph.filterProperties(path) as graph - RETURN graph.nodes AS nodes, graph.relationships AS relationships""", + testCall( + db, + "MATCH path=(:Person)-[:REL]->(:Movie)\n" + "WITH apoc.graph.filterProperties(path) as graph\n" + + "RETURN graph.nodes AS nodes, graph.relationships AS relationships", this::assertEmptyFilter); } diff --git a/full/src/test/java/apoc/map/MapsExtendedTest.java b/full/src/test/java/apoc/map/MapsExtendedTest.java new file mode 100644 index 0000000000..e62dbdd5c7 --- /dev/null +++ b/full/src/test/java/apoc/map/MapsExtendedTest.java @@ -0,0 +1,113 @@ +package apoc.map; + +import static org.junit.Assert.assertEquals; + +import apoc.util.TestUtil; +import java.util.List; +import java.util.Map; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +public class MapsExtendedTest { + + private static final String OLD_KEY = "depends_on"; + private static final String NEW_KEY = "new_name"; + + @Rule + public DbmsRule db = new ImpermanentDbmsRule(); + + @Before + public void setUp() { + TestUtil.registerProcedure(db, MapsExtended.class); + } + + @Test + public void testRenameKeyNotExistent() { + Map map = Map.of("testKey", List.of(Map.of("testKey", "some_value")), "name", "Mario"); + TestUtil.testCall( + db, + "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo) as value", + Map.of("map", map, "keyFrom", "notExistent", "keyTo", "something"), + (r) -> assertEquals(map, r.get("value"))); + } + + @Test + public void testRenameKeyInnerKey() { + Map map = Map.of("testKey", List.of(Map.of("innerKey", "some_value")), "name", "Mario"); + TestUtil.testCall( + db, + "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo) as value", + Map.of("map", map, "keyFrom", "innerKey", "keyTo", "otherKey"), + (r) -> { + Map expected = + Map.of("testKey", List.of(Map.of("otherKey", "some_value")), "name", "Mario"); + ; + assertEquals(expected, r.get("value")); + }); + } + + @Test + public void testRenameKey() { + Map map = Map.of(OLD_KEY, List.of(Map.of(OLD_KEY, "some_value")), "name", "Mario"); + ; + TestUtil.testCall( + db, + "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo) as value", + Map.of("map", map, "keyFrom", OLD_KEY, "keyTo", NEW_KEY), + (r) -> { + Map expected = + Map.of(NEW_KEY, List.of(Map.of(NEW_KEY, "some_value")), "name", "Mario"); + Map> actual = (Map) r.get("value"); + assertEquals(expected, actual); + }); + } + + @Test + public void testRenameKeyComplexMap() { + Map map = Map.of( + OLD_KEY, + List.of(1L, "test", Map.of(OLD_KEY, "some_value")), + "otherKey", + List.of(1L, List.of("test", Map.of(OLD_KEY, "some_value"))), + "name", + Map.of(OLD_KEY, "some_value"), + "other", + "key"); + TestUtil.testCall( + db, + "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo) as value", + Map.of("map", map, "keyFrom", OLD_KEY, "keyTo", NEW_KEY), + (r) -> { + Map expected = Map.of( + NEW_KEY, + List.of(1L, "test", Map.of(NEW_KEY, "some_value")), + "otherKey", + List.of(1L, List.of("test", Map.of(NEW_KEY, "some_value"))), + "name", + Map.of(NEW_KEY, "some_value"), + "other", + "key"); + Map> actual = (Map) r.get("value"); + assertEquals(expected, actual); + }); + } + + @Test + public void testRenameKeyRecursiveFalse() { + Map map = Map.of(OLD_KEY, List.of(Map.of(OLD_KEY, "some_value")), "name", "Mario"); + ; + TestUtil.testCall( + db, + "RETURN apoc.map.renameKey($map, $keyFrom, $keyTo, {recursive: false}) as value", + Map.of("map", map, "keyFrom", OLD_KEY, "keyTo", NEW_KEY), + (r) -> { + Map> actual = (Map) r.get("value"); + Map expected = + Map.of(NEW_KEY, List.of(Map.of(OLD_KEY, "some_value")), "name", "Mario"); + assertEquals(expected, actual); + }); + } +} diff --git a/full/src/test/java/apoc/util/UtilsExtendedTest.java b/full/src/test/java/apoc/util/UtilsExtendedTest.java index 926d5649ea..93d41fc086 100644 --- a/full/src/test/java/apoc/util/UtilsExtendedTest.java +++ b/full/src/test/java/apoc/util/UtilsExtendedTest.java @@ -3,6 +3,7 @@ import static apoc.util.TestUtil.testCall; import static org.junit.Assert.assertTrue; +import org.junit.Assume; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; @@ -26,4 +27,10 @@ public void testMultipleCharsetsCompressionWithDifferentResults() { "RETURN apoc.util.hashCode(rand()) AS hashCode", r -> assertTrue(r.get("hashCode") instanceof Long)); } + + public static String checkEnvVar(String envKey) { + String value = System.getenv(envKey); + Assume.assumeNotNull(String.format("No %s environment configured", envKey), value); + return value; + } } diff --git a/full/src/test/java/apoc/vectordb/PineconeTest.java b/full/src/test/java/apoc/vectordb/PineconeTest.java new file mode 100644 index 0000000000..b12fbf5403 --- /dev/null +++ b/full/src/test/java/apoc/vectordb/PineconeTest.java @@ -0,0 +1,98 @@ +package apoc.vectordb; + +import static apoc.ml.RestAPIConfig.BODY_KEY; +import static apoc.ml.RestAPIConfig.HEADERS_KEY; +import static apoc.ml.RestAPIConfig.JSON_PATH_KEY; +import static apoc.ml.RestAPIConfig.METHOD_KEY; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.util.Util.map; +import static apoc.util.UtilsExtendedTest.checkEnvVar; +import static apoc.vectordb.VectorEmbeddingConfig.VECTOR_KEY; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import apoc.util.TestUtil; +import java.net.URL; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +/** + * It leverages `apoc.vectordb.custom*` procedures + * * + * * + * Example of Pinecone RestAPI: + * PINECONE_HOST: `https://INDEX-ID.svc.gcp-starter.pinecone.io` + * PINECONE_KEY: `API Key` + * PINECONE_NAMESPACE: `the one to be specified in body: {.. "ns": NAMESPACE}` + * PINECONE_DIMENSION: vector dimension + */ +public class PineconeTest { + private static String apiKey; + private static String host; + private static String size; + private static String namespace; + + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + @BeforeClass + public static void setUp() throws Exception { + apiKey = checkEnvVar("PINECONE_KEY"); + host = checkEnvVar("PINECONE_HOST"); + size = checkEnvVar("PINECONE_DIMENSION"); + namespace = checkEnvVar("PINECONE_NAMESPACE"); + + TestUtil.registerProcedure(db, VectorDb.class); + } + + @Test + public void callQueryEndpointViaCustomGetProc() { + + Map conf = getConf(); + conf.put(VECTOR_KEY, "values"); + + testResult(db, "CALL apoc.vectordb.custom.get($host, $conf)", map("host", host + "/query", "conf", conf), r -> { + r.forEachRemaining(i -> { + assertNotNull(i.get("score")); + assertNotNull(i.get("metadata")); + assertNotNull(i.get("id")); + assertNotNull(i.get("vector")); + }); + }); + } + + @Test + public void callQueryEndpointViaCustomProc() { + testCall(db, "CALL apoc.vectordb.custom($host, $conf)", map("host", host + "/query", "conf", getConf()), r -> { + List value = (List) r.get("value"); + value.forEach(i -> { + assertTrue(i.containsKey("score")); + assertTrue(i.containsKey("metadata")); + assertTrue(i.containsKey("id")); + }); + }); + } + + /** + * TODO: "method" is null as a workaround. + * Since with `method: POST` the {@link apoc.util.Util#openUrlConnection(URL, Map)} has a `setChunkedStreamingMode` + * that makes the request to respond 200 OK, but returns an empty result + */ + private static Map getConf() { + List vector = Collections.nCopies(Integer.parseInt(size), 0.1); + + Map body = map( + "namespace", namespace, "vector", vector, "topK", 3, "includeValues", true, "includeMetadata", true); + + Map header = map("Api-Key", apiKey); + + return map(BODY_KEY, body, HEADERS_KEY, header, METHOD_KEY, null, JSON_PATH_KEY, "matches"); + } +} diff --git a/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java b/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java new file mode 100644 index 0000000000..d949adf64d --- /dev/null +++ b/full/src/test/java/apoc/vectordb/VectorDbTestUtil.java @@ -0,0 +1,92 @@ +package apoc.vectordb; + +import static apoc.util.TestUtil.testResult; +import static apoc.util.Util.map; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.Map; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.ResourceIterator; +import org.neo4j.graphdb.Result; + +public class VectorDbTestUtil { + + public enum EntityType { + NODE, + REL, + FALSE + } + + public static void dropAndDeleteAll(GraphDatabaseService db) { + db.executeTransactionally("MATCH (n) DETACH DELETE n"); + } + + public static void assertBerlinResult(Map row, EntityType entityType) { + assertBerlinResult(row, "1", entityType); + } + + public static void assertBerlinResult(Map row, String id, EntityType entityType) { + assertEquals(Map.of("city", "Berlin", "foo", "one"), row.get("metadata")); + assertEquals(id, row.get("id").toString()); + if (!entityType.equals(EntityType.FALSE)) { + String entity = entityType.equals(EntityType.NODE) ? "node" : "rel"; + Map props = ((Entity) row.get(entity)).getAllProperties(); + assertBerlinProperties(props); + } + } + + public static void assertLondonResult(Map row, EntityType entityType) { + assertLondonResult(row, "2", entityType); + } + + public static void assertLondonResult(Map row, String id, EntityType entityType) { + assertEquals(Map.of("city", "London", "foo", "two"), row.get("metadata")); + assertEquals(id, row.get("id").toString()); + if (!entityType.equals(EntityType.FALSE)) { + String entity = entityType.equals(EntityType.NODE) ? "node" : "rel"; + Map props = ((Entity) row.get(entity)).getAllProperties(); + assertLondonProperties(props); + } + } + + public static void assertNodesCreated(GraphDatabaseService db) { + testResult( + db, + "MATCH (n:Test) RETURN properties(n) AS props ORDER BY n.myId", + VectorDbTestUtil::vectorEntityAssertions); + } + + public static void assertRelsCreated(GraphDatabaseService db) { + testResult( + db, + "MATCH (:Start)-[r:TEST]->(:End) RETURN properties(r) AS props ORDER BY r.myId", + VectorDbTestUtil::vectorEntityAssertions); + } + + public static void vectorEntityAssertions(Result r) { + ResourceIterator propsIterator = r.columnAs("props"); + assertBerlinProperties(propsIterator.next()); + assertLondonProperties(propsIterator.next()); + + assertFalse(propsIterator.hasNext()); + } + + private static void assertLondonProperties(Map props) { + assertEquals("London", props.get("city")); + assertEquals("two", props.get("myId")); + assertTrue(props.get("vect") instanceof float[]); + } + + private static void assertBerlinProperties(Map props) { + assertEquals("Berlin", props.get("city")); + assertEquals("one", props.get("myId")); + assertTrue(props.get("vect") instanceof float[]); + } + + public static Map getAuthHeader(String key) { + return map("Authorization", "Bearer " + key); + } +} diff --git a/test-utils/build.gradle b/test-utils/build.gradle index 909d4bbd37..dde9f628ec 100644 --- a/test-utils/build.gradle +++ b/test-utils/build.gradle @@ -42,6 +42,9 @@ dependencies { api group: 'org.testcontainers', name: 'postgresql', version: testContainersVersion api group: 'org.testcontainers', name: 'cassandra', version: testContainersVersion api group: 'org.testcontainers', name: 'localstack', version: testContainersVersion + api group: 'org.testcontainers', name: 'qdrant', version: '1.19.7' + api group: 'org.testcontainers', name: 'chromadb', version: '1.19.7' + api group: 'org.testcontainers', name: 'weaviate', version: '1.19.7' api group: 'org.apache.arrow', name: 'arrow-vector', version: '16.1.0' api group: 'org.apache.arrow', name: 'arrow-memory-netty', version: '16.1.0' }