From 5b329461e456a63d8152efef70f84286dda8fd85 Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Mon, 2 Dec 2024 12:24:47 +0100 Subject: [PATCH] [NOID] Fixes #4233: Improve Weaviate error handling (#4239) * Fixes #4233: Improve Weaviate error handling * fix tests --- .../apoc/full/it/vectordb/WeaviateTest.java | 19 +++++++++++++++++++ .../src/main/java/apoc/vectordb/Weaviate.java | 11 ++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java index 14b76f4ef9..45439b750d 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java @@ -10,6 +10,7 @@ import static apoc.vectordb.VectorDbTestUtil.EntityType.*; import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; +import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY; import static apoc.vectordb.VectorMappingConfig.*; import static org.assertj.core.api.Assertions.assertThat; @@ -167,6 +168,24 @@ public void writeOperationWithReadOnlyUser() { assertThat(e.getMessage()).contains("HTTP response code: 403"); } } + + @Test + public void queryWithWrongEmbeddingSize() { + Map conf = map(ALL_RESULTS_KEY, true, + FIELDS_KEY, FIELDS, + HEADERS_KEY, READONLY_AUTHORIZATION); + + + try { + testCall(db, "CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9], null, 5, $conf)", + map("host", HOST, "conf", conf), + r -> fail()); + } catch (Exception e) { + String message = e.getMessage(); + String expectedErrMsg = "distance between entrypoint and query node: vector lengths don't match: 4 vs 3"; + assertEquals(expectedErrMsg, message); + } + } @Test public void getVectorsWithoutVectorResult() { diff --git a/full/src/main/java/apoc/vectordb/Weaviate.java b/full/src/main/java/apoc/vectordb/Weaviate.java index c56eb2b982..08f9146b61 100644 --- a/full/src/main/java/apoc/vectordb/Weaviate.java +++ b/full/src/main/java/apoc/vectordb/Weaviate.java @@ -18,6 +18,9 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; + +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.Transaction; import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; @@ -248,7 +251,13 @@ private Stream queryCommon( VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); return getEmbeddingResultStream(conf, procedureCallContext, tx, v -> { - Object getValue = ((Map) v).get("data").get("Get"); + Map mapResult = (Map) v; + List errors = (List) mapResult.get("errors"); + if ( CollectionUtils.isNotEmpty(errors) ) { + String message = "An error occurred during Weaviate API response: \n" + StringUtils.join(errors, "\n"); + throw new RuntimeException(message); + } + Object getValue = mapResult.get("data").get("Get"); Object collectionValue = ((Map) getValue).get(collection); return ((List) collectionValue).stream().map(i -> { Map additional = (Map) i.remove("_additional");