Skip to content

Commit

Permalink
Fixes #4153: Handling OpenAI 429's gracefully (#4284) (#4301)
Browse files Browse the repository at this point in the history
* Fixes #4153: Handling OpenAI 429's gracefully

* cleanup

* fix tests
  • Loading branch information
vga91 authored Dec 11, 2024
1 parent 6f0ecc6 commit bf0dd2d
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 11 deletions.
3 changes: 3 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ If present, they take precedence over the analogous APOC configs.
| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response.
The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure.
| failOnError | If true (default), the procedure fails in case of empty, blank or null input
| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false)
| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5)
| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false)
|===


Expand Down
18 changes: 16 additions & 2 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import apoc.ApocConfig;
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.ExtendedUtil;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -36,6 +37,9 @@ public class OpenAI {
public static final String PATH_CONF_KEY = "path";
public static final String GPT_4O_MODEL = "gpt-4o";
public static final String FAIL_ON_ERROR_CONF = "failOnError";
public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries";
public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff";
public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries";

@Context
public ApocConfig apocConfig;
Expand All @@ -59,6 +63,9 @@ public EmbeddingResult(long index, String text, List<Double> embedding) {

static Stream<Object> executeRequest(String apiKey, Map<String, Object> configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException {
apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
boolean enableBackOffRetries = Util.toBoolean( configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY) );
Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5));
boolean exponentialBackoff = Util.toBoolean( configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY) );
if (apiKey == null || apiKey.isBlank())
throw new IllegalArgumentException("API Key must not be empty");

Expand All @@ -78,7 +85,7 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
path = (String) configuration.getOrDefault(PATH_CONF_KEY, path);
OpenAIRequestHandler apiType = type.get();

jsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
headers.put("Content-Type", "application/json");
apiType.addApiKey(headers, apiKey);

Expand All @@ -88,7 +95,14 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
// therefore is better to join the not-empty path pieces
var url = apiType.getFullUrl(path, configuration, apocConfig);
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker);
return ExtendedUtil.withBackOffRetries(
() -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of(), urlAccessChecker),
enableBackOffRetries, backOffRetries, exponentialBackoff,
exception -> {
if(!exception.getMessage().contains("429"))
throw new RuntimeException(exception);
}
);
}

private static void handleAPIProvider(OpenAIRequestHandler.Type type,
Expand Down
67 changes: 58 additions & 9 deletions extended/src/main/java/apoc/util/ExtendedUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,9 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.temporal.TemporalAccessor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -353,5 +348,59 @@ public static float[] listOfNumbersToFloatArray(List<? extends Number> embedding
}
return floats;
}


public static <T> T withBackOffRetries(
Supplier<T> func,
boolean retry,
int backoffRetry,
boolean exponential,
Consumer<Exception> exceptionHandler
) {
T result;
backoffRetry = backoffRetry < 1
? 5
: backoffRetry;
int countDown = backoffRetry;
exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {});
while (true) {
try {
result = func.get();
break;
} catch (Exception e) {
if(!retry || countDown < 1) throw e;
exceptionHandler.accept(e);
countDown--;
long delay = getDelay(backoffRetry, countDown, exponential);
backoffSleep(delay);
}
}
return result;
}

private static void backoffSleep(long millis){
sleep(millis, "Operation interrupted during backoff");
}

public static void sleep(long millis, String interruptedMessage) {
try {
Thread.sleep(millis);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException(interruptedMessage, ie);
}
}

private static long getDelay(int backoffRetry, int countDown, boolean exponential) {
int backOffTime = backoffRetry - countDown;
long sleepMultiplier = exponential ?
(long) Math.pow(2, backOffTime) : // Exponential retry progression
backOffTime; // Linear retry progression
return Math.min(
Duration.ofSeconds(1)
.multipliedBy(sleepMultiplier)
.toMillis(),
Duration.ofSeconds(30).toMillis() // Max 30s
);
}

}
132 changes: 132 additions & 0 deletions extended/src/test/java/apoc/util/ExtendedUtilTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package apoc.util;

import org.junit.Test;

import static org.junit.Assert.*;
import static org.junit.Assert.assertTrue;

public class ExtendedUtilTest {

private static int i = 0;

@Test
public void testWithLinearBackOffRetriesWithSuccess() {
i = 0;
long start = System.currentTimeMillis();
int result = ExtendedUtil.withBackOffRetries(
this::testFunction,
true,
-1, // test backoffRetry default value -> 5
false,
runEx -> {
if(!runEx.getMessage().contains("Expected"))
throw new RuntimeException("Some Bad News...");
}
);
long time = System.currentTimeMillis() - start;

assertEquals(4, result);

// The method will attempt to execute the operation with a linear backoff strategy,
// sleeping for 1 second, 2 seconds, and 3 seconds between retries.
// This results in a total wait time of 6 seconds (1s + 2s + 3s + 4s) if the operation succeeds on the third attempt,
// leading to an approximate execution time of 6 seconds.
assertTrue("Current time is: " + time,
time > 9000 && time < 11000);
}

@Test
public void testWithExponentialBackOffRetriesWithSuccess() {
i = 0;
long start = System.currentTimeMillis();
int result = ExtendedUtil.withBackOffRetries(
this::testFunction,
true,
0, // test backoffRetry default value -> 5
true,
runEx -> {}
);
long time = System.currentTimeMillis() - start;

assertEquals(4, result);

// The method will attempt to execute the operation with an exponential backoff strategy,
// sleeping for 2 second, 4 seconds, and 8 seconds between retries.
// This results in a total wait time of 30 seconds (2s + 4s + 8s + 16s) if the operation succeeds on the third attempt,
// leading to an approximate execution time of 14 seconds.
assertTrue("Current time is: " + time,
time > 29000 && time < 31000);
}

@Test
public void testBackOffRetriesWithError() {
i = 0;
long start = System.currentTimeMillis();
assertThrows(
RuntimeException.class,
() -> ExtendedUtil.withBackOffRetries(
this::testFunction,
true,
2,
false,
runEx -> {}
)
);
long time = System.currentTimeMillis() - start;

// The method is configured to retry the operation twice.
// So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception.
// Resulting in an approximate execution time of 3 seconds.
assertTrue("Current time is: " + time,
time > 2000 && time < 4000);
}

@Test
public void testBackOffRetriesWithErrorAndExponential() {
i = 0;
long start = System.currentTimeMillis();
assertThrows(
RuntimeException.class,
() -> ExtendedUtil.withBackOffRetries(
this::testFunction,
true,
2,
true,
runEx -> {}
)
);
long time = System.currentTimeMillis() - start;

// The method is configured to retry the operation twice.
// So, it will make two extra-attempts, waiting for 2 second and 4 seconds before failing and throwing an exception.
// Resulting in an approximate execution time of 6 seconds.
assertTrue("Current time is: " + time,
time > 5000 && time < 7000);
}

@Test
public void testWithoutBackOffRetriesWithError() {
i = 0;
assertThrows(
RuntimeException.class,
() -> ExtendedUtil.withBackOffRetries(
this::testFunction,
false, 30,
false,
runEx -> {}
)
);

// Retry strategy is not active and the testFunction is executed only once by raising an exception.
assertEquals(1, i);
}

private int testFunction() {
if (i == 4) {
return i;
}
i++;
throw new RuntimeException("Expected i not equal to 4");
}

}

0 comments on commit bf0dd2d

Please sign in to comment.