Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attach all b64 images in the proxy #163

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions apis/node/src/login.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import bsearch from "binary-search";
import { Env } from "./env";
import { APISecret } from "@braintrust/proxy/schema";
import { loginToState } from "braintrust";

export async function lookupApiSecret(
useCache: boolean,
Expand Down Expand Up @@ -53,6 +54,14 @@ export async function lookupApiSecret(
return secrets;
}

export function sdkLogin(authToken: string, orgName: string | undefined) {
return loginToState({
appUrl: Env.braintrustApiUrl,
apiKey: authToken,
orgName: orgName ?? Env.orgName,
});
}

function fixIndex(i: number) {
return i >= 0 ? i : -i - 1;
}
Expand Down
3 changes: 2 additions & 1 deletion apis/node/src/node-proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import type * as streamWeb from "node:stream/web";
import { proxyV1 } from "@braintrust/proxy";

import { getRedis } from "./cache";
import { lookupApiSecret } from "./login";
import { lookupApiSecret, sdkLogin } from "./login";

export async function nodeProxyV1({
method,
Expand Down Expand Up @@ -71,6 +71,7 @@ export async function nodeProxyV1({
digest: async (message: string) => {
return crypto.createHash("md5").update(message).digest("hex");
},
sdkLogin,
});

const res = getRes();
Expand Down
8 changes: 8 additions & 0 deletions packages/proxy/edge/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
EncryptedMessage,
encryptMessage,
} from "utils/encrypt";
import { loginToState } from "braintrust";

export { FlushingExporter } from "./exporter";

Expand Down Expand Up @@ -310,6 +311,13 @@ export function EdgeProxyV1(opts: ProxyOpts) {
cacheGet,
cachePut,
digest: digestMessage,
sdkLogin: async (authToken: string, orgName: string | undefined) => {
return loginToState({
appUrl: opts.braintrustApiUrl ?? DEFAULT_BRAINTRUST_APP_URL,
apiKey: authToken,
orgName,
});
},
meterProvider,
});
} catch (e) {
Expand Down
87 changes: 87 additions & 0 deletions packages/proxy/src/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
OpenAIParamsToGoogleParams,
} from "./providers/google";
import {
chatCompletionMessageParamSchema,
Message,
MessageRole,
responseFormatSchema,
Expand Down Expand Up @@ -73,6 +74,7 @@ import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completion
import { importPKCS8, SignJWT } from "jose";
import { z } from "zod";
import $RefParser from "@apidevtools/json-schema-ref-parser";
import { Attachment, BraintrustState, ReadonlyAttachment } from "braintrust";

type CachedMetadata = {
cached_at: Date;
Expand Down Expand Up @@ -137,6 +139,7 @@ export async function proxyV1({
cacheKeyOptions = {},
decompressFetch = false,
spanLogger,
sdkLogin,
}: {
method: "GET" | "POST";
url: string;
Expand All @@ -159,6 +162,10 @@ export async function proxyV1({
ttl_seconds?: number,
) => Promise<void>;
digest: (message: string) => Promise<string>;
sdkLogin: (
authToken: string,
orgName: string | undefined,
) => Promise<BraintrustState>;
meterProvider?: MeterProvider;
cacheKeyOptions?: CacheKeyOptions;
decompressFetch?: boolean;
Expand Down Expand Up @@ -431,6 +438,8 @@ export async function proxyV1({
);
}

let sdkState: BraintrustState | undefined;

const {
modelResponse: { response: proxyResponse, stream: proxyStream },
secretName,
Expand Down Expand Up @@ -487,6 +496,12 @@ export async function proxyV1({
(st) => {
spanType = st;
},
async () => {
if (!sdkState) {
sdkState = await sdkLogin(authToken, orgName);
}
return sdkState;
},
);
stream = proxyStream;

Expand Down Expand Up @@ -806,6 +821,75 @@ const TRY_ANOTHER_ENDPOINT_ERROR_CODES = [
RATE_LIMIT_ERROR_CODE,
];

async function attachRawImages(
getSdkState: () => Promise<BraintrustState>,
bodyData: any,
) {
const parsed = z
.array(chatCompletionMessageParamSchema)
.safeParse(bodyData.messages);
if (!parsed.success) {
return bodyData;
}

const messages = await Promise.all(
parsed.data.map(async (msg) => {
if (msg.role !== "user" || !Array.isArray(msg.content)) {
return msg;
}

const content = await Promise.all(
msg.content.map(async (item) => {
if (item.type !== "image_url") {
return item;
}

const match = item.image_url.url.match(
/^data:image\/([a-zA-Z]+);base64,([a-zA-Z0-9+/=]+)$/,
);
if (!match) {
return item;
}

const [, type, b64] = match;
const buf = Buffer.from(b64, "base64");
if (!buf.length) {
console.warn("Empty image buffer after base64 decode");
return item;
}

try {
const state = await getSdkState();
const att = new Attachment({
data: buf,
contentType: `image/${type}`,
filename: "(embedded)",
state,
});

await att.upload();
return {
...item,
image_url: {
url: (
await new ReadonlyAttachment(att.reference, state).metadata()
).downloadUrl,
},
};
} catch (err) {
console.warn("Failed to process base64 image:", err);
return item;
}
}),
);

return { ...msg, content };
}),
);

return { ...bodyData, messages };
}

let loopIndex = 0;
async function fetchModelLoop(
meter: Meter,
Expand All @@ -817,6 +901,7 @@ async function fetchModelLoop(
getApiSecrets: (model: string | null) => Promise<APISecret[]>,
spanLogger: SpanLogger | undefined,
setSpanType: (spanType: SpanType) => void,
getSdkState: () => Promise<BraintrustState>,
): Promise<{ modelResponse: ModelResponse; secretName?: string | null }> {
const requestId = ++loopIndex;

Expand Down Expand Up @@ -846,6 +931,8 @@ async function fetchModelLoop(
model = bodyData.model;
}

bodyData = await attachRawImages(getSdkState, bodyData);

// TODO: Make this smarter. For now, just pick a random one.
const secrets = await getApiSecrets(model);
const initialIdx = getRandomInt(secrets.length);
Expand Down
Loading