diff --git a/packages/proxy/schema/index.ts b/packages/proxy/schema/index.ts index 34163fa..3e8c8f5 100644 --- a/packages/proxy/schema/index.ts +++ b/packages/proxy/schema/index.ts @@ -133,7 +133,7 @@ export const modelProviderHasTools: { } = { openai: true, anthropic: true, - google: false, + google: true, js: false, window: false, converse: true, diff --git a/packages/proxy/src/providers/google.ts b/packages/proxy/src/providers/google.ts index 4603df1..9a7f095 100644 --- a/packages/proxy/src/providers/google.ts +++ b/packages/proxy/src/providers/google.ts @@ -54,27 +54,63 @@ export async function openAIMessagesToGoogleMessages( ): Promise { // First, do a basic mapping const content: Content[] = await Promise.all( - messages.map( - async (m: Message): Promise => ({ - parts: await openAIContentToGoogleContent(m.content), - // TODO: Add tool call support - role: m.role === "assistant" ? "model" : m.role, - }), - ), + messages.map(async (m) => { + const contentParts = + m.role === "tool" ? [] : await openAIContentToGoogleContent(m.content); + const toolCallParts: Part[] = + m.role === "assistant" + ? m.tool_calls?.map((t) => ({ + functionCall: { + name: t.id, + args: JSON.parse(t.function.arguments), + }, + })) ?? [] + : []; + const toolResponseParts: Part[] = + m.role === "tool" + ? [ + { + functionResponse: { + name: m.tool_call_id, + response: { + name: m.tool_call_id, + content: JSON.parse(m.content), + }, + }, + }, + ] + : []; + return { + parts: [...contentParts, ...toolCallParts, ...toolResponseParts], + role: + m.role === "assistant" + ? "model" + : m.role === "tool" + ? "user" + : m.role, + }; + }), ); - // Then, flatten each content item into an individual message - const flattenedContent: Content[] = content.flatMap((c) => - c.parts.map((p) => ({ - role: c.role, - parts: [p], - })), - ); + const flattenedContent: Content[] = []; + for (let i = 0; i < content.length; i++) { + if ( + flattenedContent.length > 0 && + flattenedContent[flattenedContent.length - 1].role === content[i].role + ) { + flattenedContent[flattenedContent.length - 1].parts = flattenedContent[ + flattenedContent.length - 1 + ].parts.concat(content[i].parts); + } else { + flattenedContent.push(content[i]); + } + } // Finally, sort the messages so that: // 1. All images are up front // 2. The system prompt. // 3. Then all user messages' text parts + // The EcmaScript spec requires the sort to be stable, so this is safe. const sortedContent: Content[] = flattenedContent.sort((a, b) => { if (a.parts[0].inlineData && !b.parts[0].inlineData) { return -1; @@ -122,14 +158,34 @@ export function googleEventToOpenAIChatEvent( event: data.candidates ? { id: uuidv4(), - choices: (data.candidates || []).map((candidate) => ({ - index: candidate.index, - delta: { - role: "assistant", - content: candidate.content.parts[0].text || "", - }, - finish_reason: translateFinishReason(candidate.finishReason), - })), + choices: (data.candidates || []).map((candidate) => { + const firstText = candidate.content.parts.find( + (p) => p.text !== undefined, + ); + const toolCalls = candidate.content.parts + .filter((p) => p.functionCall !== undefined) + .map((p, i) => ({ + id: uuidv4(), + type: "function" as const, + function: { + name: p.functionCall.name, + arguments: JSON.stringify(p.functionCall.args), + }, + index: i, + })); + return { + index: 0, + delta: { + role: "assistant", + content: firstText?.text ?? "", + tool_calls: toolCalls.length > 0 ? toolCalls : undefined, + }, + finish_reason: + toolCalls.length > 0 + ? "tool_calls" + : translateFinishReason(candidate.finishReason), + }; + }), created: getTimestampInSeconds(), model, object: "chat.completion.chunk", @@ -143,7 +199,9 @@ export function googleEventToOpenAIChatEvent( } : null, finished: - false /* all of the events seem to have STOP as the finish reason */, + data.candidates?.every( + (candidate) => candidate.finishReason !== undefined, + ) ?? false, }; } @@ -153,16 +211,35 @@ export function googleCompletionToOpenAICompletion( ): ChatCompletion { return { id: uuidv4(), - choices: (data.candidates || []).map((candidate) => ({ - logprobs: null, - index: candidate.index, - message: { - role: "assistant", - content: candidate.content.parts[0].text || "", - refusal: null, - }, - finish_reason: translateFinishReason(candidate.finishReason) || "stop", - })), + choices: (data.candidates || []).map((candidate) => { + const firstText = candidate.content.parts.find( + (p) => p.text !== undefined, + ); + const toolCalls = candidate.content.parts + .filter((p) => p.functionCall !== undefined) + .map((p) => ({ + id: uuidv4(), + type: "function" as const, + function: { + name: p.functionCall.name, + arguments: JSON.stringify(p.functionCall.args), + }, + })); + return { + logprobs: null, + index: candidate.index, + message: { + role: "assistant", + content: firstText?.text ?? "", + tool_calls: toolCalls.length > 0 ? toolCalls : undefined, + refusal: null, + }, + finish_reason: + toolCalls.length > 0 + ? "tool_calls" + : translateFinishReason(candidate.finishReason) || "stop", + }; + }), created: getTimestampInSeconds(), model, object: "chat.completion", diff --git a/packages/proxy/src/proxy.ts b/packages/proxy/src/proxy.ts index 96c0e67..5c5471b 100644 --- a/packages/proxy/src/proxy.ts +++ b/packages/proxy/src/proxy.ts @@ -66,7 +66,6 @@ import { } from "utils"; import { openAIChatCompletionToChatEvent } from "./providers/openai"; import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions"; -import { z } from "zod"; type CachedData = { headers: Record; @@ -1522,6 +1521,111 @@ async function fetchAnthropic( }; } +function pruneJsonSchemaToGoogle(schema: any): any { + if (!schema || typeof schema !== "object") { + return schema; + } + + const allowedFields = [ + "type", + "format", + "description", + "nullable", + "items", + "enum", + "properties", + "required", + "example", + ]; + + const result: any = {}; + + for (const [key, value] of Object.entries(schema)) { + if (!allowedFields.includes(key)) { + continue; + } + + if (key === "properties") { + result[key] = Object.fromEntries( + Object.entries(value as Record).map(([k, v]) => [ + k, + pruneJsonSchemaToGoogle(v), + ]), + ); + } else if (key === "items") { + result[key] = pruneJsonSchemaToGoogle(value); + } else { + result[key] = value; + } + } + + return result; +} + +function openAIToolsToGoogleTools(params: ChatCompletionCreateParams) { + if (params.tools || params.functions) { + params.tools = + params.tools || + (params.functions as Array).map( + (f: any) => ({ + type: "function", + function: f, + }), + ); + } + let tool_config: any = undefined; + if (params.tool_choice) { + switch (params.tool_choice) { + case "required": + tool_config = { + function_calling_config: { + mode: "ANY", + }, + }; + break; + case "none": + tool_config = { + function_calling_config: { + mode: "NONE", + }, + }; + break; + case "auto": + tool_config = { + function_calling_config: { + mode: "AUTO", + }, + }; + break; + default: + tool_config = { + function_calling_config: { + mode: "ANY", + allowed_function_names: [params.tool_choice.function.name], + }, + }; + break; + } + } + let out = { + tools: params.tools + ? [ + { + function_declarations: params.tools.map((t) => ({ + name: t.function.name, + description: t.function.description, + parameters: pruneJsonSchemaToGoogle(t.function.parameters), + })), + }, + ] + : undefined, + tool_config, + }; + delete params.tools; + delete params.tool_choice; + return out; +} + async function fetchGoogle( method: "POST", url: string, @@ -1577,13 +1681,27 @@ async function fetchGoogle( delete headers["authorization"]; headers["content-type"] = "application/json"; + if ( + oaiParams.response_format?.type === "json_object" || + oaiParams.response_format?.type === "json_schema" + ) { + params.response_mime_type = "application/json"; + } + if (oaiParams.response_format?.type === "json_schema") { + params.response_schema = pruneJsonSchemaToGoogle( + oaiParams.response_format.json_schema.schema, + ); + } + const body = JSON.stringify({ + contents: content, + generationConfig: params, + ...openAIToolsToGoogleTools(params), + }); + const proxyResponse = await fetch(fullURL.toString(), { method, headers, - body: JSON.stringify({ - contents: [content], - generationConfig: params, - }), + body, keepalive: true, });