diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index b02fd7d4d7..4b54fc4be4 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -37,6 +37,9 @@ If present, they take precedence over the analogous APOC configs. | jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response. The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure. | failOnError | If true (default), the procedure fails in case of empty, blank or null input +| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false) +| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5) +| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false) |=== diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index 29a54a7f7d..4bc4814b33 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -3,6 +3,7 @@ import apoc.ApocConfig; import apoc.Extended; import apoc.result.MapResult; +import apoc.util.ExtendedUtil; import apoc.util.JsonUtil; import apoc.util.Util; import com.fasterxml.jackson.core.JsonProcessingException; @@ -36,6 +37,9 @@ public class OpenAI { public static final String PATH_CONF_KEY = "path"; public static final String GPT_4O_MODEL = "gpt-4o"; public static final String FAIL_ON_ERROR_CONF = "failOnError"; + public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries"; + public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff"; + public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries"; @Context public ApocConfig apocConfig; @@ -59,6 +63,9 @@ public EmbeddingResult(long index, String text, List embedding) { static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException { apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey)); + boolean enableBackOffRetries = Util.toBoolean( configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY) ); + Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5)); + boolean exponentialBackoff = Util.toBoolean( configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY) ); if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty"); @@ -78,7 +85,7 @@ static Stream executeRequest(String apiKey, Map configur path = (String) configuration.getOrDefault(PATH_CONF_KEY, path); OpenAIRequestHandler apiType = type.get(); - jsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath); + String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath); headers.put("Content-Type", "application/json"); apiType.addApiKey(headers, apiKey); @@ -88,7 +95,14 @@ static Stream executeRequest(String apiKey, Map configur // eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model // therefore is better to join the not-empty path pieces var url = apiType.getFullUrl(path, configuration, apocConfig); - return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker); + return ExtendedUtil.withBackOffRetries( + () -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of(), urlAccessChecker), + enableBackOffRetries, backOffRetries, exponentialBackoff, + exception -> { + if(!exception.getMessage().contains("429")) + throw new RuntimeException(exception); + } + ); } private static void handleAPIProvider(OpenAIRequestHandler.Type type, diff --git a/extended/src/main/java/apoc/util/ExtendedUtil.java b/extended/src/main/java/apoc/util/ExtendedUtil.java index f7fa8a5c9d..596e086c75 100644 --- a/extended/src/main/java/apoc/util/ExtendedUtil.java +++ b/extended/src/main/java/apoc/util/ExtendedUtil.java @@ -30,14 +30,9 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.temporal.TemporalAccessor; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; +import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.LongStream; import java.util.stream.Stream; @@ -353,5 +348,59 @@ public static float[] listOfNumbersToFloatArray(List embedding } return floats; } - + + public static T withBackOffRetries( + Supplier func, + boolean retry, + int backoffRetry, + boolean exponential, + Consumer exceptionHandler + ) { + T result; + backoffRetry = backoffRetry < 1 + ? 5 + : backoffRetry; + int countDown = backoffRetry; + exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {}); + while (true) { + try { + result = func.get(); + break; + } catch (Exception e) { + if(!retry || countDown < 1) throw e; + exceptionHandler.accept(e); + countDown--; + long delay = getDelay(backoffRetry, countDown, exponential); + backoffSleep(delay); + } + } + return result; + } + + private static void backoffSleep(long millis){ + sleep(millis, "Operation interrupted during backoff"); + } + + public static void sleep(long millis, String interruptedMessage) { + try { + Thread.sleep(millis); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException(interruptedMessage, ie); + } + } + + private static long getDelay(int backoffRetry, int countDown, boolean exponential) { + int backOffTime = backoffRetry - countDown; + long sleepMultiplier = exponential ? + (long) Math.pow(2, backOffTime) : // Exponential retry progression + backOffTime; // Linear retry progression + return Math.min( + Duration.ofSeconds(1) + .multipliedBy(sleepMultiplier) + .toMillis(), + Duration.ofSeconds(30).toMillis() // Max 30s + ); + } + } diff --git a/extended/src/test/java/apoc/util/ExtendedUtilTest.java b/extended/src/test/java/apoc/util/ExtendedUtilTest.java new file mode 100644 index 0000000000..05fe36f264 --- /dev/null +++ b/extended/src/test/java/apoc/util/ExtendedUtilTest.java @@ -0,0 +1,132 @@ +package apoc.util; + +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertTrue; + +public class ExtendedUtilTest { + + private static int i = 0; + + @Test + public void testWithLinearBackOffRetriesWithSuccess() { + i = 0; + long start = System.currentTimeMillis(); + int result = ExtendedUtil.withBackOffRetries( + this::testFunction, + true, + -1, // test backoffRetry default value -> 5 + false, + runEx -> { + if(!runEx.getMessage().contains("Expected")) + throw new RuntimeException("Some Bad News..."); + } + ); + long time = System.currentTimeMillis() - start; + + assertEquals(4, result); + + // The method will attempt to execute the operation with a linear backoff strategy, + // sleeping for 1 second, 2 seconds, and 3 seconds between retries. + // This results in a total wait time of 6 seconds (1s + 2s + 3s + 4s) if the operation succeeds on the third attempt, + // leading to an approximate execution time of 6 seconds. + assertTrue("Current time is: " + time, + time > 9000 && time < 11000); + } + + @Test + public void testWithExponentialBackOffRetriesWithSuccess() { + i = 0; + long start = System.currentTimeMillis(); + int result = ExtendedUtil.withBackOffRetries( + this::testFunction, + true, + 0, // test backoffRetry default value -> 5 + true, + runEx -> {} + ); + long time = System.currentTimeMillis() - start; + + assertEquals(4, result); + + // The method will attempt to execute the operation with an exponential backoff strategy, + // sleeping for 2 second, 4 seconds, and 8 seconds between retries. + // This results in a total wait time of 30 seconds (2s + 4s + 8s + 16s) if the operation succeeds on the third attempt, + // leading to an approximate execution time of 14 seconds. + assertTrue("Current time is: " + time, + time > 29000 && time < 31000); + } + + @Test + public void testBackOffRetriesWithError() { + i = 0; + long start = System.currentTimeMillis(); + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries( + this::testFunction, + true, + 2, + false, + runEx -> {} + ) + ); + long time = System.currentTimeMillis() - start; + + // The method is configured to retry the operation twice. + // So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception. + // Resulting in an approximate execution time of 3 seconds. + assertTrue("Current time is: " + time, + time > 2000 && time < 4000); + } + + @Test + public void testBackOffRetriesWithErrorAndExponential() { + i = 0; + long start = System.currentTimeMillis(); + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries( + this::testFunction, + true, + 2, + true, + runEx -> {} + ) + ); + long time = System.currentTimeMillis() - start; + + // The method is configured to retry the operation twice. + // So, it will make two extra-attempts, waiting for 2 second and 4 seconds before failing and throwing an exception. + // Resulting in an approximate execution time of 6 seconds. + assertTrue("Current time is: " + time, + time > 5000 && time < 7000); + } + + @Test + public void testWithoutBackOffRetriesWithError() { + i = 0; + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries( + this::testFunction, + false, 30, + false, + runEx -> {} + ) + ); + + // Retry strategy is not active and the testFunction is executed only once by raising an exception. + assertEquals(1, i); + } + + private int testFunction() { + if (i == 4) { + return i; + } + i++; + throw new RuntimeException("Expected i not equal to 4"); + } + +}