Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor]: Extract input vars from complex contents passed to prompt templates #6558

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langchain-core/src/output_parsers/tests/xml.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ test("Can parse streams", async () => {
const result = await streamingLlm.stream(XML_EXAMPLE);
let finalResult = {};
for await (const chunk of result) {
console.log(chunk);
// console.log(chunk);
finalResult = chunk;
}
expect(finalResult).toStrictEqual(expectedResult);
Expand Down
17 changes: 11 additions & 6 deletions langchain-core/src/prompts/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import {
ChatMessage,
type BaseMessageLike,
coerceMessageLikeToMessage,
isBaseMessage,
MessageContent,
isBaseMessage,
} from "../messages/index.js";
import {
type ChatPromptValueInterface,
Expand Down Expand Up @@ -729,12 +729,19 @@ function _coerceMessagePromptTemplateLike<
messagePromptTemplateLike: BaseMessagePromptTemplateLike,
extra?: Extra
): BaseMessagePromptTemplate | BaseMessage {
if (_isBaseMessagePromptTemplate(messagePromptTemplateLike)) {
return messagePromptTemplateLike;
}
const allowedCoercionMessageTypes = ["system", "ai", "human", "generic"];
// Do not coerce if it's an instance of `BaseMessage` AND it's not one of the allowed message types.
if (
_isBaseMessagePromptTemplate(messagePromptTemplateLike) ||
isBaseMessage(messagePromptTemplateLike)
isBaseMessage(messagePromptTemplateLike) &&
!allowedCoercionMessageTypes.includes(messagePromptTemplateLike._getType())
) {
console.log("Returning", messagePromptTemplateLike._getType())
return messagePromptTemplateLike;
}

if (
Array.isArray(messagePromptTemplateLike) &&
messagePromptTemplateLike[0] === "placeholder"
Expand Down Expand Up @@ -1118,9 +1125,7 @@ export class ChatPromptTemplate<
// eslint-disable-next-line no-instanceof/no-instanceof
if (promptMessage instanceof BaseMessage) continue;
for (const inputVariable of promptMessage.inputVariables) {
if (inputVariable in flattenedPartialVariables) {
continue;
}
if (inputVariable in flattenedPartialVariables) continue;
inputVariables.add(inputVariable);
}
}
Expand Down
93 changes: 93 additions & 0 deletions langchain-core/src/prompts/tests/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,96 @@ test("Multi-modal, multi part chat prompt works with instances of BaseMessage",
});
expect(messages).toMatchSnapshot();
});

test("extract input variables from complex message contents", async () => {
const promptComplexContent = ChatPromptTemplate.fromMessages([
[
"human",
[
{
type: "text",
text: "{input}",
},
],
],
[
"human",
[
{
type: "image_url",
image_url: {
url: "{image_url}",
detail: "high",
},
},
],
],
[
"human",
[
{
type: "image_url",
image_url: "{image_url_2}",
},
{
type: "text",
text: "{input_2}",
},
{
type: "text",
text: "{input}", // Intentionally duplicated
},
],
],
]);

expect(promptComplexContent.inputVariables).toHaveLength(4);
expect(promptComplexContent.inputVariables.sort()).toEqual(
["input", "image_url", "image_url_2", "input_2"].sort()
);
});

