From a10ec9e1d70d1b80f88d83944bd73cdf9c809797 Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Fri, 6 Dec 2024 15:46:51 +0100 Subject: [PATCH] [NOID] Fixes #4156: Improves handling of empty or blank input for openai procedures (#4228) * Fixes #4156: Improves handling of empty or blank input for openai procedures * fix tests * changed boolean conditions --- .../modules/ROOT/pages/ml/openai.adoc | 1 + full/src/main/java/apoc/ml/MLUtil.java | 2 + full/src/main/java/apoc/ml/OpenAI.java | 44 +++++++++++++- full/src/test/java/apoc/ml/OpenAIIT.java | 59 +++++++++++++++++++ .../test/java/apoc/util/ExtendedTestUtil.java | 13 ++++ 5 files changed, 118 insertions(+), 1 deletion(-) diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index 41a482365a..f7ff49dd5a 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -27,6 +27,7 @@ If present, they take precedence over the analogous APOC configs. | apiType | analogous to `apoc.ml.openai.type` APOC config | endpoint | analogous to `apoc.ml.openai.url` APOC config | apiVersion | analogous to `apoc.ml.azure.api.version` APOC config +| failOnError | If true (default), the procedure fails in case of empty, blank or null input |=== diff --git a/full/src/main/java/apoc/ml/MLUtil.java b/full/src/main/java/apoc/ml/MLUtil.java index 1b9a770b29..47025adb4a 100644 --- a/full/src/main/java/apoc/ml/MLUtil.java +++ b/full/src/main/java/apoc/ml/MLUtil.java @@ -1,6 +1,8 @@ package apoc.ml; public class MLUtil { + public static final String ERROR_NULL_INPUT = "Null, blank or empty input provided. Please specify a valid input"; + public static final String ENDPOINT_CONF_KEY = "endpoint"; public static final String API_VERSION_CONF_KEY = "apiVersion"; public static final String MODEL_CONF_KEY = "model"; diff --git a/full/src/main/java/apoc/ml/OpenAI.java b/full/src/main/java/apoc/ml/OpenAI.java index 934b2b1b7b..8403411e19 100644 --- a/full/src/main/java/apoc/ml/OpenAI.java +++ b/full/src/main/java/apoc/ml/OpenAI.java @@ -6,20 +6,27 @@ import static apoc.ml.MLUtil.API_TYPE_CONF_KEY; import static apoc.ml.MLUtil.API_VERSION_CONF_KEY; import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY; +import static apoc.ml.MLUtil.ERROR_NULL_INPUT; import static apoc.ml.MLUtil.MODEL_CONF_KEY; import apoc.ApocConfig; import apoc.Extended; import apoc.result.MapResult; import apoc.util.JsonUtil; +import apoc.util.Util; import com.fasterxml.jackson.core.JsonProcessingException; import java.net.MalformedURLException; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.function.BiFunction; +import java.util.function.Supplier; import java.util.stream.Stream; +import org.apache.commons.collections4.MapUtils; +import org.apache.commons.lang3.StringUtils; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Name; @@ -27,6 +34,8 @@ @Extended public class OpenAI { + public static final String FAIL_ON_ERROR_CONF = "failOnError"; + @Context public ApocConfig apocConfig; @@ -106,7 +115,10 @@ public Stream getEmbedding( "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } */ - + boolean failOnError = isFailOnError(configuration); + if (checkNullInput(texts, failOnError)) return Stream.empty(); + texts = texts.stream().filter(StringUtils::isNotBlank).toList(); + if (checkEmptyInput(texts, failOnError)) return Stream.empty(); return getEmbeddingResult(texts, apiKey, configuration, apocConfig, (map, text) -> { Long index = (Long) map.get("index"); return new EmbeddingResult(index, text, (List) map.get("embedding")); @@ -147,6 +159,8 @@ public Stream completion( "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 } } */ + boolean failOnError = isFailOnError(configuration); + if (checkBlankInput(prompt, failOnError)) return Stream.empty(); return executeRequest( apiKey, configuration, @@ -167,6 +181,10 @@ public Stream chatCompletion( @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + boolean failOnError = isFailOnError(configuration); + if (checkNullInput(messages, failOnError)) return Stream.empty(); + messages = messages.stream().filter(MapUtils::isNotEmpty).toList(); + if (checkEmptyInput(messages, failOnError)) return Stream.empty(); String model = (String) configuration.putIfAbsent("model", "gpt-4o"); return executeRequest(apiKey, configuration, "chat/completions", model, "messages", messages, "$", apocConfig) .map(v -> (Map) v) @@ -181,4 +199,28 @@ public Stream chatCompletion( } ] } */ } + + private static boolean isFailOnError(Map configuration) { + return Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true)); + } + + static boolean checkNullInput(Object input, boolean failOnError) { + return checkInput(failOnError, () -> Objects.isNull(input)); + } + + static boolean checkEmptyInput(Collection input, boolean failOnError) { + return checkInput(failOnError, input::isEmpty); + } + + static boolean checkBlankInput(String input, boolean failOnError) { + return checkInput(failOnError, () -> StringUtils.isBlank(input)); + } + + private static boolean checkInput(boolean failOnError, Supplier checkFunction) { + if (checkFunction.get()) { + if (failOnError) throw new RuntimeException(ERROR_NULL_INPUT); + return true; + } + return false; + } } diff --git a/full/src/test/java/apoc/ml/OpenAIIT.java b/full/src/test/java/apoc/ml/OpenAIIT.java index 012e6ba0b0..a816903967 100644 --- a/full/src/test/java/apoc/ml/OpenAIIT.java +++ b/full/src/test/java/apoc/ml/OpenAIIT.java @@ -1,10 +1,12 @@ package apoc.ml; +import static apoc.ml.MLUtil.ERROR_NULL_INPUT; import static apoc.ml.OpenAITestResultUtils.assertChatCompletion; import static apoc.util.TestUtil.testCall; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import apoc.util.ExtendedTestUtil; import apoc.util.TestUtil; import java.util.List; import java.util.Map; @@ -13,6 +15,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; @@ -138,4 +141,60 @@ public void chatCompletion() { } */ } + + @Test + public void embeddingsNull() { + assertNullInputFails( + db, + "CALL apoc.ml.openai.embedding(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap())); + } + + @Test + public void chatNull() { + assertNullInputFails( + db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", Map.of("apiKey", openaiKey, "conf", emptyMap())); + } + + @Test + public void chatReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty( + db, + "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))); + } + + @Test + public void embeddingsReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty( + db, + "CALL apoc.ml.openai.embedding(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))); + } + + @Test + public void chatWithEmptyFails() { + assertNullInputFails( + db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)", Map.of("apiKey", openaiKey, "conf", emptyMap())); + } + + @Test + public void embeddingsWithEmptyReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty( + db, + "CALL apoc.ml.openai.embedding([], $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))); + } + + @Test + public void completionReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty( + db, + "CALL apoc.ml.openai.completion(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))); + } + + public static void assertNullInputFails(GraphDatabaseService db, String query, Map params) { + ExtendedTestUtil.assertFails(db, query, params, ERROR_NULL_INPUT); + } } diff --git a/full/src/test/java/apoc/util/ExtendedTestUtil.java b/full/src/test/java/apoc/util/ExtendedTestUtil.java index dc3a8ba34c..0a2b4c7c31 100644 --- a/full/src/test/java/apoc/util/ExtendedTestUtil.java +++ b/full/src/test/java/apoc/util/ExtendedTestUtil.java @@ -1,6 +1,9 @@ package apoc.util; +import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testCallAssertions; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.neo4j.test.assertion.Assert.assertEventually; import java.util.Collections; @@ -67,4 +70,14 @@ public static void testResultEventually( timeout, TimeUnit.SECONDS); } + + public static void assertFails( + GraphDatabaseService db, String query, Map params, String expectedErrMsg) { + try { + testCall(db, query, params, r -> fail("Should fail due to " + expectedErrMsg)); + } catch (Exception e) { + String actualErrMsg = e.getMessage(); + assertTrue("Actual err. message is: " + actualErrMsg, actualErrMsg.contains(expectedErrMsg)); + } + } }