Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #4153: Handling OpenAI 429's gracefully #4284

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ If present, they take precedence over the analogous APOC configs.
| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response.
The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure.
| failOnError | If true (default), the procedure fails in case of empty, blank or null input
| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false)
| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5)
| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false)
|===


Expand Down
18 changes: 16 additions & 2 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import apoc.ApocConfig;
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.ExtendedUtil;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -36,6 +37,9 @@ public class OpenAI {
public static final String PATH_CONF_KEY = "path";
public static final String GPT_4O_MODEL = "gpt-4o";
public static final String FAIL_ON_ERROR_CONF = "failOnError";
public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries";
public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff";
public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries";

@Context
public ApocConfig apocConfig;
Expand All @@ -59,6 +63,9 @@ public EmbeddingResult(long index, String text, List<Double> embedding) {

static Stream<Object> executeRequest(String apiKey, Map<String, Object> configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException {
apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
boolean enableBackOffRetries = Util.toBoolean( configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY) );
Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5));
boolean exponentialBackoff = Util.toBoolean( configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY) );
if (apiKey == null || apiKey.isBlank())
throw new IllegalArgumentException("API Key must not be empty");

Expand All @@ -78,7 +85,7 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
path = (String) configuration.getOrDefault(PATH_CONF_KEY, path);
OpenAIRequestHandler apiType = type.get();

jsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
headers.put("Content-Type", "application/json");
apiType.addApiKey(headers, apiKey);

Expand All @@ -88,7 +95,14 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
// therefore is better to join the not-empty path pieces
var url = apiType.getFullUrl(path, configuration, apocConfig);
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker);
return ExtendedUtil.withBackOffRetries(
() -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of(), urlAccessChecker),
enableBackOffRetries, backOffRetries, exponentialBackoff,
exception -> {
if(!exception.getMessage().contains("429"))
throw new RuntimeException(exception);
}
);
}

private static void handleAPIProvider(OpenAIRequestHandler.Type type,
Expand Down
67 changes: 58 additions & 9 deletions extended/src/main/java/apoc/util/ExtendedUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,9 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.temporal.TemporalAccessor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -353,5 +348,59 @@ public static float[] listOfNumbersToFloatArray(List<? extends Number> embedding
}
return floats;
}


public static <T> T withBackOffRetries(
Supplier<T> func,
boolean retry,
int backoffRetry,
boolean exponential,
Consumer<Exception> exceptionHandler
) {
T result;
backoffRetry = backoffRetry < 1
? 5
: backoffRetry;
int countDown = backoffRetry;
exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {});
while (true) {
try {
result = func.get();
break;
} catch (Exception e) {
if(!retry || countDown < 1) throw e;
exceptionHandler.accept(e);
countDown--;
long delay = getDelay(backoffRetry, countDown, exponential);
backoffSleep(delay);
}
}
return result;
}

private static void backoffSleep(long millis){
sleep(millis, "Operation interrupted during backoff");
}

public static void sleep(long millis, String interruptedMessage) {
try {
Thread.sleep(millis);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException(interruptedMessage, ie);
}
}

private static long getDelay(int backoffRetry, int countDown, boolean exponential) {
int backOffTime = backoffRetry - countDown;
long sleepMultiplier = exponential ?
(long) Math.pow(2, backOffTime) : // Exponential retry progression
backOffTime; // Linear retry progression
return Math.min(
Duration.ofSeconds(1)
.multipliedBy(sleepMultiplier)
.toMillis(),
Duration.ofSeconds(30).toMillis() // Max 30s
);
}

}
132 changes: 132 additions & 0 deletions extended/src/test/java/apoc/util/ExtendedUtilTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package apoc.util;

import org.junit.Test;

import static org.junit.Assert.*;
import static org.junit.Assert.assertTrue;

public class ExtendedUtilTest {

private static int i = 0;

@Test
public void testWithLinearBackOffRetriesWithSuccess() {
i = 0;
long start = System.currentTimeMillis();
int result = ExtendedUtil.withBackOffRetries(
this::testFunction,
true,
-1, // test backoffRetry default value -> 5
false,
runEx -> {
if(!runEx.getMessage().contains("Expected"))
throw new RuntimeException("Some Bad News...");
}
);
long time = System.currentTimeMillis() - start;

assertEquals(4, result);

// The method will attempt to execute the operation with a linear backoff strategy,
// sleeping for 1 second, 2 seconds, and 3 seconds between retries.
// This results in a total wait time of 6 seconds (1s + 2s + 3s + 4s) if the operation succeeds on the third attempt,
// leading to an approximate execution time of 6 seconds.
assertTrue("Current time is: " + time,
time > 9000 && time < 11000);
}

@Test
public void testWithExponentialBackOffRetriesWithSuccess() {
i = 0;
long start = System.currentTimeMillis();
int result = ExtendedUtil.withBackOffRetries(
this::testFunction,
true,
0, // test backoffRetry default value -> 5
true,
runEx -> {}
);
long time = System.currentTimeMillis() - start;

assertEquals(4, result);

// The method will attempt to execute the operation with an exponential backoff strategy,
// sleeping for 2 second, 4 seconds, and 8 seconds between retries.
// This results in a total wait time of 30 seconds (2s + 4s + 8s + 16s) if the operation succeeds on the third attempt,
// leading to an approximate execution time of 14 seconds.
assertTrue("Current time is: " + time,
time > 29000 && time < 31000);
}

@Test
public void testBackOffRetriesWithError() {
i = 0;
long start = System.currentTimeMillis();
assertThrows(
RuntimeException.class,
() -> ExtendedUtil.withBackOffRetries(
this::testFunction,
true,
2,
false,
runEx -> {}
)
);
long time = System.currentTimeMillis() - start;

// The method is configured to retry the operation twice.
// So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception.
// Resulting in an approximate execution time of 3 seconds.
assertTrue("Current time is: " + time,
time > 2000 && time < 4000);
}

@Test
public void testBackOffRetriesWithErrorAndExponential() {
i = 0;
long start = System.currentTimeMillis();
assertThrows(
RuntimeException.class,
() -> ExtendedUtil.withBackOffRetries(
this::testFunction,
true,
2,
true,
runEx -> {}
)
);
long time = System.currentTimeMillis() - start;

// The method is configured to retry the operation twice.
// So, it will make two extra-attempts, waiting for 2 second and 4 seconds before failing and throwing an exception.
// Resulting in an approximate execution time of 6 seconds.
assertTrue("Current time is: " + time,
time > 5000 && time < 7000);
}

@Test
public void testWithoutBackOffRetriesWithError() {
i = 0;
assertThrows(
RuntimeException.class,
() -> ExtendedUtil.withBackOffRetries(
this::testFunction,
false, 30,
false,
runEx -> {}
)
);

// Retry strategy is not active and the testFunction is executed only once by raising an exception.
assertEquals(1, i);
}

private int testFunction() {
if (i == 4) {
return i;
}
i++;
throw new RuntimeException("Expected i not equal to 4");
}

}
Loading