From b1b33303c5755d0065aebfa67198b1d260442027 Mon Sep 17 00:00:00 2001 From: Sachin Padmanabhan Date: Fri, 21 Feb 2025 10:53:41 -0800 Subject: [PATCH] attach all b64 images in the proxy --- apis/node/src/login.ts | 9 ++++ apis/node/src/node-proxy.ts | 3 +- packages/proxy/edge/index.ts | 8 ++++ packages/proxy/src/proxy.ts | 87 ++++++++++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 1 deletion(-) diff --git a/apis/node/src/login.ts b/apis/node/src/login.ts index 01309bf..b06df82 100644 --- a/apis/node/src/login.ts +++ b/apis/node/src/login.ts @@ -1,6 +1,7 @@ import bsearch from "binary-search"; import { Env } from "./env"; import { APISecret } from "@braintrust/proxy/schema"; +import { loginToState } from "braintrust"; export async function lookupApiSecret( useCache: boolean, @@ -53,6 +54,14 @@ export async function lookupApiSecret( return secrets; } +export function sdkLogin(authToken: string, orgName: string | undefined) { + return loginToState({ + appUrl: Env.braintrustApiUrl, + apiKey: authToken, + orgName: orgName ?? Env.orgName, + }); +} + function fixIndex(i: number) { return i >= 0 ? i : -i - 1; } diff --git a/apis/node/src/node-proxy.ts b/apis/node/src/node-proxy.ts index a284736..9b7668a 100644 --- a/apis/node/src/node-proxy.ts +++ b/apis/node/src/node-proxy.ts @@ -7,7 +7,7 @@ import type * as streamWeb from "node:stream/web"; import { proxyV1 } from "@braintrust/proxy"; import { getRedis } from "./cache"; -import { lookupApiSecret } from "./login"; +import { lookupApiSecret, sdkLogin } from "./login"; export async function nodeProxyV1({ method, @@ -71,6 +71,7 @@ export async function nodeProxyV1({ digest: async (message: string) => { return crypto.createHash("md5").update(message).digest("hex"); }, + sdkLogin, }); const res = getRes(); diff --git a/packages/proxy/edge/index.ts b/packages/proxy/edge/index.ts index 5a99230..24344d6 100644 --- a/packages/proxy/edge/index.ts +++ b/packages/proxy/edge/index.ts @@ -11,6 +11,7 @@ import { EncryptedMessage, encryptMessage, } from "utils/encrypt"; +import { loginToState } from "braintrust"; export { FlushingExporter } from "./exporter"; @@ -310,6 +311,13 @@ export function EdgeProxyV1(opts: ProxyOpts) { cacheGet, cachePut, digest: digestMessage, + sdkLogin: async (authToken: string, orgName: string | undefined) => { + return loginToState({ + appUrl: opts.braintrustApiUrl ?? DEFAULT_BRAINTRUST_APP_URL, + apiKey: authToken, + orgName, + }); + }, meterProvider, }); } catch (e) { diff --git a/packages/proxy/src/proxy.ts b/packages/proxy/src/proxy.ts index 6bc00ca..3c052d1 100644 --- a/packages/proxy/src/proxy.ts +++ b/packages/proxy/src/proxy.ts @@ -44,6 +44,7 @@ import { OpenAIParamsToGoogleParams, } from "./providers/google"; import { + chatCompletionMessageParamSchema, Message, MessageRole, responseFormatSchema, @@ -73,6 +74,7 @@ import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completion import { importPKCS8, SignJWT } from "jose"; import { z } from "zod"; import $RefParser from "@apidevtools/json-schema-ref-parser"; +import { Attachment, BraintrustState, ReadonlyAttachment } from "braintrust"; type CachedMetadata = { cached_at: Date; @@ -137,6 +139,7 @@ export async function proxyV1({ cacheKeyOptions = {}, decompressFetch = false, spanLogger, + sdkLogin, }: { method: "GET" | "POST"; url: string; @@ -159,6 +162,10 @@ export async function proxyV1({ ttl_seconds?: number, ) => Promise; digest: (message: string) => Promise; + sdkLogin: ( + authToken: string, + orgName: string | undefined, + ) => Promise; meterProvider?: MeterProvider; cacheKeyOptions?: CacheKeyOptions; decompressFetch?: boolean; @@ -431,6 +438,8 @@ export async function proxyV1({ ); } + let sdkState: BraintrustState | undefined; + const { modelResponse: { response: proxyResponse, stream: proxyStream }, secretName, @@ -487,6 +496,12 @@ export async function proxyV1({ (st) => { spanType = st; }, + async () => { + if (!sdkState) { + sdkState = await sdkLogin(authToken, orgName); + } + return sdkState; + }, ); stream = proxyStream; @@ -806,6 +821,75 @@ const TRY_ANOTHER_ENDPOINT_ERROR_CODES = [ RATE_LIMIT_ERROR_CODE, ]; +async function attachRawImages( + getSdkState: () => Promise, + bodyData: any, +) { + const parsed = z + .array(chatCompletionMessageParamSchema) + .safeParse(bodyData.messages); + if (!parsed.success) { + return bodyData; + } + + const messages = await Promise.all( + parsed.data.map(async (msg) => { + if (msg.role !== "user" || !Array.isArray(msg.content)) { + return msg; + } + + const content = await Promise.all( + msg.content.map(async (item) => { + if (item.type !== "image_url") { + return item; + } + + const match = item.image_url.url.match( + /^data:image\/([a-zA-Z]+);base64,([a-zA-Z0-9+/=]+)$/, + ); + if (!match) { + return item; + } + + const [, type, b64] = match; + const buf = Buffer.from(b64, "base64"); + if (!buf.length) { + console.warn("Empty image buffer after base64 decode"); + return item; + } + + try { + const state = await getSdkState(); + const att = new Attachment({ + data: buf, + contentType: `image/${type}`, + filename: "(embedded)", + state, + }); + + await att.upload(); + return { + ...item, + image_url: { + url: ( + await new ReadonlyAttachment(att.reference, state).metadata() + ).downloadUrl, + }, + }; + } catch (err) { + console.warn("Failed to process base64 image:", err); + return item; + } + }), + ); + + return { ...msg, content }; + }), + ); + + return { ...bodyData, messages }; +} + let loopIndex = 0; async function fetchModelLoop( meter: Meter, @@ -817,6 +901,7 @@ async function fetchModelLoop( getApiSecrets: (model: string | null) => Promise, spanLogger: SpanLogger | undefined, setSpanType: (spanType: SpanType) => void, + getSdkState: () => Promise, ): Promise<{ modelResponse: ModelResponse; secretName?: string | null }> { const requestId = ++loopIndex; @@ -846,6 +931,8 @@ async function fetchModelLoop( model = bodyData.model; } + bodyData = await attachRawImages(getSdkState, bodyData); + // TODO: Make this smarter. For now, just pick a random one. const secrets = await getApiSecrets(model); const initialIdx = getRandomInt(secrets.length);