Skip to content

Commit

Permalink
google-common/webauth/gauth [patch]: Fixes for error, safety, and ima…
Browse files Browse the repository at this point in the history
…ge handling (langchain-ai#4583)

* Better handling when API call makes an error. Tailored to how Google will respond.

* Handle finish reason or prompt block reason.

* Fixes for image types

* Support for GoogleBaseLLM .invoke() and .stream() that can handle messages with image parts.

* formatting

* Refactor safety handler to be a class

* formatting

* Fix bug turning a chunk into a string when streaming an LLM response

* Refactor DefaultGeminiSafetyHandler.
Add MessageGeminiSafetyHandler that will mask exceptions with a custom message.

* Adding tests for invoke()

* Explicitly type safety handler.
Code cleanup.

* Change from overriding stream() to overriding _streamIterator() as discussed in langchain-ai#4583

* Formatting

* Refactor LLM methods to support callbacks

* Fix mock

* Test everything

* Lint

* Resolve circular deps

* Fix

* Fix

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
afirstenberg and jacoblee93 authored Mar 8, 2024
1 parent f0762cc commit 5030088
Show file tree
Hide file tree
Showing 22 changed files with 1,905 additions and 119 deletions.
25 changes: 18 additions & 7 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@ import {
import { AbstractGoogleLLMConnection } from "./connection.js";
import {
baseMessageToContent,
responseToChatGeneration,
responseToChatResult,
safeResponseToChatGeneration,
safeResponseToChatResult,
DefaultGeminiSafetyHandler,
} from "./utils/gemini.js";
import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js";
import { GoogleBaseLLMInput } from "./llms.js";
import { JsonStream } from "./utils/stream.js";
import { ensureParams } from "./utils/failed_handler.js";
import type {
GoogleBaseLLMInput,
GoogleAISafetyHandler,
GoogleAISafetyParams,
} from "./types.js";

class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<
BaseMessage[],
Expand All @@ -51,7 +57,8 @@ class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<
export interface ChatGoogleBaseInput<AuthOptions>
extends BaseChatModelParams,
GoogleConnectionParams<AuthOptions>,
GoogleAIModelParams {}
GoogleAIModelParams,
GoogleAISafetyParams {}

/**
* Integration with a chat model.
Expand Down Expand Up @@ -81,14 +88,18 @@ export abstract class ChatGoogleBase<AuthOptions>

safetySettings: GoogleAISafetySetting[] = [];

safetyHandler: GoogleAISafetyHandler;

protected connection: ChatConnection<AuthOptions>;

protected streamedConnection: ChatConnection<AuthOptions>;

constructor(fields?: ChatGoogleBaseInput<AuthOptions>) {
super(fields ?? {});
super(ensureParams(fields));

copyAndValidateModelParamsInto(fields, this);
this.safetyHandler =
fields?.safetyHandler ?? new DefaultGeminiSafetyHandler();

const client = this.buildClient(fields);
this.buildConnection(fields ?? {}, client);
Expand Down Expand Up @@ -156,7 +167,7 @@ export abstract class ChatGoogleBase<AuthOptions>
parameters,
options
);
const ret = responseToChatResult(response);
const ret = safeResponseToChatResult(response, this.safetyHandler);
return ret;
}

Expand All @@ -183,7 +194,7 @@ export abstract class ChatGoogleBase<AuthOptions>
const output = await stream.nextChunk();
const chunk =
output !== null
? responseToChatGeneration({ data: output })
? safeResponseToChatGeneration({ data: output }, this.safetyHandler)
: new ChatGenerationChunk({
text: "",
generationInfo: { finishReason: "stop" },
Expand Down
180 changes: 131 additions & 49 deletions libs/langchain-google-common/src/llms.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { CallbackManager, Callbacks } from "@langchain/core/callbacks/manager";
import { BaseLLM, LLM } from "@langchain/core/language_models/llms";
import {
CallbackManagerForLLMRun,
Callbacks,
} from "@langchain/core/callbacks/manager";
import { LLM } from "@langchain/core/language_models/llms";
import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
type BaseLanguageModelCallOptions,
BaseLanguageModelInput,
} from "@langchain/core/language_models/base";
import { BaseMessage, MessageContent } from "@langchain/core/messages";
import { GenerationChunk } from "@langchain/core/outputs";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
Expand All @@ -21,13 +21,18 @@ import {
copyAndValidateModelParamsInto,
} from "./utils/common.js";
import {
chunkToString,
messageContentToParts,
responseToBaseMessage,
responseToGeneration,
responseToString,
safeResponseToBaseMessage,
safeResponseToString,
DefaultGeminiSafetyHandler,
} from "./utils/gemini.js";
import { JsonStream } from "./utils/stream.js";
import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js";
import { ensureParams } from "./utils/failed_handler.js";
import { ChatGoogleBase } from "./chat_models.js";
import type { GoogleBaseLLMInput, GoogleAISafetyHandler } from "./types.js";

export { GoogleBaseLLMInput };

class GoogleLLMConnection<AuthOptions> extends AbstractGoogleLLMConnection<
MessageContent,
Expand All @@ -48,11 +53,21 @@ class GoogleLLMConnection<AuthOptions> extends AbstractGoogleLLMConnection<
}
}

/**
* Input to LLM class.
*/
export interface GoogleBaseLLMInput<AuthOptions>
extends GoogleAIBaseLLMInput<AuthOptions> {}
type ProxyChatInput<AuthOptions> = GoogleAIBaseLLMInput<AuthOptions> & {
connection: GoogleLLMConnection<AuthOptions>;
};

class ProxyChatGoogle<AuthOptions> extends ChatGoogleBase<AuthOptions> {
constructor(fields: ProxyChatInput<AuthOptions>) {
super(fields);
}

buildAbstractedClient(
fields: ProxyChatInput<AuthOptions>
): GoogleAbstractedClient {
return fields.connection.client;
}
}

/**
* Integration with an LLM.
Expand All @@ -66,6 +81,8 @@ export abstract class GoogleBaseLLM<AuthOptions>
return "GoogleLLM";
}

originalFields?: GoogleBaseLLMInput<AuthOptions>;

lc_serializable = true;

model = "gemini-pro";
Expand All @@ -82,14 +99,19 @@ export abstract class GoogleBaseLLM<AuthOptions>

safetySettings: GoogleAISafetySetting[] = [];

safetyHandler: GoogleAISafetyHandler;

protected connection: GoogleLLMConnection<AuthOptions>;

protected streamedConnection: GoogleLLMConnection<AuthOptions>;

constructor(fields?: GoogleBaseLLMInput<AuthOptions>) {
super(fields ?? {});
super(ensureParams(fields));
this.originalFields = fields;

copyAndValidateModelParamsInto(fields, this);
this.safetyHandler =
fields?.safetyHandler ?? new DefaultGeminiSafetyHandler();

const client = this.buildClient(fields);
this.buildConnection(fields ?? {}, client);
Expand Down Expand Up @@ -152,48 +174,81 @@ export abstract class GoogleBaseLLM<AuthOptions>

/**
* For some given input string and options, return a string output.
*
* Despite the fact that `invoke` is overridden below, we still need this
* in order to handle public APi calls to `generate()`.
*/
async _call(
_prompt: string,
_options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
prompt: string,
options: this["ParsedCallOptions"]
): Promise<string> {
const parameters = copyAIModelParams(this);
const result = await this.connection.request(_prompt, parameters, _options);
const ret = responseToString(result);
const result = await this.connection.request(prompt, parameters, options);
const ret = safeResponseToString(result, this.safetyHandler);
return ret;
}

async *_streamResponseChunks(
_prompt: string,
_options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
// Make the call as a streaming request
const parameters = copyAIModelParams(this);
const result = await this.streamedConnection.request(
_prompt,
parameters,
_options
// Normally, you should not override this method and instead should override
// _streamResponseChunks. We are doing so here to allow for multimodal inputs into
// the LLM.
async *_streamIterator(
input: BaseLanguageModelInput,
options?: BaseLanguageModelCallOptions
): AsyncGenerator<string> {
// TODO: Refactor callback setup and teardown code into core
const prompt = BaseLLM._convertInputToPromptValue(input);
const [runnableConfig, callOptions] =
this._separateRunnableConfigFromCallOptions(options);
const callbackManager_ = await CallbackManager.configure(
runnableConfig.callbacks,
this.callbacks,
runnableConfig.tags,
this.tags,
runnableConfig.metadata,
this.metadata,
{ verbose: this.verbose }
);

// Get the streaming parser of the response
const stream = result.data as JsonStream;

// Loop until the end of the stream
// During the loop, yield each time we get a chunk from the streaming parser
// that is either available or added to the queue
while (!stream.streamDone) {
const output = await stream.nextChunk();
const chunk =
output !== null
? new GenerationChunk(responseToGeneration({ data: output }))
: new GenerationChunk({
text: "",
generationInfo: { finishReason: "stop" },
});
yield chunk;
const extra = {
options: callOptions,
invocation_params: this?.invocationParams(callOptions),
batch_size: 1,
};
const runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
[prompt.toString()],
undefined,
undefined,
extra,
undefined,
undefined,
runnableConfig.runName
);
let generation = new GenerationChunk({
text: "",
});
const proxyChat = this.createProxyChat();
try {
for await (const chunk of proxyChat._streamIterator(input, options)) {
const stringValue = chunkToString(chunk);
const generationChunk = new GenerationChunk({
text: stringValue,
});
generation = generation.concat(generationChunk);
yield stringValue;
}
} catch (err) {
await Promise.all(
(runManagers ?? []).map((runManager) => runManager?.handleLLMError(err))
);
throw err;
}
await Promise.all(
(runManagers ?? []).map((runManager) =>
runManager?.handleLLMEnd({
generations: [[generation]],
})
)
);
}

async predictMessages(
Expand All @@ -207,7 +262,34 @@ export abstract class GoogleBaseLLM<AuthOptions>
{},
options as BaseLanguageModelCallOptions
);
const ret = responseToBaseMessage(result);
const ret = safeResponseToBaseMessage(result, this.safetyHandler);
return ret;
}

/**
* Internal implementation detail to allow Google LLMs to support
* multimodal input by delegating to the chat model implementation.
*
* TODO: Replace with something less hacky.
*/
protected createProxyChat(): ChatGoogleBase<AuthOptions> {
return new ProxyChatGoogle<AuthOptions>({
...this.originalFields,
connection: this.connection,
});
}

// TODO: Remove the need to override this - we are doing it to
// allow the LLM to handle multimodal types of input.
async invoke(
input: BaseLanguageModelInput,
options?: BaseLanguageModelCallOptions
): Promise<string> {
const stream = await this._streamIterator(input, options);
let generatedOutput = "";
for await (const chunk of stream) {
generatedOutput += chunk;
}
return generatedOutput;
}
}
Loading

0 comments on commit 5030088

Please sign in to comment.