Skip to content

Commit

Permalink
[NOID] Fixes #4233: Improve Weaviate error handling (#4239)
Browse files Browse the repository at this point in the history
* Fixes #4233: Improve Weaviate error handling

* fix tests
  • Loading branch information
vga91 committed Dec 5, 2024
1 parent ae7315d commit bce5312
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
19 changes: 19 additions & 0 deletions full-it/src/test/java/apoc/full/it/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import static apoc.vectordb.VectorDbTestUtil.EntityType.*;
import static apoc.vectordb.VectorDbUtil.ERROR_READONLY_MAPPING;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.MAPPING_KEY;
import static apoc.vectordb.VectorMappingConfig.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.neo4j.configuration.GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
import static org.neo4j.configuration.GraphDatabaseSettings.SYSTEM_DATABASE_NAME;
Expand Down Expand Up @@ -168,6 +170,23 @@ public void writeOperationWithReadOnlyUser() {
}
}

@Test
public void queryWithWrongEmbeddingSize() {
Map<String, Object> conf = map(ALL_RESULTS_KEY, true, FIELDS_KEY, FIELDS, HEADERS_KEY, READONLY_AUTHORIZATION);

try {
testCall(
db,
"CALL apoc.vectordb.weaviate.query($host, 'TestCollection', [0.2, 0.1, 0.9], null, 5, $conf)",
map("host", HOST, "conf", conf),
r -> fail());
} catch (Exception e) {
String message = e.getMessage();
String expectedErrMsg = "distance between entrypoint and query node: vector lengths don't match: 4 vs 3";
assertTrue(message.contains(expectedErrMsg));
}
}

@Test
public void getVectorsWithoutVectorResult() {
testResult(
Expand Down
10 changes: 9 additions & 1 deletion full/src/main/java/apoc/vectordb/Weaviate.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
Expand Down Expand Up @@ -248,7 +250,13 @@ private Stream<EmbeddingResult> queryCommon(
VectorEmbeddingConfig conf =
DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection);
return getEmbeddingResultStream(conf, procedureCallContext, tx, v -> {
Object getValue = ((Map<String, Map>) v).get("data").get("Get");
Map<String, Map> mapResult = (Map<String, Map>) v;
List<Map> errors = (List<Map>) mapResult.get("errors");
if (CollectionUtils.isNotEmpty(errors)) {
String message = "An error occurred during Weaviate API response: \n" + StringUtils.join(errors, "\n");
throw new RuntimeException(message);
}
Object getValue = mapResult.get("data").get("Get");
Object collectionValue = ((Map) getValue).get(collection);
return ((List<Map>) collectionValue).stream().map(i -> {
Map additional = (Map) i.remove("_additional");
Expand Down

0 comments on commit bce5312

Please sign in to comment.