diff --git a/DESCRIPTION b/DESCRIPTION index f2935cd4..d07f00f2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -54,6 +54,7 @@ Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 Collate: + 'BatchChat.R' 'utils-S7.R' 'types.R' 'content.R' diff --git a/NEWS.md b/NEWS.md index ed38b01d..e0145d02 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # ellmer (development version) +* New `$chat_batch()` and `$extract_data_batch()` make it possible to use the + "batch" API provided by Claude, OpenAI, and Gemini (#143). Batch request are + typically 50% cheaper than regular requests but can take up to 24 hours to + complete. + * New `$chat_parallel()` and `$extract_data_parallel()` make it easier to perform multiple actions in parallel (#143). For Claude, note that the number of active connections is limited primarily by the output tokens per limit diff --git a/R/BatchChat.R b/R/BatchChat.R new file mode 100644 index 00000000..c00130a0 --- /dev/null +++ b/R/BatchChat.R @@ -0,0 +1,35 @@ +batch_wait <- function(provider, batch) { + info <- batch_info(provider, batch) + cli::cli_progress_bar( + format = paste0( + "{cli::pb_spin} Processing... {info$counts$processing} -> {cli::col_green({info$counts$succeeded})} / {cli::col_red({info$counts$failed})} ", + "[{cli::pb_elapsed}]" + ), + clear = FALSE + ) + tryCatch( + { + while (info$working) { + Sys.sleep(1) + cli::cli_progress_update() + batch <- batch_poll(provider, batch) + info <- batch_info(provider, batch) + } + }, + interrupt = function(cnd) {} + ) + + batch +} + + +check_has_batch_support <- function(provider, call = caller_env()) { + if (has_batch_support(provider)) { + return() + } + + cli::cli_abort( + "Batch requests are not currently supported by this provider.", + call = call + ) +} diff --git a/R/chat.R b/R/chat.R index 2db3089b..729378aa 100644 --- a/R/chat.R +++ b/R/chat.R @@ -173,12 +173,35 @@ Chat <- R6::R6Class("Chat", map2(json, turns[ok], function(json, user_turn) { chat <- self$clone() - turn <- value_turn(private$provider, json) - chat$add_turn(user_turn, turn) + ai_turn <- value_turn(private$provider, json) + chat$add_turn(user_turn, user_turn) chat }) }, + #' @description Submit multiple prompts in parallel. Returns a list of + #' [Chat] objects, one for each prompt. + #' @param prompts A list of user prompts. + #' @param max_active The maximum number of simultaenous requests to send. + #' @param rpm Maximum number of requests per minute. + chat_batch = function(prompts) { + check_has_batch_support(private$provider) + + turns <- as_user_turns(prompts) + new_turns <- map(turns, function(new_turn) c(private$.turns, list(new_turn))) + + batch <- batch_submit(private$provider, new_turns) + batch <- batch_wait(private$provider, batch) + results <- batch_retrieve(private$provider, batch) + + ok <- map_lgl(results, function(x) batch_result_ok(private$provider, x)) + + map2(results[ok], turns[ok], function(result, user_turn) { + ai_turn <- batch_result_turn(private$provider, result) + self$clone()$add_turn(user_turn, ai_turn) + }) + }, + #' @description Extract structured data #' @param ... The input to send to the chatbot. Will typically include #' the phrase "extract structured data". @@ -254,6 +277,28 @@ Chat <- R6::R6Class("Chat", }) }, + extract_data_batch = function(prompts, type, convert = TRUE) { + check_has_batch_support(private$provider) + turns <- as_user_turns(prompts) + check_bool(convert) + + needs_wrapper <- S7_inherits(private$provider, ProviderOpenAI) + if (needs_wrapper) { + type <- type_object(wrapper = type) + } + + new_turns <- map(turns, function(new_turn) c(private$.turns, list(new_turn))) + batch <- batch_submit(private$provider, new_turns, type = type) + batch <- batch_wait(provider, batch) + results <- batch_retrieve(provider, batch) + + ok <- map_lgl(results, function(x) batch_result_ok(private$provider, x)) + map2(results[ok], turns[ok], function(result, user_turn) { + turn <- batch_result_turn(private$provider, result, has_type = TRUE) + extract_data(turn, type, convert = convert, needs_wrapper = needs_wrapper) + }) + }, + #' @description Extract structured data, asynchronously. Returns a promise #' that resolves to an object matching the type specification. #' @param ... The input to send to the chatbot. Will typically include diff --git a/R/provider-claude.R b/R/provider-claude.R index 59aa692c..e27f8395 100644 --- a/R/provider-claude.R +++ b/R/provider-claude.R @@ -72,15 +72,8 @@ anthropic_key_exists <- function() { key_exists("ANTHROPIC_API_KEY") } -method(chat_request, ProviderClaude) <- function(provider, - stream = TRUE, - turns = list(), - tools = list(), - type = NULL) { - +method(base_request, ProviderClaude) <- function(provider) { req <- request(provider@base_url) - # https://docs.anthropic.com/en/api/messages - req <- req_url_path_append(req, "/messages") # req <- req_headers(req, `anthropic-version` = "2023-06-01") # @@ -102,6 +95,37 @@ method(chat_request, ProviderClaude) <- function(provider, } }) + req +} + +# Chat ------------------------------------------------------------------------ + +method(chat_request, ProviderClaude) <- function(provider, + stream = TRUE, + turns = list(), + tools = list(), + type = NULL) { + req <- base_request(provider) + # https://docs.anthropic.com/en/api/messages + req <- req_url_path_append(req, "/messages") + + body <- chat_body( + provider, + stream = stream, + turns = turns, + tools = tools, + type = type + ) + req <- req_body_json(req, body) + req +} + +method(chat_body, ProviderClaude) <- function(provider, + stream = TRUE, + turns = list(), + tools = list(), + type = NULL) { + if (length(turns) >= 1 && is_system_prompt(turns[[1]])) { system <- turns[[1]]@text } else { @@ -134,10 +158,82 @@ method(chat_request, ProviderClaude) <- function(provider, tools = tools, tool_choice = tool_choice, )) - body <- modify_list(body, provider@extra_args) - req <- req_body_json(req, body) + modify_list(body, provider@extra_args) +} - req +# Batch chat ------------------------------------------------------------------- + +method(has_batch_support, ProviderClaude) <- function(provider) { + TRUE +} + +# https://docs.anthropic.com/en/api/creating-message-batches +method(batch_submit, ProviderClaude) <- function(provider, turns, type = NULL) { + req <- base_request(provider) + req <- req_url_path_append(req, "/messages/batches") + + requests <- map(seq_along(turns), function(i) { + params <- chat_body( + provider, + stream = FALSE, + turns = turns[[i]], + type = type + ) + list( + custom_id = paste0("chat-", i), + params = params + ) + }) + req <- req_body_json(req, list(requests = requests)) + + resp <- req_perform(req) + resp_body_json(resp) +} + +# https://docs.anthropic.com/en/api/retrieving-message-batches +method(batch_poll, ProviderClaude) <- function(provider, batch) { + req <- base_request(provider) + req <- req_url_path_append(req, "/messages/batches/", batch$id) + resp <- req_perform(req) + resp_body_json(resp) +} + +method(batch_info, ProviderClaude) <- function(provider, batch) { + counts <- batch$request_counts + + list( + working = batch$processing_status != "ended", + counts = list( + processing = counts$processing, + succeeded = counts$succeeded, + failed = counts$errored + counts$canceled + counts$expired + ) + ) +} + +# https://docs.anthropic.com/en/api/retrieving-message-batch-results +method(batch_retrieve, ProviderClaude) <- function(provider, batch) { + req <- base_request(provider) + req <- req_url(req, batch$results_url) + req <- req_progress(req, "down") + + path <- withr::local_tempfile() + req <- req_perform(req, path = path) + + lines <- readLines(path, warn = FALSE) + json <- lapply(lines, jsonlite::fromJSON, simplifyVector = FALSE) + + ids <- as.numeric(gsub("chat-", "", map_chr(json, "[[", "custom_id"))) + results <- lapply(json, "[[", "result") + results[order(ids)] +} + +method(batch_result_ok, ProviderClaude) <- function(provider, result) { + result$type == "succeeded" +} + +method(batch_result_turn, ProviderClaude) <- function(provider, result, has_type = FALSE) { + value_turn(provider, result$message, has_type = has_type) } # Claude -> ellmer -------------------------------------------------------------- diff --git a/R/provider-gemini-upload.R b/R/provider-gemini-upload.R index 1badfed1..20706e3d 100644 --- a/R/provider-gemini-upload.R +++ b/R/provider-gemini-upload.R @@ -100,12 +100,12 @@ gemini_upload_status <- function(uri, credentials) { } gemini_upload_wait <- function(status, credentials) { - cli::cli_progress_bar(format = "{cli::pb_spin} Processing [{cli::pb_elapsed}] ") + cli::cli_progress_bar(format = "{cli::pb_spin} Processing... [{cli::pb_elapsed}] ") while (status$state == "PROCESSING") { cli::cli_progress_update() status <- gemini_upload_status(status$uri, credentials) - Sys.sleep(0.5) + Sys.sleep(1) } if (status$state == "FAILED") { cli::cli_abort("Upload failed: {status$error$message}") diff --git a/R/provider-openai.R b/R/provider-openai.R index c703779d..cdd23446 100644 --- a/R/provider-openai.R +++ b/R/provider-openai.R @@ -95,15 +95,8 @@ openai_key <- function() { key_get("OPENAI_API_KEY") } -# https://platform.openai.com/docs/api-reference/chat/create -method(chat_request, ProviderOpenAI) <- function(provider, - stream = TRUE, - turns = list(), - tools = list(), - type = NULL) { - +method(base_request, ProviderOpenAI) <- function(provider) { req <- request(provider@base_url) - req <- req_url_path_append(req, "/chat/completions") req <- req_auth_bearer_token(req, provider@api_key) req <- req_retry(req, max_tries = 2) req <- ellmer_req_timeout(req, stream) @@ -116,6 +109,36 @@ method(chat_request, ProviderOpenAI) <- function(provider, } }) + req +} + +# https://platform.openai.com/docs/api-reference/chat/create +method(chat_request, ProviderOpenAI) <- function(provider, + stream = TRUE, + turns = list(), + tools = list(), + type = NULL) { + + req <- base_request(provider) + req <- req_url_path_append(req, "/chat/completions") + + body <- chat_body(provider, + stream = stream, + turns = turns, + tools = tools, + type = type + ) + req <- req_body_json(req, body) + + req +} + +method(chat_body, ProviderOpenAI) <- function(provider, + stream = TRUE, + turns = list(), + tools = list(), + type = NULL) { + messages <- compact(unlist(as_json(provider, turns), recursive = FALSE)) tools <- as_json(provider, unname(tools)) @@ -142,11 +165,108 @@ method(chat_request, ProviderOpenAI) <- function(provider, response_format = response_format )) body <- utils::modifyList(body, provider@extra_args) - req <- req_body_json(req, body) - req + body +} + +# Batched requests ------------------------------------------------------------- + +method(has_batch_support, ProviderOpenAI) <- function(provider) { + TRUE +} + +# https://platform.openai.com/docs/api-reference/batch +method(batch_submit, ProviderOpenAI) <- function(provider, turns, type = NULL) { + path <- withr::local_tempfile() + + # First put the requests in a file + # https://platform.openai.com/docs/api-reference/batch/request-input + requests <- map(seq_along(turns), function(i) { + body <- chat_body(provider, stream = FALSE, turns = turns[[i]], type = type) + + list( + custom_id = paste0("chat-", i), + method = "POST", + url = "/v1/chat/completions", + body = body + ) + }) + json <- map_chr(requests, jsonlite::toJSON, auto_unbox = TRUE) + writeLines(json, path) + # Then upload it + uploaded <- openai_upload(provider, path) + + # Now we can submit the + req <- base_request(provider) + req <- req_url_path_append(req, "/batches") + req <- req_body_json(req, list( + input_file_id = uploaded$id, + endpoint = "/v1/chat/completions", + completion_window = "24h" + )) + + resp <- req_perform(req) + resp_body_json(resp) +} + +# https://platform.openai.com/docs/api-reference/batch/retrieve +openai_upload <- function(provider, path, purpose = "batch") { + req <- base_request(provider) + req <- req_url_path_append(req, "/files") + req <- req_body_multipart(req, purpose = purpose, file = curl::form_file(path)) + req <- req_progress(req, "up") + + resp <- req_perform(req) + resp_body_json(resp) +} + +# https://docs.anthropic.com/en/api/retrieving-message-batches +method(batch_poll, ProviderOpenAI) <- function(provider, batch) { + req <- base_request(provider) + req <- req_url_path_append(req, "/batches/", batch$id) + + resp <- req_perform(req) + resp_body_json(resp) +} + +method(batch_info, ProviderOpenAI) <- function(provider, batch) { + counts <- batch$request_counts + + list( + working = batch$status != "completed", + counts = list( + processing = counts$total - counts$completed, + succeeded = counts$completed, + failed = counts$failed + ) + ) +} + +# https://docs.anthropic.com/en/api/retrieving-message-batch-results +method(batch_retrieve, ProviderOpenAI) <- function(provider, batch) { + path <- withr::local_tempfile() + + req <- base_request(provider) + req <- req_url_path_append(req, "/files/", batch$output_file_id, "/content") + req <- req_progress(req, "down") + resp <- req_perform(req, path = path) + + lines <- readLines(path, warn = FALSE) + json <- lapply(lines, jsonlite::fromJSON, simplifyVector = FALSE) + + ids <- as.numeric(gsub("chat-", "", map_chr(json, "[[", "custom_id"))) + results <- lapply(json, "[[", "response") + results[order(ids)] } + +method(batch_result_ok, ProviderOpenAI) <- function(provider, result) { + result$status_code == 200 +} + +method(batch_result_turn, ProviderOpenAI) <- function(provider, result, has_type = FALSE) { + value_turn(provider, result$body, has_type = has_type) +} # OpenAI -> ellmer -------------------------------------------------------------- method(stream_parse, ProviderOpenAI) <- function(provider, event) { diff --git a/R/provider.R b/R/provider.R index cf53adba..55abfd80 100644 --- a/R/provider.R +++ b/R/provider.R @@ -28,12 +28,27 @@ Provider <- new_class( # Create a request------------------------------------ +base_request <- new_generic("base_request", "provider", + function(provider) { + S7_dispatch() + } +) + chat_request <- new_generic("chat_request", "provider", function(provider, stream = TRUE, turns = list(), tools = list(), type = NULL) { S7_dispatch() } ) +chat_body <- new_generic( + "chat_body", + "provider", + function(provider, stream = TRUE, turns = list(), tools = list(), type = NULL) { + S7_dispatch() + } +) + + chat_resp_stream <- new_generic("chat_resp_stream", "provider", function(provider, resp) { S7_dispatch() @@ -75,3 +90,63 @@ method(as_json, list(Provider, class_list)) <- function(provider, x) { method(as_json, list(Provider, ContentJson)) <- function(provider, x) { as_json(provider, ContentText("")) } + +# Batch API --------------------------------------------------------------- + +has_batch_support <- new_generic( + "has_batch_support", + "provider", + function(provider) { + S7_dispatch() + } +) +method(has_batch_support, class_any) <- function(provider) { + FALSE +} + +batch_submit <- new_generic( + "batch_submit", + "provider", + function(provider, turns, type = NULL) { + S7_dispatch() + } +) + +batch_poll <- new_generic( + "batch_poll", + "provider", + function(provider, batch) { + S7_dispatch() + } +) + +batch_retrieve <- new_generic( + "batch_retrieve", + "provider", + function(provider, batch) { + S7_dispatch() + } +) + +batch_info <- new_generic( + "batch_info", + "provider", + function(provider, batch) { + S7_dispatch() + } +) + +batch_result_ok <- new_generic( + "batch_result_ok", + "provider", + function(provider, result) { + S7_dispatch() + } +) +batch_result_turn <- new_generic( + "batch_result_turn", + "provider", + function(provider, result, has_type = FALSE) { + S7_dispatch() + } +) diff --git a/man/Chat.Rd b/man/Chat.Rd index 41c6097e..e304601f 100644 --- a/man/Chat.Rd +++ b/man/Chat.Rd @@ -37,6 +37,7 @@ chat$chat("Tell me a funny joke") \item \href{#method-Chat-last_turn}{\code{Chat$last_turn()}} \item \href{#method-Chat-chat}{\code{Chat$chat()}} \item \href{#method-Chat-chat_parallel}{\code{Chat$chat_parallel()}} +\item \href{#method-Chat-chat_batch}{\code{Chat$chat_batch()}} \item \href{#method-Chat-extract_data}{\code{Chat$extract_data()}} \item \href{#method-Chat-extract_data_parallel}{\code{Chat$extract_data_parallel()}} \item \href{#method-Chat-extract_data_async}{\code{Chat$extract_data_async()}} @@ -246,6 +247,28 @@ Submit multiple prompts in parallel. Returns a list of \item{\code{max_active}}{The maximum number of simultaenous requests to send.} +\item{\code{rpm}}{Maximum number of requests per minute.} +} +\if{html}{\out{}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Chat-chat_batch}{}}} +\subsection{Method \code{chat_batch()}}{ +Submit multiple prompts in parallel. Returns a list of +\link{Chat} objects, one for each prompt. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Chat$chat_batch(prompts)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{prompts}}{A list of user prompts.} + +\item{\code{max_active}}{The maximum number of simultaenous requests to send.} + \item{\code{rpm}}{Maximum number of requests per minute.} } \if{html}{\out{
}}