From 4e70204f2322f147489e5887afad5f84adf6c688 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 16 Aug 2024 13:34:36 -0700 Subject: [PATCH 1/3] core[minor]: Extract input vars from complex contents passed to prompt templates --- langchain-core/src/prompts/chat.ts | 10 +-- langchain-core/src/prompts/tests/chat.test.ts | 80 +++++++++++++++++++ .../src/prompts/tests/few_shot.test.ts | 61 +++++++++++++- 3 files changed, 142 insertions(+), 9 deletions(-) diff --git a/langchain-core/src/prompts/chat.ts b/langchain-core/src/prompts/chat.ts index 0fcb0330489b..f01d77725cc6 100644 --- a/langchain-core/src/prompts/chat.ts +++ b/langchain-core/src/prompts/chat.ts @@ -10,7 +10,6 @@ import { ChatMessage, type BaseMessageLike, coerceMessageLikeToMessage, - isBaseMessage, MessageContent, } from "../messages/index.js"; import { @@ -729,10 +728,7 @@ function _coerceMessagePromptTemplateLike< messagePromptTemplateLike: BaseMessagePromptTemplateLike, extra?: Extra ): BaseMessagePromptTemplate | BaseMessage { - if ( - _isBaseMessagePromptTemplate(messagePromptTemplateLike) || - isBaseMessage(messagePromptTemplateLike) - ) { + if (_isBaseMessagePromptTemplate(messagePromptTemplateLike)) { return messagePromptTemplateLike; } if ( @@ -1118,9 +1114,7 @@ export class ChatPromptTemplate< // eslint-disable-next-line no-instanceof/no-instanceof if (promptMessage instanceof BaseMessage) continue; for (const inputVariable of promptMessage.inputVariables) { - if (inputVariable in flattenedPartialVariables) { - continue; - } + if (inputVariable in flattenedPartialVariables) continue; inputVariables.add(inputVariable); } } diff --git a/langchain-core/src/prompts/tests/chat.test.ts b/langchain-core/src/prompts/tests/chat.test.ts index 3f5125861a73..031ef9288662 100644 --- a/langchain-core/src/prompts/tests/chat.test.ts +++ b/langchain-core/src/prompts/tests/chat.test.ts @@ -622,3 +622,83 @@ test("Multi-modal, multi part chat prompt works with instances of BaseMessage", }); expect(messages).toMatchSnapshot(); }); + +test("extract input variables from complex message contents", async () => { + const promptComplexContent = ChatPromptTemplate.fromMessages([ + ["human", [ + { + type: "text", + text: "{input}", + }, + ]], + ["human", [ + { + type:"image_url", + image_url: { + url: "{image_url}", + detail: "high" + } + } + ]], + ["human", [ + { + type: "image_url", + image_url: "{image_url_2}" + }, + { + type: "text", + text: "{input_2}" + }, + { + type: "text", + text: "{input}" // Intentionally duplicated + } + ]] + ]) + + expect(promptComplexContent.inputVariables).toHaveLength(4) + expect(promptComplexContent.inputVariables.sort()).toEqual(["input", "image_url", "image_url_2", "input_2"].sort()); +}); + +test("extract input variables from complex message contents inside BaseMessages", async () => { + const promptComplexContent = ChatPromptTemplate.fromMessages([ + new HumanMessage({ + content: [ + { + type: "text", + text: "{input}", + }, + ] + }), + new HumanMessage({ + content: [ + { + type:"image_url", + image_url: { + url: "{image_url}", + detail: "high" + } + } + ] + }), + new HumanMessage({ + content: [ + { + type: "image_url", + image_url: "{image_url_2}" + }, + { + type: "text", + text: "{input_2}" + }, + { + type: "text", + text: "{input}" // Intentionally duplicated + } + ] + }), + ]) + + expect(promptComplexContent.inputVariables).toHaveLength(4) + expect(promptComplexContent.inputVariables.sort()).toEqual(["input", "image_url", "image_url_2", "input_2"].sort()); +}); diff --git a/langchain-core/src/prompts/tests/few_shot.test.ts b/langchain-core/src/prompts/tests/few_shot.test.ts index a183ad68668c..0f896cd36716 100644 --- a/langchain-core/src/prompts/tests/few_shot.test.ts +++ b/langchain-core/src/prompts/tests/few_shot.test.ts @@ -125,7 +125,7 @@ An example about bar }); }); -describe("FewShotChatMessagePromptTemplate", () => { +describe.only("FewShotChatMessagePromptTemplate", () => { test("Format messages", async () => { const examplePrompt = ChatPromptTemplate.fromMessages([ ["ai", "{ai_input_var}"], @@ -252,4 +252,63 @@ An example about bar ` ); }); + + test("Can format messages with complex contents", async () => { + const examplePrompt = ChatPromptTemplate.fromMessages([ + new AIMessage({ + content: [{ + type: "text", + text: "{ai_input_var}", + }] + }), + new HumanMessage({ + content: [{ + type: "text", + text: "{human_input_var}", + }] + }), + ]); + const examples = [ + { + ai_input_var: "ai-foo", + human_input_var: "human-bar", + }, + { + ai_input_var: "ai-foo2", + human_input_var: "human-bar2", + }, + ]; + const prompt = new FewShotChatMessagePromptTemplate({ + examplePrompt, + inputVariables: ["ai_input_var", "human_input_var"], + examples, + }); + const messages = await prompt.formatMessages({}); + expect(messages).toEqual([ + new AIMessage({ + content: [{ + type: "text", + text: "ai-foo", + }] + }), + new HumanMessage({ + content: [{ + type: "text", + text: "human-bar", + }] + }), + new AIMessage({ + content: [{ + type: "text", + text: "ai-foo2", + }] + }), + new HumanMessage({ + content: [{ + type: "text", + text: "human-bar2", + }] + }), + ]); + }) }); From c635aca46d480abea86d6b2caaaebfc78bb59fdd Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 16 Aug 2024 13:35:40 -0700 Subject: [PATCH 2/3] format n lint --- langchain-core/src/prompts/tests/chat.test.ts | 105 ++++++++++-------- .../src/prompts/tests/few_shot.test.ts | 64 ++++++----- 2 files changed, 97 insertions(+), 72 deletions(-) diff --git a/langchain-core/src/prompts/tests/chat.test.ts b/langchain-core/src/prompts/tests/chat.test.ts index 031ef9288662..95d24744a2a9 100644 --- a/langchain-core/src/prompts/tests/chat.test.ts +++ b/langchain-core/src/prompts/tests/chat.test.ts @@ -625,39 +625,50 @@ test("Multi-modal, multi part chat prompt works with instances of BaseMessage", test("extract input variables from complex message contents", async () => { const promptComplexContent = ChatPromptTemplate.fromMessages([ - ["human", [ - { - type: "text", - text: "{input}", - }, - ]], - ["human", [ - { - type:"image_url", - image_url: { - url: "{image_url}", - detail: "high" - } - } - ]], - ["human", [ - { - type: "image_url", - image_url: "{image_url_2}" - }, - { - type: "text", - text: "{input_2}" - }, - { - type: "text", - text: "{input}" // Intentionally duplicated - } - ]] - ]) + [ + "human", + [ + { + type: "text", + text: "{input}", + }, + ], + ], + [ + "human", + [ + { + type: "image_url", + image_url: { + url: "{image_url}", + detail: "high", + }, + }, + ], + ], + [ + "human", + [ + { + type: "image_url", + image_url: "{image_url_2}", + }, + { + type: "text", + text: "{input_2}", + }, + { + type: "text", + text: "{input}", // Intentionally duplicated + }, + ], + ], + ]); - expect(promptComplexContent.inputVariables).toHaveLength(4) - expect(promptComplexContent.inputVariables.sort()).toEqual(["input", "image_url", "image_url_2", "input_2"].sort()); + expect(promptComplexContent.inputVariables).toHaveLength(4); + expect(promptComplexContent.inputVariables.sort()).toEqual( + ["input", "image_url", "image_url_2", "input_2"].sort() + ); }); test("extract input variables from complex message contents inside BaseMessages", async () => { @@ -668,37 +679,39 @@ test("extract input variables from complex message contents inside BaseMessages" type: "text", text: "{input}", }, - ] + ], }), new HumanMessage({ content: [ { - type:"image_url", + type: "image_url", image_url: { url: "{image_url}", - detail: "high" - } - } - ] + detail: "high", + }, + }, + ], }), new HumanMessage({ content: [ { type: "image_url", - image_url: "{image_url_2}" + image_url: "{image_url_2}", }, { type: "text", - text: "{input_2}" + text: "{input_2}", }, { type: "text", - text: "{input}" // Intentionally duplicated - } - ] + text: "{input}", // Intentionally duplicated + }, + ], }), - ]) + ]); - expect(promptComplexContent.inputVariables).toHaveLength(4) - expect(promptComplexContent.inputVariables.sort()).toEqual(["input", "image_url", "image_url_2", "input_2"].sort()); + expect(promptComplexContent.inputVariables).toHaveLength(4); + expect(promptComplexContent.inputVariables.sort()).toEqual( + ["input", "image_url", "image_url_2", "input_2"].sort() + ); }); diff --git a/langchain-core/src/prompts/tests/few_shot.test.ts b/langchain-core/src/prompts/tests/few_shot.test.ts index 0f896cd36716..c92ea8d93699 100644 --- a/langchain-core/src/prompts/tests/few_shot.test.ts +++ b/langchain-core/src/prompts/tests/few_shot.test.ts @@ -125,7 +125,7 @@ An example about bar }); }); -describe.only("FewShotChatMessagePromptTemplate", () => { +describe("FewShotChatMessagePromptTemplate", () => { test("Format messages", async () => { const examplePrompt = ChatPromptTemplate.fromMessages([ ["ai", "{ai_input_var}"], @@ -256,16 +256,20 @@ An example about bar test("Can format messages with complex contents", async () => { const examplePrompt = ChatPromptTemplate.fromMessages([ new AIMessage({ - content: [{ - type: "text", - text: "{ai_input_var}", - }] + content: [ + { + type: "text", + text: "{ai_input_var}", + }, + ], }), new HumanMessage({ - content: [{ - type: "text", - text: "{human_input_var}", - }] + content: [ + { + type: "text", + text: "{human_input_var}", + }, + ], }), ]); const examples = [ @@ -286,29 +290,37 @@ An example about bar const messages = await prompt.formatMessages({}); expect(messages).toEqual([ new AIMessage({ - content: [{ - type: "text", - text: "ai-foo", - }] + content: [ + { + type: "text", + text: "ai-foo", + }, + ], }), new HumanMessage({ - content: [{ - type: "text", - text: "human-bar", - }] + content: [ + { + type: "text", + text: "human-bar", + }, + ], }), new AIMessage({ - content: [{ - type: "text", - text: "ai-foo2", - }] + content: [ + { + type: "text", + text: "ai-foo2", + }, + ], }), new HumanMessage({ - content: [{ - type: "text", - text: "human-bar2", - }] + content: [ + { + type: "text", + text: "human-bar2", + }, + ], }), ]); - }) + }); }); From da143977e0a3745474dc492f01f1a9a2d8e94d75 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 16 Aug 2024 13:59:42 -0700 Subject: [PATCH 3/3] cr --- .../src/output_parsers/tests/xml.test.ts | 2 +- langchain-core/src/prompts/chat.ts | 11 +++++++++++ .../src/runnables/tests/runnable.test.ts | 18 +++++++++--------- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/langchain-core/src/output_parsers/tests/xml.test.ts b/langchain-core/src/output_parsers/tests/xml.test.ts index bec47cfbc6f8..0a2bd5f4fd2b 100644 --- a/langchain-core/src/output_parsers/tests/xml.test.ts +++ b/langchain-core/src/output_parsers/tests/xml.test.ts @@ -87,7 +87,7 @@ test("Can parse streams", async () => { const result = await streamingLlm.stream(XML_EXAMPLE); let finalResult = {}; for await (const chunk of result) { - console.log(chunk); + // console.log(chunk); finalResult = chunk; } expect(finalResult).toStrictEqual(expectedResult); diff --git a/langchain-core/src/prompts/chat.ts b/langchain-core/src/prompts/chat.ts index f01d77725cc6..15192e7df984 100644 --- a/langchain-core/src/prompts/chat.ts +++ b/langchain-core/src/prompts/chat.ts @@ -11,6 +11,7 @@ import { type BaseMessageLike, coerceMessageLikeToMessage, MessageContent, + isBaseMessage, } from "../messages/index.js"; import { type ChatPromptValueInterface, @@ -731,6 +732,16 @@ function _coerceMessagePromptTemplateLike< if (_isBaseMessagePromptTemplate(messagePromptTemplateLike)) { return messagePromptTemplateLike; } + const allowedCoercionMessageTypes = ["system", "ai", "human", "generic"]; + // Do not coerce if it's an instance of `BaseMessage` AND it's not one of the allowed message types. + if ( + isBaseMessage(messagePromptTemplateLike) && + !allowedCoercionMessageTypes.includes(messagePromptTemplateLike._getType()) + ) { + console.log("Returning", messagePromptTemplateLike._getType()) + return messagePromptTemplateLike; + } + if ( Array.isArray(messagePromptTemplateLike) && messagePromptTemplateLike[0] === "placeholder" diff --git a/langchain-core/src/runnables/tests/runnable.test.ts b/langchain-core/src/runnables/tests/runnable.test.ts index ca597ab35872..ee23d3202e3a 100644 --- a/langchain-core/src/runnables/tests/runnable.test.ts +++ b/langchain-core/src/runnables/tests/runnable.test.ts @@ -70,7 +70,7 @@ test("Test chat model stream", async () => { let done = false; while (!done) { const chunk = await reader.read(); - console.log(chunk); + // console.log(chunk); done = chunk.done; } }); @@ -80,7 +80,7 @@ test("Pipe from one runnable to the next", async () => { const llm = new FakeLLM({}); const runnable = promptTemplate.pipe(llm); const result = await runnable.invoke({ input: "Hello world!" }); - console.log(result); + // console.log(result); expect(result).toBe("Hello world!"); }); @@ -90,7 +90,7 @@ test("Stream the entire way through", async () => { const chunks = []; for await (const chunk of stream) { chunks.push(chunk); - console.log(chunk); + // console.log(chunk); } expect(chunks.length).toEqual("Hi there!".length); expect(chunks.join("")).toEqual("Hi there!"); @@ -118,7 +118,7 @@ test("Callback order with transform streaming", async () => { const chunks = []; for await (const chunk of stream) { chunks.push(chunk); - console.log(chunk); + // console.log(chunk); } expect(order).toEqual([ "RunnableSequence", @@ -139,7 +139,7 @@ test("Don't use intermediate streaming", async () => { const chunks = []; for await (const chunk of stream) { chunks.push(chunk); - console.log(chunk); + // console.log(chunk); } expect(chunks.length).toEqual(1); expect(chunks[0]).toEqual("Hi there!"); @@ -400,7 +400,7 @@ test("Create a runnable sequence and run it", async () => { const text = `Jello world`; const runnable = promptTemplate.pipe(llm).pipe(parser); const result = await runnable.invoke({ input: text }); - console.log(result); + // console.log(result); expect(result).toEqual("Jello world"); }); @@ -408,7 +408,7 @@ test("Create a runnable sequence with a static method with invalid output and ca const promptTemplate = PromptTemplate.fromTemplate("{input}"); const llm = new FakeChatModel({}); const parser = (input: BaseMessage) => { - console.log(input); + // console.log(input); try { const parsedInput = typeof input.content === "string" @@ -428,8 +428,8 @@ test("Create a runnable sequence with a static method with invalid output and ca }; const runnable = RunnableSequence.from([promptTemplate, llm, parser]); await expect(async () => { - const result = await runnable.invoke({ input: "Hello sequence!" }); - console.log(result); + await runnable.invoke({ input: "Hello sequence!" }); + // console.log(result); }).rejects.toThrow(OutputParserException); });