test("extract input variables from complex message contents inside BaseMessages", async () => {
const promptComplexContent = ChatPromptTemplate.fromMessages([
new HumanMessage({
content: [
{
type: "text",
text: "{input}",
},
],
}),
new HumanMessage({
content: [
{
type: "image_url",
image_url: {
url: "{image_url}",
detail: "high",
},
},
],
}),
new HumanMessage({
content: [
{
type: "image_url",
image_url: "{image_url_2}",
},
{
type: "text",
text: "{input_2}",
},
{
type: "text",
text: "{input}", // Intentionally duplicated
},
],
}),
]);

expect(promptComplexContent.inputVariables).toHaveLength(4);
expect(promptComplexContent.inputVariables.sort()).toEqual(
["input", "image_url", "image_url_2", "input_2"].sort()
);
});
71 changes: 71 additions & 0 deletions langchain-core/src/prompts/tests/few_shot.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,75 @@ An example about bar
`
);
});

test("Can format messages with complex contents", async () => {
const examplePrompt = ChatPromptTemplate.fromMessages([
new AIMessage({
content: [
{
type: "text",
text: "{ai_input_var}",
},
],
}),
new HumanMessage({
content: [
{
type: "text",
text: "{human_input_var}",
},
],
}),
]);
const examples = [
{
ai_input_var: "ai-foo",
human_input_var: "human-bar",
},
{
ai_input_var: "ai-foo2",
human_input_var: "human-bar2",
},
];
const prompt = new FewShotChatMessagePromptTemplate({
examplePrompt,
inputVariables: ["ai_input_var", "human_input_var"],
examples,
});
const messages = await prompt.formatMessages({});
expect(messages).toEqual([
new AIMessage({
content: [
{
type: "text",
text: "ai-foo",
},
],
}),
new HumanMessage({
content: [
{
type: "text",
text: "human-bar",
},
],
}),
new AIMessage({
content: [
{
type: "text",
text: "ai-foo2",
},
],
}),
new HumanMessage({
content: [
{
type: "text",
text: "human-bar2",
},
],
}),
]);
});
});
18 changes: 9 additions & 9 deletions langchain-core/src/runnables/tests/runnable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ test("Test chat model stream", async () => {
let done = false;
while (!done) {
const chunk = await reader.read();
console.log(chunk);
// console.log(chunk);
done = chunk.done;
}
});
Expand All @@ -80,7 +80,7 @@ test("Pipe from one runnable to the next", async () => {
const llm = new FakeLLM({});
const runnable = promptTemplate.pipe(llm);
const result = await runnable.invoke({ input: "Hello world!" });
console.log(result);
// console.log(result);
expect(result).toBe("Hello world!");
});

Expand All @@ -90,7 +90,7 @@ test("Stream the entire way through", async () => {
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
console.log(chunk);
// console.log(chunk);
}
expect(chunks.length).toEqual("Hi there!".length);
expect(chunks.join("")).toEqual("Hi there!");
Expand Down Expand Up @@ -118,7 +118,7 @@ test("Callback order with transform streaming", async () => {
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
console.log(chunk);
// console.log(chunk);
}
expect(order).toEqual([
"RunnableSequence",
Expand All @@ -139,7 +139,7 @@ test("Don't use intermediate streaming", async () => {
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
console.log(chunk);
// console.log(chunk);
}
expect(chunks.length).toEqual(1);
expect(chunks[0]).toEqual("Hi there!");
Expand Down Expand Up @@ -400,15 +400,15 @@ test("Create a runnable sequence and run it", async () => {
const text = `Jello world`;
const runnable = promptTemplate.pipe(llm).pipe(parser);
const result = await runnable.invoke({ input: text });
console.log(result);
// console.log(result);
expect(result).toEqual("Jello world");
});

test("Create a runnable sequence with a static method with invalid output and catch the error", async () => {
const promptTemplate = PromptTemplate.fromTemplate("{input}");
const llm = new FakeChatModel({});
const parser = (input: BaseMessage) => {
console.log(input);
// console.log(input);
try {
const parsedInput =
typeof input.content === "string"
Expand All @@ -428,8 +428,8 @@ test("Create a runnable sequence with a static method with invalid output and ca
};
const runnable = RunnableSequence.from([promptTemplate, llm, parser]);
await expect(async () => {
const result = await runnable.invoke({ input: "Hello sequence!" });
console.log(result);
await runnable.invoke({ input: "Hello sequence!" });
// console.log(result);
}).rejects.toThrow(OutputParserException);
});

Expand Down
Loading