Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support tool calls and structured outputs for gemini #149

Merged
merged 13 commits into from
Feb 9, 2025
2 changes: 1 addition & 1 deletion packages/proxy/schema/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ export const modelProviderHasTools: {
} = {
openai: true,
anthropic: true,
google: false,
google: true,
js: false,
window: false,
converse: true,
Expand Down
143 changes: 110 additions & 33 deletions packages/proxy/src/providers/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,63 @@ export async function openAIMessagesToGoogleMessages(
): Promise<Content[]> {
// First, do a basic mapping
const content: Content[] = await Promise.all(
messages.map(
async (m: Message): Promise<Content> => ({
parts: await openAIContentToGoogleContent(m.content),
// TODO: Add tool call support
role: m.role === "assistant" ? "model" : m.role,
}),
),
messages.map(async (m) => {
const contentParts =
m.role === "tool" ? [] : await openAIContentToGoogleContent(m.content);
const toolCallParts: Part[] =
m.role === "assistant"
? m.tool_calls?.map((t) => ({
functionCall: {
name: t.id,
args: JSON.parse(t.function.arguments),
},
})) ?? []
: [];
const toolResponseParts: Part[] =
m.role === "tool"
? [
{
functionResponse: {
name: m.tool_call_id,
response: {
name: m.tool_call_id,
content: JSON.parse(m.content),
},
},
},
]
: [];
return {
parts: [...contentParts, ...toolCallParts, ...toolResponseParts],
role:
m.role === "assistant"
? "model"
: m.role === "tool"
? "user"
: m.role,
};
}),
);

// Then, flatten each content item into an individual message
const flattenedContent: Content[] = content.flatMap((c) =>
c.parts.map((p) => ({
role: c.role,
parts: [p],
})),
);
const flattenedContent: Content[] = [];
for (let i = 0; i < content.length; i++) {
if (
flattenedContent.length > 0 &&
flattenedContent[flattenedContent.length - 1].role === content[i].role
) {
flattenedContent[flattenedContent.length - 1].parts = flattenedContent[
flattenedContent.length - 1
].parts.concat(content[i].parts);
} else {
flattenedContent.push(content[i]);
}
}

// Finally, sort the messages so that:
// 1. All images are up front
// 2. The system prompt.
// 3. Then all user messages' text parts
// The EcmaScript spec requires the sort to be stable, so this is safe.
const sortedContent: Content[] = flattenedContent.sort((a, b) => {
if (a.parts[0].inlineData && !b.parts[0].inlineData) {
return -1;
Expand Down Expand Up @@ -122,14 +158,34 @@ export function googleEventToOpenAIChatEvent(
event: data.candidates
? {
id: uuidv4(),
choices: (data.candidates || []).map((candidate) => ({
index: candidate.index,
delta: {
role: "assistant",
content: candidate.content.parts[0].text || "",
},
finish_reason: translateFinishReason(candidate.finishReason),
})),
choices: (data.candidates || []).map((candidate) => {
const firstText = candidate.content.parts.find(
(p) => p.text !== undefined,
);
const toolCalls = candidate.content.parts
.filter((p) => p.functionCall !== undefined)
.map((p, i) => ({
id: uuidv4(),
type: "function" as const,
function: {
name: p.functionCall.name,
arguments: JSON.stringify(p.functionCall.args),
},
index: i,
}));
return {
index: 0,
delta: {
role: "assistant",
content: firstText?.text ?? "",
tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
},
finish_reason:
toolCalls.length > 0
? "tool_calls"
: translateFinishReason(candidate.finishReason),
};
}),
created: getTimestampInSeconds(),
model,
object: "chat.completion.chunk",
Expand All @@ -143,7 +199,9 @@ export function googleEventToOpenAIChatEvent(
}
: null,
finished:
false /* all of the events seem to have STOP as the finish reason */,
data.candidates?.every(
(candidate) => candidate.finishReason !== undefined,
) ?? false,
};
}

Expand All @@ -153,16 +211,35 @@ export function googleCompletionToOpenAICompletion(
): ChatCompletion {
return {
id: uuidv4(),
choices: (data.candidates || []).map((candidate) => ({
logprobs: null,
index: candidate.index,
message: {
role: "assistant",
content: candidate.content.parts[0].text || "",
refusal: null,
},
finish_reason: translateFinishReason(candidate.finishReason) || "stop",
})),
choices: (data.candidates || []).map((candidate) => {
const firstText = candidate.content.parts.find(
(p) => p.text !== undefined,
);
const toolCalls = candidate.content.parts
.filter((p) => p.functionCall !== undefined)
.map((p) => ({
id: uuidv4(),
type: "function" as const,
function: {
name: p.functionCall.name,
arguments: JSON.stringify(p.functionCall.args),
},
}));
return {
logprobs: null,
index: candidate.index,
message: {
role: "assistant",
content: firstText?.text ?? "",
tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
refusal: null,
},
finish_reason:
toolCalls.length > 0
? "tool_calls"
: translateFinishReason(candidate.finishReason) || "stop",
};
}),
created: getTimestampInSeconds(),
model,
object: "chat.completion",
Expand Down
128 changes: 123 additions & 5 deletions packages/proxy/src/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ import {
} from "utils";
import { openAIChatCompletionToChatEvent } from "./providers/openai";
import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions";
import { z } from "zod";

type CachedData = {
headers: Record<string, string>;
Expand Down Expand Up @@ -1522,6 +1521,111 @@ async function fetchAnthropic(
};
}

function pruneJsonSchemaToGoogle(schema: any): any {
if (!schema || typeof schema !== "object") {
return schema;
}

const allowedFields = [
"type",
"format",
"description",
"nullable",
"items",
"enum",
"properties",
"required",
"example",
];

const result: any = {};

for (const [key, value] of Object.entries(schema)) {
if (!allowedFields.includes(key)) {
continue;
}

if (key === "properties") {
result[key] = Object.fromEntries(
Object.entries(value as Record<string, any>).map(([k, v]) => [
k,
pruneJsonSchemaToGoogle(v),
]),
);
} else if (key === "items") {
result[key] = pruneJsonSchemaToGoogle(value);
} else {
result[key] = value;
}
}

return result;
}

function openAIToolsToGoogleTools(params: ChatCompletionCreateParams) {
if (params.tools || params.functions) {
params.tools =
params.tools ||
(params.functions as Array<ChatCompletionCreateParams.Function>).map(
(f: any) => ({
type: "function",
function: f,
}),
);
}
let tool_config: any = undefined;
if (params.tool_choice) {
switch (params.tool_choice) {
case "required":
tool_config = {
function_calling_config: {
mode: "ANY",
},
};
break;
case "none":
tool_config = {
function_calling_config: {
mode: "NONE",
},
};
break;
case "auto":
tool_config = {
function_calling_config: {
mode: "AUTO",
},
};
break;
default:
tool_config = {
function_calling_config: {
mode: "ANY",
allowed_function_names: [params.tool_choice.function.name],
},
};
break;
}
}
let out = {
tools: params.tools
? [
{
function_declarations: params.tools.map((t) => ({
name: t.function.name,
description: t.function.description,
parameters: pruneJsonSchemaToGoogle(t.function.parameters),
})),
},
]
: undefined,
tool_config,
};
delete params.tools;
delete params.tool_choice;
return out;
}

async function fetchGoogle(
method: "POST",
url: string,
Expand Down Expand Up @@ -1577,13 +1681,27 @@ async function fetchGoogle(
delete headers["authorization"];
headers["content-type"] = "application/json";

if (
oaiParams.response_format?.type === "json_object" ||
oaiParams.response_format?.type === "json_schema"
) {
params.response_mime_type = "application/json";
}
if (oaiParams.response_format?.type === "json_schema") {
params.response_schema = pruneJsonSchemaToGoogle(
oaiParams.response_format.json_schema.schema,
);
}
const body = JSON.stringify({
contents: content,
generationConfig: params,
...openAIToolsToGoogleTools(params),
});

const proxyResponse = await fetch(fullURL.toString(), {
method,
headers,
body: JSON.stringify({
contents: [content],
generationConfig: params,
}),
body,
keepalive: true,
});

Expand Down