From b7697f400b55c83e3319ce58a1767944e11f726c Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Wed, 4 Dec 2024 10:21:47 +0100 Subject: [PATCH] Fixes #4232: The apoc.vectordb.configure(WEAVIATE', ..) procedure should append /v1 to url (#4248) --- .../test/java/apoc/vectordb/WeaviateTest.java | 99 ++++++++++++------- .../src/main/java/apoc/vectordb/VectorDb.java | 6 +- .../main/java/apoc/vectordb/VectorDbUtil.java | 14 +++ 3 files changed, 83 insertions(+), 36 deletions(-) diff --git a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java index 5b3aed4d27..ea5ed24aca 100644 --- a/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java +++ b/extended-it/src/test/java/apoc/vectordb/WeaviateTest.java @@ -4,7 +4,6 @@ import apoc.util.MapUtil; import apoc.util.TestUtil; import org.junit.AfterClass; -import org.junit.Assume; import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -507,41 +506,21 @@ MAPPING_KEY, map(REL_TYPE, "TEST", public void queryVectorsWithSystemDbStorage() { String keyConfig = "weaviate-config-foo"; String baseUrl = "http://" + HOST + "/v1"; - Map mapping = map(EMBEDDING_KEY, "vect", - NODE_LABEL, "Test", - ENTITY_KEY, "myId", - METADATA_KEY, "foo"); - sysDb.executeTransactionally("CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)", - map("vectorName", WEAVIATE.toString(), - "keyConfig", keyConfig, - "databaseName", DEFAULT_DATABASE_NAME, - "conf", map( - "host", baseUrl, - "credentials", ADMIN_KEY, - "mapping", mapping - ) - ) - ); - - db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); - - testResult(db, "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)", - map("host", keyConfig, - "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true) - ), - r -> { - Map row = r.next(); - assertBerlinResult(row, ID_1, NODE); - assertNotNull(row.get("score")); - assertNotNull(row.get("vector")); + assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, false); + } - row = r.next(); - assertLondonResult(row, ID_2, NODE); - assertNotNull(row.get("score")); - assertNotNull(row.get("vector")); - }); + @Test + public void queryVectorsWithSystemDbStorageWithUrlWithoutVersion() { + String keyConfig = "weaviate-config-foo"; + String baseUrl = "http://" + HOST; + assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, false); + } - assertNodesCreated(db); + @Test + public void queryVectorsWithSystemDbStorageWithUrlV3Version() { + String keyConfig = "weaviate-config-foo"; + String baseUrl = "http://" + HOST + "/v3"; + assertQueryVectorsWithSystemDbStorage(keyConfig, baseUrl, true); } @Test @@ -575,4 +554,56 @@ WITH collect(node) as paths ), VectorDbTestUtil::assertRagWithVectors); } + + private static void assertQueryVectorsWithSystemDbStorage(String keyConfig, String baseUrl, boolean fails) { + Map mapping = map(EMBEDDING_KEY, "vect", + NODE_LABEL, "Test", + ENTITY_KEY, "myId", + METADATA_KEY, "foo"); + sysDb.executeTransactionally("CALL apoc.vectordb.configure($vectorName, $keyConfig, $databaseName, $conf)", + map("vectorName", WEAVIATE.toString(), + "keyConfig", keyConfig, + "databaseName", DEFAULT_DATABASE_NAME, + "conf", map( + "host", baseUrl, + "credentials", ADMIN_KEY, + "mapping", mapping + ) + ) + ); + + db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})"); + + String query = "CALL apoc.vectordb.weaviate.queryAndUpdate($host, 'TestCollection', [0.2, 0.1, 0.9, 0.7], null, 5, $conf)"; + Map params = map("host", keyConfig, + "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true) + ); + + if (fails) { + assertFails( + db, + query, + params, + "Caused by: java.io.FileNotFoundException: http://127.0.0.1:" + HOST.split(":")[1] + "/v3/graphql" + ); + return; + } + + + testResult(db, query, + params, + r -> { + Map row = r.next(); + assertBerlinResult(row, ID_1, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + + row = r.next(); + assertLondonResult(row, ID_2, NODE); + assertNotNull(row.get("score")); + assertNotNull(row.get("vector")); + }); + + assertNodesCreated(db); + } } diff --git a/extended/src/main/java/apoc/vectordb/VectorDb.java b/extended/src/main/java/apoc/vectordb/VectorDb.java index f136f4c6ee..61df7a210d 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDb.java +++ b/extended/src/main/java/apoc/vectordb/VectorDb.java @@ -38,7 +38,9 @@ import static apoc.util.ExtendedUtil.setProperties; import static apoc.util.JsonUtil.OBJECT_MAPPER; import static apoc.util.SystemDbUtil.withSystemDb; -import static apoc.vectordb.VectorDbUtil.*; +import static apoc.vectordb.VectorDbUtil.EmbeddingResult; +import static apoc.vectordb.VectorDbUtil.appendVersionUrlIfNeeded; +import static apoc.vectordb.VectorDbUtil.getEndpoint; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; @@ -259,7 +261,7 @@ public void vectordb( Node node = Util.mergeNode(transaction, label, null, Pair.of(SystemPropertyKeys.name.name(), configKey)); Map mapping = (Map) config.get("mapping"); - String host = (String) config.get("host"); + String host = appendVersionUrlIfNeeded(type, (String) config.get("host")); Object credentials = config.get("credentials"); if (host != null) { diff --git a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java index 18cd3f8db7..3a11fdfa19 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java +++ b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -135,4 +135,18 @@ public static void methodAndPayloadNull(Map config) { config.put(METHOD_KEY, null); config.put(BODY_KEY, null); } + + /** + * If the vectorDb is WEAVIATE and endpoint doesn't end with `/vN`, where N is a number, + * then add `/v1` to the endpoint + */ + public static String appendVersionUrlIfNeeded(VectorDbHandler.Type type, String host) { + if (VectorDbHandler.Type.WEAVIATE == type) { + String regex = ".*(/v\\d+)$"; + if (!host.matches(regex)) { + host = host + "/v1"; + } + } + return host; + } }