From c827c766c7ace2f88eb7c87c00b4532c8d0b9c36 Mon Sep 17 00:00:00 2001 From: vga91 Date: Tue, 26 Nov 2024 16:34:12 +0100 Subject: [PATCH 1/3] Fixes #4153: Handling OpenAI 429's gracefully --- .../modules/ROOT/pages/ml/openai.adoc | 3 + extended/src/main/java/apoc/ml/OpenAI.java | 18 ++- .../src/main/java/apoc/util/ExtendedUtil.java | 69 +++++++-- .../test/java/apoc/util/ExtendedUtilTest.java | 131 ++++++++++++++++++ 4 files changed, 210 insertions(+), 11 deletions(-) create mode 100644 extended/src/test/java/apoc/util/ExtendedUtilTest.java diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index b02fd7d4d7..4b54fc4be4 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -37,6 +37,9 @@ If present, they take precedence over the analogous APOC configs. | jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response. The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure. | failOnError | If true (default), the procedure fails in case of empty, blank or null input +| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false) +| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5) +| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false) |=== diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index 29a54a7f7d..4bc4814b33 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -3,6 +3,7 @@ import apoc.ApocConfig; import apoc.Extended; import apoc.result.MapResult; +import apoc.util.ExtendedUtil; import apoc.util.JsonUtil; import apoc.util.Util; import com.fasterxml.jackson.core.JsonProcessingException; @@ -36,6 +37,9 @@ public class OpenAI { public static final String PATH_CONF_KEY = "path"; public static final String GPT_4O_MODEL = "gpt-4o"; public static final String FAIL_ON_ERROR_CONF = "failOnError"; + public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries"; + public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff"; + public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries"; @Context public ApocConfig apocConfig; @@ -59,6 +63,9 @@ public EmbeddingResult(long index, String text, List embedding) { static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException { apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey)); + boolean enableBackOffRetries = Util.toBoolean( configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY) ); + Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5)); + boolean exponentialBackoff = Util.toBoolean( configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY) ); if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty"); @@ -78,7 +85,7 @@ static Stream executeRequest(String apiKey, Map configur path = (String) configuration.getOrDefault(PATH_CONF_KEY, path); OpenAIRequestHandler apiType = type.get(); - jsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath); + String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath); headers.put("Content-Type", "application/json"); apiType.addApiKey(headers, apiKey); @@ -88,7 +95,14 @@ static Stream executeRequest(String apiKey, Map configur // eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model // therefore is better to join the not-empty path pieces var url = apiType.getFullUrl(path, configuration, apocConfig); - return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker); + return ExtendedUtil.withBackOffRetries( + () -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of(), urlAccessChecker), + enableBackOffRetries, backOffRetries, exponentialBackoff, + exception -> { + if(!exception.getMessage().contains("429")) + throw new RuntimeException(exception); + } + ); } private static void handleAPIProvider(OpenAIRequestHandler.Type type, diff --git a/extended/src/main/java/apoc/util/ExtendedUtil.java b/extended/src/main/java/apoc/util/ExtendedUtil.java index f7fa8a5c9d..deb4f44a73 100644 --- a/extended/src/main/java/apoc/util/ExtendedUtil.java +++ b/extended/src/main/java/apoc/util/ExtendedUtil.java @@ -5,12 +5,14 @@ import com.fasterxml.jackson.core.json.JsonWriteFeature; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.lang3.StringUtils; +import org.eclipse.collections.api.block.function.Function0; import org.neo4j.exceptions.Neo4jException; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.ExecutionPlanDescription; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.QueryExecutionType; import org.neo4j.graphdb.Result; +import org.neo4j.logging.Log; import org.neo4j.procedure.Mode; import org.neo4j.values.storable.DateTimeValue; import org.neo4j.values.storable.DateValue; @@ -30,14 +32,11 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.temporal.TemporalAccessor; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.time.temporal.TemporalUnit; +import java.util.*; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.LongStream; import java.util.stream.Stream; @@ -353,5 +352,57 @@ public static float[] listOfNumbersToFloatArray(List embedding } return floats; } - + + public static T withBackOffRetries( + Supplier func, + boolean retry, + int backoffRetry, + boolean exponential, + Consumer exceptionHandler + ) { + T result = null; + backoffRetry = backoffRetry < 1 ? 5 : backoffRetry; + int countDown = backoffRetry; + exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {}); + while (true) { + try { + result = func.get(); + break; + } catch (Exception e) { + if(!retry || countDown < 1) throw e; + exceptionHandler.accept(e); + countDown--; + backoffSleep( + getDelay(backoffRetry, countDown, exponential) + ); + } + } + return result; + } + + private static void backoffSleep(long millis){ + sleep(millis, "Operation interrupted during backoff"); + } + + public static void sleep(long millis, String interruptedMessage) { + try { + Thread.sleep(millis); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException(interruptedMessage, ie); + } + } + + private static long getDelay(Integer backoffRetry, Integer countDown, boolean exponential){ + long sleepMultiplier = exponential ? + (long) Math.pow(2, backoffRetry - countDown) : // Exponential retry progression + backoffRetry - countDown; // Linear retry progression + return Math.min( + Duration.ofSeconds(1) + .multipliedBy(sleepMultiplier) + .toMillis(), + Duration.ofSeconds(30).toMillis() // Max 30s + ); + } + } diff --git a/extended/src/test/java/apoc/util/ExtendedUtilTest.java b/extended/src/test/java/apoc/util/ExtendedUtilTest.java new file mode 100644 index 0000000000..97f64c8272 --- /dev/null +++ b/extended/src/test/java/apoc/util/ExtendedUtilTest.java @@ -0,0 +1,131 @@ +package apoc.util; + +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertTrue; + +public class ExtendedUtilTest { + + private static int i = 0; + + @Test + public void testWithLinearBackOffRetriesWithSuccess() { + i = 0; + long start = System.currentTimeMillis(); + int result = ExtendedUtil.withBackOffRetries( + this::testFunction, + true, + -1, // test backoffRetry default value -> 5 + false, + runEx -> { + if(!runEx.getMessage().contains("Expected")) + throw new RuntimeException("Some Bad News..."); + } + ); + long time = System.currentTimeMillis() - start; + + assertEquals(4, result); + + // The method will attempt to execute the operation with a linear backoff strategy, + // sleeping for 1 second, 2 seconds, and 3 seconds between retries. + // This results in a total wait time of 6 seconds (1s + 2s + 3s) if the operation succeeds on the third attempt, + // leading to an approximate execution time of 6 seconds. + assertTrue(time > 5500); + assertTrue(time < 6500); + } + + @Test + public void testWithExponentialBackOffRetriesWithSuccess() { + i=0; + long start = System.currentTimeMillis(); + int result = ExtendedUtil.withBackOffRetries( + this::testFunction, + true, 0, // test backoffRetry default value -> 5 + false, + runEx -> { + if(!runEx.getMessage().contains("Expected")) + throw new RuntimeException("Some Bad News..."); + } + ); + long time = System.currentTimeMillis() - start; + + assertEquals(4, result); + + // The method will attempt to execute the operation with an exponential backoff strategy, + // sleeping for 2 second, 4 seconds, and 8 seconds between retries. + // This results in a total wait time of 14 seconds (2s + 4s + 8s) if the operation succeeds on the third attempt, + // leading to an approximate execution time of 14 seconds. + assertTrue(time > 13500); + assertTrue(time < 14500); + } + + @Test + public void testBackOffRetriesWithError() { + i=0; + long start = System.currentTimeMillis(); + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries( + this::testFunction, + true, 2, + false, + runEx -> {} + ) + ); + long time = System.currentTimeMillis() - start; + + // The method is configured to retry the operation twice. + // So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception. + // Resulting in an approximate execution time of 3 seconds. + assertTrue(time > 2500); + assertTrue(time < 3500); + } + + @Test + public void testBackOffRetriesWithErrorAndExponential() { + i=0; + long start = System.currentTimeMillis(); + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries( + this::testFunction, + true, 2, + true, + runEx -> {} + ) + ); + long time = System.currentTimeMillis() - start; + + // The method is configured to retry the operation twice. + // So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception. + // Resulting in an approximate execution time of 3 seconds. + assertTrue(time > 2500); + assertTrue(time < 3500); + } + + @Test + public void testWithoutBackOffRetriesWithError() { + i=0; + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries( + this::testFunction, + false, 30, false, + runEx -> {} + ) + ); + + // Retry strategy is not active and the testFunction is executed only once by raising an exception. + assertEquals(1, i); + } + + private int testFunction() { + i++; + if (i == 4) { + throw new RuntimeException("Expected i not equal to 4"); + } + return i; + } + +} From 78b678ca2cb57fd3e8d70dd6904541512e313c05 Mon Sep 17 00:00:00 2001 From: vga91 Date: Tue, 10 Dec 2024 18:13:47 +0100 Subject: [PATCH 2/3] cleanup --- extended/src/main/java/apoc/util/ExtendedUtil.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/extended/src/main/java/apoc/util/ExtendedUtil.java b/extended/src/main/java/apoc/util/ExtendedUtil.java index deb4f44a73..cf89e4a1a4 100644 --- a/extended/src/main/java/apoc/util/ExtendedUtil.java +++ b/extended/src/main/java/apoc/util/ExtendedUtil.java @@ -5,14 +5,12 @@ import com.fasterxml.jackson.core.json.JsonWriteFeature; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.lang3.StringUtils; -import org.eclipse.collections.api.block.function.Function0; import org.neo4j.exceptions.Neo4jException; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.ExecutionPlanDescription; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.QueryExecutionType; import org.neo4j.graphdb.Result; -import org.neo4j.logging.Log; import org.neo4j.procedure.Mode; import org.neo4j.values.storable.DateTimeValue; import org.neo4j.values.storable.DateValue; @@ -32,10 +30,8 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.temporal.TemporalAccessor; -import java.time.temporal.TemporalUnit; import java.util.*; import java.util.function.Consumer; -import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.LongStream; From ef131ed717808dd6d1481431fbbc0e90a97dfc93 Mon Sep 17 00:00:00 2001 From: vga91 Date: Wed, 11 Dec 2024 14:24:03 +0100 Subject: [PATCH 3/3] fix tests --- .../src/main/java/apoc/util/ExtendedUtil.java | 18 +++--- .../test/java/apoc/util/ExtendedUtilTest.java | 57 ++++++++++--------- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/extended/src/main/java/apoc/util/ExtendedUtil.java b/extended/src/main/java/apoc/util/ExtendedUtil.java index cf89e4a1a4..596e086c75 100644 --- a/extended/src/main/java/apoc/util/ExtendedUtil.java +++ b/extended/src/main/java/apoc/util/ExtendedUtil.java @@ -356,8 +356,10 @@ public static T withBackOffRetries( boolean exponential, Consumer exceptionHandler ) { - T result = null; - backoffRetry = backoffRetry < 1 ? 5 : backoffRetry; + T result; + backoffRetry = backoffRetry < 1 + ? 5 + : backoffRetry; int countDown = backoffRetry; exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {}); while (true) { @@ -368,9 +370,8 @@ public static T withBackOffRetries( if(!retry || countDown < 1) throw e; exceptionHandler.accept(e); countDown--; - backoffSleep( - getDelay(backoffRetry, countDown, exponential) - ); + long delay = getDelay(backoffRetry, countDown, exponential); + backoffSleep(delay); } } return result; @@ -389,10 +390,11 @@ public static void sleep(long millis, String interruptedMessage) { } } - private static long getDelay(Integer backoffRetry, Integer countDown, boolean exponential){ + private static long getDelay(int backoffRetry, int countDown, boolean exponential) { + int backOffTime = backoffRetry - countDown; long sleepMultiplier = exponential ? - (long) Math.pow(2, backoffRetry - countDown) : // Exponential retry progression - backoffRetry - countDown; // Linear retry progression + (long) Math.pow(2, backOffTime) : // Exponential retry progression + backOffTime; // Linear retry progression return Math.min( Duration.ofSeconds(1) .multipliedBy(sleepMultiplier) diff --git a/extended/src/test/java/apoc/util/ExtendedUtilTest.java b/extended/src/test/java/apoc/util/ExtendedUtilTest.java index 97f64c8272..05fe36f264 100644 --- a/extended/src/test/java/apoc/util/ExtendedUtilTest.java +++ b/extended/src/test/java/apoc/util/ExtendedUtilTest.java @@ -29,24 +29,22 @@ public void testWithLinearBackOffRetriesWithSuccess() { // The method will attempt to execute the operation with a linear backoff strategy, // sleeping for 1 second, 2 seconds, and 3 seconds between retries. - // This results in a total wait time of 6 seconds (1s + 2s + 3s) if the operation succeeds on the third attempt, + // This results in a total wait time of 6 seconds (1s + 2s + 3s + 4s) if the operation succeeds on the third attempt, // leading to an approximate execution time of 6 seconds. - assertTrue(time > 5500); - assertTrue(time < 6500); + assertTrue("Current time is: " + time, + time > 9000 && time < 11000); } @Test public void testWithExponentialBackOffRetriesWithSuccess() { - i=0; + i = 0; long start = System.currentTimeMillis(); int result = ExtendedUtil.withBackOffRetries( this::testFunction, - true, 0, // test backoffRetry default value -> 5 - false, - runEx -> { - if(!runEx.getMessage().contains("Expected")) - throw new RuntimeException("Some Bad News..."); - } + true, + 0, // test backoffRetry default value -> 5 + true, + runEx -> {} ); long time = System.currentTimeMillis() - start; @@ -54,21 +52,22 @@ public void testWithExponentialBackOffRetriesWithSuccess() { // The method will attempt to execute the operation with an exponential backoff strategy, // sleeping for 2 second, 4 seconds, and 8 seconds between retries. - // This results in a total wait time of 14 seconds (2s + 4s + 8s) if the operation succeeds on the third attempt, + // This results in a total wait time of 30 seconds (2s + 4s + 8s + 16s) if the operation succeeds on the third attempt, // leading to an approximate execution time of 14 seconds. - assertTrue(time > 13500); - assertTrue(time < 14500); + assertTrue("Current time is: " + time, + time > 29000 && time < 31000); } @Test public void testBackOffRetriesWithError() { - i=0; + i = 0; long start = System.currentTimeMillis(); assertThrows( RuntimeException.class, () -> ExtendedUtil.withBackOffRetries( this::testFunction, - true, 2, + true, + 2, false, runEx -> {} ) @@ -78,19 +77,20 @@ public void testBackOffRetriesWithError() { // The method is configured to retry the operation twice. // So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception. // Resulting in an approximate execution time of 3 seconds. - assertTrue(time > 2500); - assertTrue(time < 3500); + assertTrue("Current time is: " + time, + time > 2000 && time < 4000); } @Test public void testBackOffRetriesWithErrorAndExponential() { - i=0; + i = 0; long start = System.currentTimeMillis(); assertThrows( RuntimeException.class, () -> ExtendedUtil.withBackOffRetries( this::testFunction, - true, 2, + true, + 2, true, runEx -> {} ) @@ -98,20 +98,21 @@ public void testBackOffRetriesWithErrorAndExponential() { long time = System.currentTimeMillis() - start; // The method is configured to retry the operation twice. - // So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception. - // Resulting in an approximate execution time of 3 seconds. - assertTrue(time > 2500); - assertTrue(time < 3500); + // So, it will make two extra-attempts, waiting for 2 second and 4 seconds before failing and throwing an exception. + // Resulting in an approximate execution time of 6 seconds. + assertTrue("Current time is: " + time, + time > 5000 && time < 7000); } @Test public void testWithoutBackOffRetriesWithError() { - i=0; + i = 0; assertThrows( RuntimeException.class, () -> ExtendedUtil.withBackOffRetries( this::testFunction, - false, 30, false, + false, 30, + false, runEx -> {} ) ); @@ -121,11 +122,11 @@ public void testWithoutBackOffRetriesWithError() { } private int testFunction() { - i++; if (i == 4) { - throw new RuntimeException("Expected i not equal to 4"); + return i; } - return i; + i++; + throw new RuntimeException("Expected i not equal to 4"); } }