diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index f7ff49dd5a..5592ca59f7 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -28,6 +28,9 @@ If present, they take precedence over the analogous APOC configs. | endpoint | analogous to `apoc.ml.openai.url` APOC config | apiVersion | analogous to `apoc.ml.azure.api.version` APOC config | failOnError | If true (default), the procedure fails in case of empty, blank or null input +| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false) +| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5) +| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false) |=== diff --git a/full/src/main/java/apoc/ml/OpenAI.java b/full/src/main/java/apoc/ml/OpenAI.java index 8a585cbddc..28f554e940 100644 --- a/full/src/main/java/apoc/ml/OpenAI.java +++ b/full/src/main/java/apoc/ml/OpenAI.java @@ -12,6 +12,7 @@ import apoc.ApocConfig; import apoc.Extended; import apoc.result.MapResult; +import apoc.util.ExtendedUtil; import apoc.util.JsonUtil; import apoc.util.Util; import com.fasterxml.jackson.core.JsonProcessingException; @@ -35,7 +36,11 @@ @Extended public class OpenAI { + public static final String JSON_PATH_CONF_KEY = "jsonPath"; public static final String FAIL_ON_ERROR_CONF = "failOnError"; + public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries"; + public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff"; + public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries"; @Context public ApocConfig apocConfig; @@ -63,6 +68,9 @@ static Stream executeRequest( ApocConfig apocConfig) throws JsonProcessingException, MalformedURLException { apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey)); + boolean enableBackOffRetries = Util.toBoolean(configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY)); + Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5)); + boolean exponentialBackoff = Util.toBoolean(configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY)); if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty"); String apiTypeString = (String) configuration.getOrDefault( API_TYPE_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_TYPE, OpenAIRequestHandler.Type.OPENAI.name())); @@ -83,6 +91,7 @@ static Stream executeRequest( OpenAIRequestHandler apiType = type.get(); final Map headers = new HashMap<>(); + String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath); headers.put("Content-Type", "application/json"); apiType.addApiKey(headers, apiKey); @@ -93,7 +102,14 @@ static Stream executeRequest( // eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model // therefore is better to join the not-empty path pieces var url = apiType.getFullUrl(path, configuration, apocConfig); - return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of()); + return ExtendedUtil.withBackOffRetries( + () -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of()), + enableBackOffRetries, + backOffRetries, + exponentialBackoff, + exception -> { + if (!exception.getMessage().contains("429")) throw new RuntimeException(exception); + }); } @Procedure("apoc.ml.openai.embedding") diff --git a/full/src/main/java/apoc/util/ExtendedUtil.java b/full/src/main/java/apoc/util/ExtendedUtil.java new file mode 100644 index 0000000000..8dba886ace --- /dev/null +++ b/full/src/main/java/apoc/util/ExtendedUtil.java @@ -0,0 +1,58 @@ +package apoc.util; + +import java.time.Duration; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; + +public class ExtendedUtil { + public static T withBackOffRetries( + Supplier func, + boolean retry, + int backoffRetry, + boolean exponential, + Consumer exceptionHandler) { + T result; + backoffRetry = backoffRetry < 1 ? 5 : backoffRetry; + int countDown = backoffRetry; + exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {}); + while (true) { + try { + result = func.get(); + break; + } catch (Exception e) { + if (!retry || countDown < 1) throw e; + exceptionHandler.accept(e); + countDown--; + long delay = getDelay(backoffRetry, countDown, exponential); + backoffSleep(delay); + } + } + return result; + } + + private static void backoffSleep(long millis) { + sleep(millis, "Operation interrupted during backoff"); + } + + public static void sleep(long millis, String interruptedMessage) { + try { + Thread.sleep(millis); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException(interruptedMessage, ie); + } + } + + private static long getDelay(int backoffRetry, int countDown, boolean exponential) { + int backOffTime = backoffRetry - countDown; + long sleepMultiplier = exponential + ? (long) Math.pow(2, backOffTime) + : // Exponential retry progression + backOffTime; // Linear retry progression + return Math.min( + Duration.ofSeconds(1).multipliedBy(sleepMultiplier).toMillis(), + Duration.ofSeconds(30).toMillis() // Max 30s + ); + } +} diff --git a/full/src/test/java/apoc/util/ExtendedUtilTest.java b/full/src/test/java/apoc/util/ExtendedUtilTest.java new file mode 100644 index 0000000000..20790ca38e --- /dev/null +++ b/full/src/test/java/apoc/util/ExtendedUtilTest.java @@ -0,0 +1,108 @@ +package apoc.util; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class ExtendedUtilTest { + + private static int i = 0; + + @Test + public void testWithLinearBackOffRetriesWithSuccess() { + i = 0; + long start = System.currentTimeMillis(); + int result = ExtendedUtil.withBackOffRetries( + this::testFunction, + true, + -1, // test backoffRetry default value -> 5 + false, + runEx -> { + if (!runEx.getMessage().contains("Expected")) throw new RuntimeException("Some Bad News..."); + }); + long time = System.currentTimeMillis() - start; + + assertEquals(4, result); + + // The method will attempt to execute the operation with a linear backoff strategy, + // sleeping for 1 second, 2 seconds, and 3 seconds between retries. + // This results in a total wait time of 6 seconds (1s + 2s + 3s + 4s) if the operation succeeds on the third + // attempt, + // leading to an approximate execution time of 6 seconds. + assertTrue("Current time is: " + time, time > 9000 && time < 11000); + } + + @Test + public void testWithExponentialBackOffRetriesWithSuccess() { + i = 0; + long start = System.currentTimeMillis(); + int result = ExtendedUtil.withBackOffRetries( + this::testFunction, + true, + 0, // test backoffRetry default value -> 5 + true, + runEx -> {}); + long time = System.currentTimeMillis() - start; + + assertEquals(4, result); + + // The method will attempt to execute the operation with an exponential backoff strategy, + // sleeping for 2 second, 4 seconds, and 8 seconds between retries. + // This results in a total wait time of 30 seconds (2s + 4s + 8s + 16s) if the operation succeeds on the third + // attempt, + // leading to an approximate execution time of 14 seconds. + assertTrue("Current time is: " + time, time > 29000 && time < 31000); + } + + @Test + public void testBackOffRetriesWithError() { + i = 0; + long start = System.currentTimeMillis(); + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries(this::testFunction, true, 2, false, runEx -> {})); + long time = System.currentTimeMillis() - start; + + // The method is configured to retry the operation twice. + // So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an + // exception. + // Resulting in an approximate execution time of 3 seconds. + assertTrue("Current time is: " + time, time > 2000 && time < 4000); + } + + @Test + public void testBackOffRetriesWithErrorAndExponential() { + i = 0; + long start = System.currentTimeMillis(); + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries(this::testFunction, true, 2, true, runEx -> {})); + long time = System.currentTimeMillis() - start; + + // The method is configured to retry the operation twice. + // So, it will make two extra-attempts, waiting for 2 second and 4 seconds before failing and throwing an + // exception. + // Resulting in an approximate execution time of 6 seconds. + assertTrue("Current time is: " + time, time > 5000 && time < 7000); + } + + @Test + public void testWithoutBackOffRetriesWithError() { + i = 0; + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries(this::testFunction, false, 30, false, runEx -> {})); + + // Retry strategy is not active and the testFunction is executed only once by raising an exception. + assertEquals(1, i); + } + + private int testFunction() { + if (i == 4) { + return i; + } + i++; + throw new RuntimeException("Expected i not equal to 4"); + } +}