From bc4228d01eda64cbfc50a3f8b0b61f3cd1c06711 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Mon, 16 Sep 2024 13:46:13 +0200 Subject: [PATCH 01/17] OSS LLM --- .../graphs/default_assistant_graph/helpers.ts | 112 ++++++++++-------- .../graphs/default_assistant_graph/index.ts | 50 ++++---- 2 files changed, 92 insertions(+), 70 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 93890f9dfb121..5485a02cd1fb2 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -24,6 +24,7 @@ interface StreamGraphParams { assistantGraph: DefaultAssistantGraph; inputs: GraphInputs; logger: Logger; + isOssLlm?: boolean; onLlmResponse?: OnLlmResponse; request: KibanaRequest; traceOptions?: TraceOptions; @@ -36,15 +37,18 @@ interface StreamGraphParams { * @param assistantGraph * @param inputs * @param logger + * @param isOssLlm * @param onLlmResponse * @param request * @param traceOptions */ +/* eslint complexity: ["error", 210]*/ export const streamGraph = async ({ apmTracer, assistantGraph, inputs, logger, + isOssLlm, onLlmResponse, request, traceOptions, @@ -138,61 +142,75 @@ export const streamGraph = async ({ if (done) return; const event = value; - // only process events that are part of the agent run - if ((event.tags || []).includes(AGENT_NODE_TAG)) { - if (event.name === 'ActionsClientChatOpenAI') { - if (event.event === 'on_llm_stream') { - const chunk = event.data?.chunk; - const msg = chunk.message; - if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) { - // I don't think we hit this anymore because of our check for AGENT_NODE_TAG - // however, no harm to keep it in - /* empty */ - } else if (!didEnd) { - push({ payload: msg.content, type: 'content' }); - finalMessage += msg.content; - } - } else if (event.event === 'on_llm_end' && !didEnd) { - const generations = event.data.output?.generations[0]; - if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { - handleStreamEnd(generations[0]?.text ?? finalMessage); - } + + const processOpenAIEvent = () => { + if (event.event === 'on_llm_stream') { + const chunk = event.data?.chunk; + const msg = chunk.message; + if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) { + // I don't think we hit this anymore because of our check for AGENT_NODE_TAG + // however, no harm to keep it in + /* empty */ + } else if (!didEnd) { + push({ payload: msg.content, type: 'content' }); + finalMessage += msg.content; + } + } else if (event.event === 'on_llm_end' && !didEnd) { + const generations = event.data.output?.generations[0]; + if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { + handleStreamEnd(generations[0]?.text ?? finalMessage); } } - if (event.name === 'ActionsClientSimpleChatModel') { - if (event.event === 'on_llm_stream') { - const chunk = event.data?.chunk; + }; - const msg = chunk.content; - if (finalOutputIndex === -1) { - currentOutput += msg; - // Remove whitespace to simplify parsing - const noWhitespaceOutput = currentOutput.replace(/\s/g, ''); - if (noWhitespaceOutput.includes(finalOutputStartToken)) { - const nonStrippedToken = '"action_input": "'; - finalOutputIndex = currentOutput.indexOf(nonStrippedToken); - const contentStartIndex = finalOutputIndex + nonStrippedToken.length; - extraOutput = currentOutput.substring(contentStartIndex); + const processSimpleChatModelEvent = () => { + // console.log(`[TEST][ActionsClientSimpleChatModel] currentOutput: ${currentOutput}`); + if (event.event === 'on_llm_stream') { + const chunk = event.data?.chunk; + + const msg = isOssLlm ? chunk.message.content : chunk.content; + if (finalOutputIndex === -1) { + currentOutput += msg; + // Remove whitespace to simplify parsing + const noWhitespaceOutput = currentOutput.replace(/\s/g, ''); + if (noWhitespaceOutput.includes(finalOutputStartToken)) { + const nonStrippedToken = '"action_input": "'; + finalOutputIndex = currentOutput.indexOf(nonStrippedToken); + const contentStartIndex = finalOutputIndex + nonStrippedToken.length; + extraOutput = currentOutput.substring(contentStartIndex); + push({ payload: extraOutput, type: 'content' }); + finalMessage += extraOutput; + } + } else if (!streamingFinished && !didEnd) { + const finalOutputEndIndex = msg.search(finalOutputStopRegex); + if (finalOutputEndIndex !== -1) { + extraOutput = msg.substring(0, finalOutputEndIndex); + streamingFinished = true; + if (extraOutput.length > 0) { push({ payload: extraOutput, type: 'content' }); finalMessage += extraOutput; } - } else if (!streamingFinished && !didEnd) { - const finalOutputEndIndex = msg.search(finalOutputStopRegex); - if (finalOutputEndIndex !== -1) { - extraOutput = msg.substring(0, finalOutputEndIndex); - streamingFinished = true; - if (extraOutput.length > 0) { - push({ payload: extraOutput, type: 'content' }); - finalMessage += extraOutput; - } - } else { - push({ payload: chunk.content, type: 'content' }); - finalMessage += chunk.content; - } + } else { + push({ payload: msg, type: 'content' }); + finalMessage += msg; } - } else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) { - handleStreamEnd(finalMessage); } + } else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) { + handleStreamEnd(finalMessage); + } + }; + + // only process events that are part of the agent run + if ((event.tags || []).includes(AGENT_NODE_TAG)) { + if (event.name === 'ActionsClientChatOpenAI') { + if (isOssLlm) { + processSimpleChatModelEvent(); + } else { + processOpenAIEvent(); + } + } + if (event.name === 'ActionsClientSimpleChatModel') { + processSimpleChatModelEvent(); } } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 9b421d2d93ebc..64ced423e58f4 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -53,7 +53,9 @@ export const callAssistantGraph: AgentExecutor = async ({ responseLanguage = 'English', }) => { const logger = parentLogger.get('defaultAssistantGraph'); + const model = request.body.model; const isOpenAI = llmType === 'openai'; + const isOssLlm = isOpenAI && !!model && !model.startsWith('gpt-'); const llmClass = getLlmClass(llmType, bedrockChatEnabled); /** @@ -134,29 +136,30 @@ export const callAssistantGraph: AgentExecutor = async ({ } } - const agentRunnable = isOpenAI - ? await createOpenAIFunctionsAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPrompt(systemPrompts.openai, systemPrompt), - streamRunnable: isStream, - }) - : llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled - ? await createToolCallingAgent({ - llm: createLlmInstance(), - tools, - prompt: - llmType === 'bedrock' - ? formatPrompt(systemPrompts.bedrock, systemPrompt) - : formatPrompt(systemPrompts.gemini, systemPrompt), - streamRunnable: isStream, - }) - : await createStructuredChatAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt), - streamRunnable: isStream, - }); + const agentRunnable = + isOpenAI && !isOssLlm + ? await createOpenAIFunctionsAgent({ + llm: createLlmInstance(), + tools, + prompt: formatPrompt(systemPrompts.openai, systemPrompt), + streamRunnable: isStream, + }) + : llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled + ? await createToolCallingAgent({ + llm: createLlmInstance(), + tools, + prompt: + llmType === 'bedrock' + ? formatPrompt(systemPrompts.bedrock, systemPrompt) + : formatPrompt(systemPrompts.gemini, systemPrompt), + streamRunnable: isStream, + }) + : await createStructuredChatAgent({ + llm: createLlmInstance(), + tools, + prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt), + streamRunnable: isStream, + }); const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger); @@ -184,6 +187,7 @@ export const callAssistantGraph: AgentExecutor = async ({ assistantGraph, inputs, logger, + isOssLlm, onLlmResponse, request, traceOptions, From 4d5c841ac7ecb2f63d430c69dbe6ee64d8da432e Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Mon, 16 Sep 2024 18:11:17 +0200 Subject: [PATCH 02/17] OSS LLMs streaming fixes 1. Make sure that we check last appearence of `'"action_input": "'` 2. Make sure that we do not cut streaming in case expected `\` before `*` character was already added to the final message in previous round --- .../graphs/default_assistant_graph/helpers.ts | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 5485a02cd1fb2..3e654bca08a97 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -164,24 +164,30 @@ export const streamGraph = async ({ }; const processSimpleChatModelEvent = () => { - // console.log(`[TEST][ActionsClientSimpleChatModel] currentOutput: ${currentOutput}`); + // console.log(`[TEST] currentOutput: ${currentOutput}`); + // console.log(`[TEST] finalMessage: ${finalMessage}`); if (event.event === 'on_llm_stream') { const chunk = event.data?.chunk; - const msg = isOssLlm ? chunk.message.content : chunk.content; + let msg = isOssLlm ? chunk.message.content : chunk.content; if (finalOutputIndex === -1) { currentOutput += msg; // Remove whitespace to simplify parsing const noWhitespaceOutput = currentOutput.replace(/\s/g, ''); if (noWhitespaceOutput.includes(finalOutputStartToken)) { const nonStrippedToken = '"action_input": "'; - finalOutputIndex = currentOutput.indexOf(nonStrippedToken); + finalOutputIndex = currentOutput.lastIndexOf(nonStrippedToken); const contentStartIndex = finalOutputIndex + nonStrippedToken.length; extraOutput = currentOutput.substring(contentStartIndex); push({ payload: extraOutput, type: 'content' }); finalMessage += extraOutput; } } else if (!streamingFinished && !didEnd) { + if (msg.startsWith('"') && finalMessage.endsWith('\\')) { + // console.log(`[TEST] finalMessage: ${finalMessage}, msg: ${msg}`); + finalMessage = finalMessage.slice(0, -1); + msg = `\\${msg}`; + } const finalOutputEndIndex = msg.search(finalOutputStopRegex); if (finalOutputEndIndex !== -1) { extraOutput = msg.substring(0, finalOutputEndIndex); From 45c8ac06fba4cda2aaa73fd170f01b8459570d84 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Tue, 17 Sep 2024 14:58:51 +0200 Subject: [PATCH 03/17] Use api URL to verify OSS llms vs OpenAI --- .../server/lib/langchain/executors/types.ts | 1 + .../graphs/default_assistant_graph/helpers.ts | 5 ++ .../graphs/default_assistant_graph/index.ts | 55 ++++++++++--------- .../server/routes/chat/chat_complete_route.ts | 7 ++- .../server/routes/evaluate/post_evaluate.ts | 7 ++- .../server/routes/helpers.ts | 3 + .../routes/post_actions_connector_execute.ts | 6 ++ 7 files changed, 55 insertions(+), 29 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 2395221ea14b3..1f8dd143ce9ed 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -41,6 +41,7 @@ export interface AgentExecutorParams { bedrockChatEnabled: boolean; assistantTools?: AssistantTool[]; connectorId: string; + connectorApiUrl?: string; conversationId?: string; dataClients?: AssistantDataClients; esClient: ElasticsearchClient; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 3e654bca08a97..1b5bdadbc5d84 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -202,6 +202,11 @@ export const streamGraph = async ({ } } } else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) { + // Sometimes llama returns extra escape backslash characters which breaks the markdown. + // One of the solutions that I've found is to use `JSON.parse` to remove those. + // console.log(`[TEST] finalMessage 1: ${finalMessage}`); + // finalMessage = JSON.parse(`{"content":"${finalMessage}"}`).content; + // console.log(`[TEST] finalMessage 2: ${finalMessage}`); handleStreamEnd(finalMessage); } }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 64ced423e58f4..b4111ac48b33e 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -14,6 +14,7 @@ import { createToolCallingAgent, } from 'langchain/agents'; import { APMTracer } from '@kbn/langchain/server/tracers/apm'; +import { OPENAI_CHAT_URL } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { getLlmClass } from '../../../../routes/utils'; import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types'; import { AssistantToolParams } from '../../../../types'; @@ -32,6 +33,7 @@ export const callAssistantGraph: AgentExecutor = async ({ actionsClient, alertsIndexPattern, assistantTools = [], + connectorApiUrl, bedrockChatEnabled, connectorId, conversationId, @@ -53,9 +55,9 @@ export const callAssistantGraph: AgentExecutor = async ({ responseLanguage = 'English', }) => { const logger = parentLogger.get('defaultAssistantGraph'); - const model = request.body.model; - const isOpenAI = llmType === 'openai'; - const isOssLlm = isOpenAI && !!model && !model.startsWith('gpt-'); + const isOpeAIType = llmType === 'openai'; + const isOpenAI = isOpeAIType && (!connectorApiUrl || connectorApiUrl === OPENAI_CHAT_URL); + const isOssLlm = isOpeAIType && !isOpenAI; const llmClass = getLlmClass(llmType, bedrockChatEnabled); /** @@ -136,30 +138,29 @@ export const callAssistantGraph: AgentExecutor = async ({ } } - const agentRunnable = - isOpenAI && !isOssLlm - ? await createOpenAIFunctionsAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPrompt(systemPrompts.openai, systemPrompt), - streamRunnable: isStream, - }) - : llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled - ? await createToolCallingAgent({ - llm: createLlmInstance(), - tools, - prompt: - llmType === 'bedrock' - ? formatPrompt(systemPrompts.bedrock, systemPrompt) - : formatPrompt(systemPrompts.gemini, systemPrompt), - streamRunnable: isStream, - }) - : await createStructuredChatAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt), - streamRunnable: isStream, - }); + const agentRunnable = isOpenAI + ? await createOpenAIFunctionsAgent({ + llm: createLlmInstance(), + tools, + prompt: formatPrompt(systemPrompts.openai, systemPrompt), + streamRunnable: isStream, + }) + : llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled + ? await createToolCallingAgent({ + llm: createLlmInstance(), + tools, + prompt: + llmType === 'bedrock' + ? formatPrompt(systemPrompts.bedrock, systemPrompt) + : formatPrompt(systemPrompts.gemini, systemPrompt), + streamRunnable: isStream, + }) + : await createStructuredChatAgent({ + llm: createLlmInstance(), + tools, + prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt), + streamRunnable: isStream, + }); const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger); diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index dd90241809015..7e7e0462f6f38 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -99,7 +99,11 @@ export const chatCompleteRoute = ( const actions = ctx.elasticAssistant.actions; const actionsClient = await actions.getActionsClientWithRequest(request); const connectors = await actionsClient.getBulk({ ids: [connectorId] }); - actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai'; + const connector = connectors.length > 0 ? connectors[0] : undefined; + actionTypeId = connector?.actionTypeId ?? '.gen-ai'; + const connectorApiUrl = connector?.config?.apiUrl + ? (connector.config.apiUrl as string) + : undefined; // replacements const anonymizationFieldsRes = @@ -192,6 +196,7 @@ export const chatCompleteRoute = ( actionsClient, actionTypeId, connectorId, + connectorApiUrl, conversationId: conversationId ?? newConversation?.id, context: ctx, getElser, diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index 090dfa2acf5f0..38e1b442f1c30 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -30,6 +30,7 @@ import { createToolCallingAgent, } from 'langchain/agents'; import { RetrievalQAChain } from 'langchain/chains'; +import { OPENAI_CHAT_URL } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { buildResponse } from '../../lib/build_response'; import { AssistantDataClients } from '../../lib/langchain/executors/types'; import { AssistantToolParams, ElasticAssistantRequestHandlerContext, GetElser } from '../../types'; @@ -195,7 +196,11 @@ export const postEvaluateRoute = ( }> = await Promise.all( connectors.map(async (connector) => { const llmType = getLlmType(connector.actionTypeId); - const isOpenAI = llmType === 'openai'; + const connectorApiUrl = connector?.config?.apiUrl + ? (connector.config.apiUrl as string) + : undefined; + const isOpenAI = + llmType === 'openai' && (!connectorApiUrl || connectorApiUrl === OPENAI_CHAT_URL); const llmClass = getLlmClass(llmType, true); const createLlmInstance = () => new llmClass({ diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 860c6882a6b27..b717250f062f7 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -323,6 +323,7 @@ export interface LangChainExecuteParams { actionTypeId: string; connectorId: string; inference: InferenceServerStart; + connectorApiUrl?: string; conversationId?: string; context: AwaitedProperties< Pick @@ -349,6 +350,7 @@ export const langChainExecute = async ({ telemetry, actionTypeId, connectorId, + connectorApiUrl, context, actionsClient, inference, @@ -420,6 +422,7 @@ export const langChainExecute = async ({ assistantTools, conversationId, connectorId, + connectorApiUrl, esClient, esStore, inference, diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 736d60ff666b0..e59a479f4031a 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -94,6 +94,11 @@ export const postActionsConnectorExecuteRoute = ( const actions = ctx.elasticAssistant.actions; const inference = ctx.elasticAssistant.inference; const actionsClient = await actions.getActionsClientWithRequest(request); + const connectors = await actionsClient.getBulk({ ids: [connectorId] }); + const connector = connectors.length > 0 ? connectors[0] : undefined; + const connectorApiUrl = connector?.config?.apiUrl + ? (connector.config.apiUrl as string) + : undefined; const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient(); @@ -129,6 +134,7 @@ export const postActionsConnectorExecuteRoute = ( actionsClient, actionTypeId, connectorId, + connectorApiUrl, conversationId, context: ctx, getElser, From 98abd535c605d912420b31ac65b3a7e4933099c9 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Thu, 19 Sep 2024 11:23:22 +0200 Subject: [PATCH 04/17] Prompting --- .../graphs/default_assistant_graph/prompts.ts | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts index eb52c227421fc..e0abcf5ca31c1 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts @@ -20,11 +20,38 @@ export const formatPrompt = (prompt: string, additionalPrompt?: string) => ['placeholder', '{agent_scratchpad}'], ]); +const PROMPT_1 = `ALWAYS use the provided tools, as they have access to the latest data and syntax. + +Always return value from ESQLKnowledgeBaseTool as is. Do not reflect on the quality of the returned search results in your response.`; +const PROMPT_2 = `${BEDROCK_SYSTEM_PROMPT}`; +const PROMPT_2_1 = `Use tools as often as possible, as they have access to the latest data and syntax. The result returned from ESQLKnowledgeBaseTool is a string which should not be modified and should ALWAYS be returned as is. Do not reflect on the quality of the returned search results in your response.`; +const PROMPT_3 = `${GEMINI_SYSTEM_PROMPT}`; +const PROMPT_4_0 = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Do not reflect on the quality of the returned search results in your response. Final ES|QL query should always be wrapped in tripple backticks and be put on a new line.`; +const PROMPT_4_1 = `ALWAYS use the provided tools, as they have access to the latest data and syntax. ALWAYS pass the whole user input to ESQLKnowledgeBaseTool. ALWAYS return value from ESQLKnowledgeBaseTool as is.`; +const PROMPT_5 = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is and use it as a final answer without modifying it. Do not reflect on the quality of the returned search results in your response.`; +const PROMPT_6 = ` +Use tools as often as possible, as they have access to the latest data and syntax. + +When using ESQLKnowledgeBaseTool pass the user's questions directly as input into the tool. + +Always return value from ESQLKnowledgeBaseTool as is. + +The ES|QL query should always be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks. + +It is important that ES|QL query is preceeded by a new line.`; + +// export const GEMINI_SYSTEM_PROMPT = +// `ALWAYS use the provided tools, as they have access to the latest data and syntax.` + +// "The final response is the only output the user sees and should be a complete answer to the user's question. Do not leave out important tool output. The final response should never be empty. Don't forget to use tools."; +// export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return tags in the response, but make sure to include tags content in the response. Do not reflect on the quality of the returned search results in your response.`; + export const systemPrompts = { openai: DEFAULT_SYSTEM_PROMPT, bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`, gemini: `${DEFAULT_SYSTEM_PROMPT} ${GEMINI_SYSTEM_PROMPT}`, - structuredChat: `Respond to the human as helpfully and accurately as possible. You have access to the following tools: + structuredChat: `${DEFAULT_SYSTEM_PROMPT} + +Respond to the human as helpfully and accurately as possible. You have access to the following tools: {tools} @@ -78,7 +105,9 @@ Action: "action_input": "Final response to human"}} -Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`, +Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation + +${PROMPT_6}`, }; export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai); From 7a7f7e20af66350163487a04b86af9e9ed9e52f8 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Thu, 19 Sep 2024 16:48:58 +0200 Subject: [PATCH 05/17] Use provider type to better identify OSS model - handles case with the Azure AI --- .../server/lib/langchain/executors/types.ts | 2 + .../graphs/default_assistant_graph/helpers.ts | 8 -- .../graphs/default_assistant_graph/index.ts | 16 +++- .../nodes/translations.ts | 67 ++++++++++++++ .../graphs/default_assistant_graph/prompts.ts | 90 ++----------------- .../server/routes/chat/chat_complete_route.ts | 5 ++ .../server/routes/evaluate/post_evaluate.ts | 18 +++- .../server/routes/helpers.ts | 4 + .../routes/post_actions_connector_execute.ts | 5 ++ 9 files changed, 117 insertions(+), 98 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 1f8dd143ce9ed..949f7927b1d68 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -15,6 +15,7 @@ import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { PublicMethodsOf } from '@kbn/utility-types'; import type { InferenceServerStart } from '@kbn/inference-plugin/server'; +import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { ResponseBody } from '../types'; import type { AssistantTool } from '../../../types'; import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; @@ -42,6 +43,7 @@ export interface AgentExecutorParams { assistantTools?: AssistantTool[]; connectorId: string; connectorApiUrl?: string; + connectorApiProvider?: OpenAiProviderType; conversationId?: string; dataClients?: AssistantDataClients; esClient: ElasticsearchClient; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 1b5bdadbc5d84..a012420b38c44 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -164,8 +164,6 @@ export const streamGraph = async ({ }; const processSimpleChatModelEvent = () => { - // console.log(`[TEST] currentOutput: ${currentOutput}`); - // console.log(`[TEST] finalMessage: ${finalMessage}`); if (event.event === 'on_llm_stream') { const chunk = event.data?.chunk; @@ -184,7 +182,6 @@ export const streamGraph = async ({ } } else if (!streamingFinished && !didEnd) { if (msg.startsWith('"') && finalMessage.endsWith('\\')) { - // console.log(`[TEST] finalMessage: ${finalMessage}, msg: ${msg}`); finalMessage = finalMessage.slice(0, -1); msg = `\\${msg}`; } @@ -202,11 +199,6 @@ export const streamGraph = async ({ } } } else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) { - // Sometimes llama returns extra escape backslash characters which breaks the markdown. - // One of the solutions that I've found is to use `JSON.parse` to remove those. - // console.log(`[TEST] finalMessage 1: ${finalMessage}`); - // finalMessage = JSON.parse(`{"content":"${finalMessage}"}`).content; - // console.log(`[TEST] finalMessage 2: ${finalMessage}`); handleStreamEnd(finalMessage); } }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index b4111ac48b33e..e96cbb4440730 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -14,7 +14,10 @@ import { createToolCallingAgent, } from 'langchain/agents'; import { APMTracer } from '@kbn/langchain/server/tracers/apm'; -import { OPENAI_CHAT_URL } from '@kbn/stack-connectors-plugin/common/openai/constants'; +import { + OPENAI_CHAT_URL, + OpenAiProviderType, +} from '@kbn/stack-connectors-plugin/common/openai/constants'; import { getLlmClass } from '../../../../routes/utils'; import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types'; import { AssistantToolParams } from '../../../../types'; @@ -34,6 +37,7 @@ export const callAssistantGraph: AgentExecutor = async ({ alertsIndexPattern, assistantTools = [], connectorApiUrl, + connectorApiProvider, bedrockChatEnabled, connectorId, conversationId, @@ -56,7 +60,11 @@ export const callAssistantGraph: AgentExecutor = async ({ }) => { const logger = parentLogger.get('defaultAssistantGraph'); const isOpeAIType = llmType === 'openai'; - const isOpenAI = isOpeAIType && (!connectorApiUrl || connectorApiUrl === OPENAI_CHAT_URL); + const isOpenAI = + isOpeAIType && + (!connectorApiUrl || + connectorApiUrl === OPENAI_CHAT_URL || + connectorApiProvider === OpenAiProviderType.AzureAi); const isOssLlm = isOpeAIType && !isOpenAI; const llmClass = getLlmClass(llmType, bedrockChatEnabled); @@ -158,7 +166,9 @@ export const callAssistantGraph: AgentExecutor = async ({ : await createStructuredChatAgent({ llm: createLlmInstance(), tools, - prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt), + prompt: isOssLlm + ? formatPromptStructured(systemPrompts.ossLlm, systemPrompt) + : formatPromptStructured(systemPrompts.structuredChat, systemPrompt), streamRunnable: isStream, }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts index ae8e3c18c2217..73479812ba064 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts @@ -15,3 +15,70 @@ export const GEMINI_SYSTEM_PROMPT = `ALWAYS use the provided tools, as they have access to the latest data and syntax.` + "The final response is the only output the user sees and should be a complete answer to the user's question. Do not leave out important tool output. The final response should never be empty. Don't forget to use tools."; export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return tags in the response, but make sure to include tags content in the response. Do not reflect on the quality of the returned search results in your response.`; + +export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. You have access to the following tools: + +{tools} + +The tool action_input should ALWAYS follow the tool JSON schema args. + +Valid "action" values: "Final Answer" or {tool_names} + +Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args). + +Provide only ONE action per $JSON_BLOB, as shown: + +\`\`\` + +{{ + + "action": $TOOL_NAME, + + "action_input": $TOOL_INPUT + +}} + +\`\`\` + +Follow this format: + +Question: input question to answer + +Thought: consider previous and subsequent steps + +Action: + +\`\`\` + +$JSON_BLOB + +\`\`\` + +Observation: action result + +... (repeat Thought/Action/Observation N times) + +Thought: I know what to respond + +Action: + +\`\`\` + +{{ + + "action": "Final Answer", + + "action_input": "Final response to human"}} + +Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`; + +export const OSS_SYSTEM_PROMPT = ` +Use tools as often as possible, as they have access to the latest data and syntax. + +When using ESQLKnowledgeBaseTool pass the user's questions directly as input into the tool. + +Always return value from ESQLKnowledgeBaseTool as is. + +The ES|QL query should always be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks. + +It is important that ES|QL query is preceeded by a new line.`; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts index e0abcf5ca31c1..19d27cb5e82de 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts @@ -10,6 +10,8 @@ import { BEDROCK_SYSTEM_PROMPT, DEFAULT_SYSTEM_PROMPT, GEMINI_SYSTEM_PROMPT, + OSS_SYSTEM_PROMPT, + STRUCTURED_SYSTEM_PROMPT, } from './nodes/translations'; export const formatPrompt = (prompt: string, additionalPrompt?: string) => @@ -20,94 +22,12 @@ export const formatPrompt = (prompt: string, additionalPrompt?: string) => ['placeholder', '{agent_scratchpad}'], ]); -const PROMPT_1 = `ALWAYS use the provided tools, as they have access to the latest data and syntax. - -Always return value from ESQLKnowledgeBaseTool as is. Do not reflect on the quality of the returned search results in your response.`; -const PROMPT_2 = `${BEDROCK_SYSTEM_PROMPT}`; -const PROMPT_2_1 = `Use tools as often as possible, as they have access to the latest data and syntax. The result returned from ESQLKnowledgeBaseTool is a string which should not be modified and should ALWAYS be returned as is. Do not reflect on the quality of the returned search results in your response.`; -const PROMPT_3 = `${GEMINI_SYSTEM_PROMPT}`; -const PROMPT_4_0 = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Do not reflect on the quality of the returned search results in your response. Final ES|QL query should always be wrapped in tripple backticks and be put on a new line.`; -const PROMPT_4_1 = `ALWAYS use the provided tools, as they have access to the latest data and syntax. ALWAYS pass the whole user input to ESQLKnowledgeBaseTool. ALWAYS return value from ESQLKnowledgeBaseTool as is.`; -const PROMPT_5 = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is and use it as a final answer without modifying it. Do not reflect on the quality of the returned search results in your response.`; -const PROMPT_6 = ` -Use tools as often as possible, as they have access to the latest data and syntax. - -When using ESQLKnowledgeBaseTool pass the user's questions directly as input into the tool. - -Always return value from ESQLKnowledgeBaseTool as is. - -The ES|QL query should always be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks. - -It is important that ES|QL query is preceeded by a new line.`; - -// export const GEMINI_SYSTEM_PROMPT = -// `ALWAYS use the provided tools, as they have access to the latest data and syntax.` + -// "The final response is the only output the user sees and should be a complete answer to the user's question. Do not leave out important tool output. The final response should never be empty. Don't forget to use tools."; -// export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return tags in the response, but make sure to include tags content in the response. Do not reflect on the quality of the returned search results in your response.`; - export const systemPrompts = { openai: DEFAULT_SYSTEM_PROMPT, bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`, gemini: `${DEFAULT_SYSTEM_PROMPT} ${GEMINI_SYSTEM_PROMPT}`, - structuredChat: `${DEFAULT_SYSTEM_PROMPT} - -Respond to the human as helpfully and accurately as possible. You have access to the following tools: - -{tools} - -The tool action_input should ALWAYS follow the tool JSON schema args. - -Valid "action" values: "Final Answer" or {tool_names} - -Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args). - -Provide only ONE action per $JSON_BLOB, as shown: - -\`\`\` - -{{ - - "action": $TOOL_NAME, - - "action_input": $TOOL_INPUT - -}} - -\`\`\` - -Follow this format: - -Question: input question to answer - -Thought: consider previous and subsequent steps - -Action: - -\`\`\` - -$JSON_BLOB - -\`\`\` - -Observation: action result - -... (repeat Thought/Action/Observation N times) - -Thought: I know what to respond - -Action: - -\`\`\` - -{{ - - "action": "Final Answer", - - "action_input": "Final response to human"}} - -Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation - -${PROMPT_6}`, + ossLlm: `${DEFAULT_SYSTEM_PROMPT} ${STRUCTURED_SYSTEM_PROMPT} ${OSS_SYSTEM_PROMPT}`, + structuredChat: STRUCTURED_SYSTEM_PROMPT, }; export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai); @@ -127,3 +47,5 @@ export const formatPromptStructured = (prompt: string, additionalPrompt?: string ]); export const structuredChatAgentPrompt = formatPromptStructured(systemPrompts.structuredChat); + +export const ossLlmStructuredChatAgentPrompt = formatPromptStructured(systemPrompts.ossLlm); diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index 7e7e0462f6f38..d9f04dbd9e6d7 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -19,6 +19,7 @@ import { } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; +import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { INVOKE_ASSISTANT_ERROR_EVENT } from '../../lib/telemetry/event_based_telemetry'; import { ElasticAssistantPluginRouter, GetElser } from '../../types'; import { buildResponse } from '../../lib/build_response'; @@ -104,6 +105,9 @@ export const chatCompleteRoute = ( const connectorApiUrl = connector?.config?.apiUrl ? (connector.config.apiUrl as string) : undefined; + const connectorApiProvider = connector?.config?.apiProvider + ? (connector?.config?.apiProvider as OpenAiProviderType) + : undefined; // replacements const anonymizationFieldsRes = @@ -197,6 +201,7 @@ export const chatCompleteRoute = ( actionTypeId, connectorId, connectorApiUrl, + connectorApiProvider, conversationId: conversationId ?? newConversation?.id, context: ctx, getElser, diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index 38e1b442f1c30..7a51fae63301b 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -30,7 +30,10 @@ import { createToolCallingAgent, } from 'langchain/agents'; import { RetrievalQAChain } from 'langchain/chains'; -import { OPENAI_CHAT_URL } from '@kbn/stack-connectors-plugin/common/openai/constants'; +import { + OPENAI_CHAT_URL, + OpenAiProviderType, +} from '@kbn/stack-connectors-plugin/common/openai/constants'; import { buildResponse } from '../../lib/build_response'; import { AssistantDataClients } from '../../lib/langchain/executors/types'; import { AssistantToolParams, ElasticAssistantRequestHandlerContext, GetElser } from '../../types'; @@ -48,6 +51,7 @@ import { bedrockToolCallingAgentPrompt, geminiToolCallingAgentPrompt, openAIFunctionAgentPrompt, + ossLlmStructuredChatAgentPrompt, structuredChatAgentPrompt, } from '../../lib/langchain/graphs/default_assistant_graph/prompts'; import { getLlmClass, getLlmType } from '../utils'; @@ -199,8 +203,16 @@ export const postEvaluateRoute = ( const connectorApiUrl = connector?.config?.apiUrl ? (connector.config.apiUrl as string) : undefined; + const connectorApiProvider = connector?.config?.apiProvider + ? (connector?.config?.apiProvider as OpenAiProviderType) + : undefined; + const isOpeAIType = llmType === 'openai'; const isOpenAI = - llmType === 'openai' && (!connectorApiUrl || connectorApiUrl === OPENAI_CHAT_URL); + isOpeAIType && + (!connectorApiUrl || + connectorApiUrl === OPENAI_CHAT_URL || + connectorApiProvider === OpenAiProviderType.AzureAi); + const isOssLlm = isOpeAIType && !isOpenAI; const llmClass = getLlmClass(llmType, true); const createLlmInstance = () => new llmClass({ @@ -294,7 +306,7 @@ export const postEvaluateRoute = ( : await createStructuredChatAgent({ llm, tools, - prompt: structuredChatAgentPrompt, + prompt: isOssLlm ? ossLlmStructuredChatAgentPrompt : structuredChatAgentPrompt, streamRunnable: false, }); diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index b717250f062f7..334b05e0bd115 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -29,6 +29,7 @@ import { ActionsClient } from '@kbn/actions-plugin/server'; import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; import type { InferenceServerStart } from '@kbn/inference-plugin/server'; +import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { AIAssistantKnowledgeBaseDataClient } from '../ai_assistant_data_clients/knowledge_base'; import { FindResponse } from '../ai_assistant_data_clients/find'; import { EsPromptsSchema } from '../ai_assistant_data_clients/prompts/types'; @@ -324,6 +325,7 @@ export interface LangChainExecuteParams { connectorId: string; inference: InferenceServerStart; connectorApiUrl?: string; + connectorApiProvider?: OpenAiProviderType; conversationId?: string; context: AwaitedProperties< Pick @@ -351,6 +353,7 @@ export const langChainExecute = async ({ actionTypeId, connectorId, connectorApiUrl, + connectorApiProvider, context, actionsClient, inference, @@ -423,6 +426,7 @@ export const langChainExecute = async ({ conversationId, connectorId, connectorApiUrl, + connectorApiProvider, esClient, esStore, inference, diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index e59a479f4031a..fc1beef2e45ad 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -17,6 +17,7 @@ import { Replacements, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; +import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { INVOKE_ASSISTANT_ERROR_EVENT } from '../lib/telemetry/event_based_telemetry'; import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants'; import { buildResponse } from '../lib/build_response'; @@ -99,6 +100,9 @@ export const postActionsConnectorExecuteRoute = ( const connectorApiUrl = connector?.config?.apiUrl ? (connector.config.apiUrl as string) : undefined; + const connectorApiProvider = connector?.config?.apiProvider + ? (connector?.config?.apiProvider as OpenAiProviderType) + : undefined; const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient(); @@ -135,6 +139,7 @@ export const postActionsConnectorExecuteRoute = ( actionTypeId, connectorId, connectorApiUrl, + connectorApiProvider, conversationId, context: ctx, getElser, From 69039092cce8d7915597a8be86af61be324a0497 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Tue, 24 Sep 2024 12:55:26 +0200 Subject: [PATCH 06/17] Enable `NaturalLanguageESQLTool` for OSS models like Llama --- .../graphs/default_assistant_graph/helpers.ts | 10 +++++----- .../langchain/graphs/default_assistant_graph/index.ts | 8 ++++---- .../server/routes/evaluate/post_evaluate.ts | 9 ++++++--- x-pack/plugins/elastic_assistant/server/types.ts | 1 + .../esql_language_knowledge_base/nl_to_esql_tool.ts | 3 ++- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index a012420b38c44..7a0a58e93502b 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -24,7 +24,7 @@ interface StreamGraphParams { assistantGraph: DefaultAssistantGraph; inputs: GraphInputs; logger: Logger; - isOssLlm?: boolean; + isOssModel?: boolean; onLlmResponse?: OnLlmResponse; request: KibanaRequest; traceOptions?: TraceOptions; @@ -37,7 +37,7 @@ interface StreamGraphParams { * @param assistantGraph * @param inputs * @param logger - * @param isOssLlm + * @param isOssModel * @param onLlmResponse * @param request * @param traceOptions @@ -48,7 +48,7 @@ export const streamGraph = async ({ assistantGraph, inputs, logger, - isOssLlm, + isOssModel, onLlmResponse, request, traceOptions, @@ -167,7 +167,7 @@ export const streamGraph = async ({ if (event.event === 'on_llm_stream') { const chunk = event.data?.chunk; - let msg = isOssLlm ? chunk.message.content : chunk.content; + let msg = isOssModel ? chunk.message.content : chunk.content; if (finalOutputIndex === -1) { currentOutput += msg; // Remove whitespace to simplify parsing @@ -206,7 +206,7 @@ export const streamGraph = async ({ // only process events that are part of the agent run if ((event.tags || []).includes(AGENT_NODE_TAG)) { if (event.name === 'ActionsClientChatOpenAI') { - if (isOssLlm) { + if (isOssModel) { processSimpleChatModelEvent(); } else { processOpenAIEvent(); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index e96cbb4440730..c434f264c8eb9 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -65,7 +65,7 @@ export const callAssistantGraph: AgentExecutor = async ({ (!connectorApiUrl || connectorApiUrl === OPENAI_CHAT_URL || connectorApiProvider === OpenAiProviderType.AzureAi); - const isOssLlm = isOpeAIType && !isOpenAI; + const isOssModel = isOpeAIType && !isOpenAI; const llmClass = getLlmClass(llmType, bedrockChatEnabled); /** @@ -132,7 +132,7 @@ export const callAssistantGraph: AgentExecutor = async ({ }; const tools: StructuredTool[] = assistantTools.flatMap( - (tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? [] + (tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance(), isOssModel }) ?? [] ); // If KB enabled, fetch for any KB IndexEntries and generate a tool for each @@ -166,7 +166,7 @@ export const callAssistantGraph: AgentExecutor = async ({ : await createStructuredChatAgent({ llm: createLlmInstance(), tools, - prompt: isOssLlm + prompt: isOssModel ? formatPromptStructured(systemPrompts.ossLlm, systemPrompt) : formatPromptStructured(systemPrompts.structuredChat, systemPrompt), streamRunnable: isStream, @@ -198,7 +198,7 @@ export const callAssistantGraph: AgentExecutor = async ({ assistantGraph, inputs, logger, - isOssLlm, + isOssModel, onLlmResponse, request, traceOptions, diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index 7a51fae63301b..36d9bfaae69be 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -212,7 +212,7 @@ export const postEvaluateRoute = ( (!connectorApiUrl || connectorApiUrl === OPENAI_CHAT_URL || connectorApiProvider === OpenAiProviderType.AzureAi); - const isOssLlm = isOpeAIType && !isOpenAI; + const isOssModel = isOpeAIType && !isOpenAI; const llmClass = getLlmClass(llmType, true); const createLlmInstance = () => new llmClass({ @@ -271,6 +271,7 @@ export const postEvaluateRoute = ( isEnabledKnowledgeBase, kbDataClient: dataClients?.kbDataClient, llm, + isOssModel, logger, modelExists: isEnabledKnowledgeBase, request: skeletonRequest, @@ -306,7 +307,9 @@ export const postEvaluateRoute = ( : await createStructuredChatAgent({ llm, tools, - prompt: isOssLlm ? ossLlmStructuredChatAgentPrompt : structuredChatAgentPrompt, + prompt: isOssModel + ? ossLlmStructuredChatAgentPrompt + : structuredChatAgentPrompt, streamRunnable: false, }); @@ -349,7 +352,7 @@ export const postEvaluateRoute = ( return output; }; - const evalOutput = await evaluate(predict, { + const evalOutput = evaluate(predict, { data: datasetName ?? '', evaluators: [], // Evals to be managed in LangSmith for now experimentPrefix: name, diff --git a/x-pack/plugins/elastic_assistant/server/types.ts b/x-pack/plugins/elastic_assistant/server/types.ts index e685c1d4e9358..8b4bdac2d475c 100755 --- a/x-pack/plugins/elastic_assistant/server/types.ts +++ b/x-pack/plugins/elastic_assistant/server/types.ts @@ -244,6 +244,7 @@ export interface AssistantToolParams { kbDataClient?: AIAssistantKnowledgeBaseDataClient; langChainTimeout?: number; llm?: ActionsClientLlm | AssistantToolLlm; + isOssModel?: boolean; logger: Logger; modelExists: boolean; onNewReplacements?: (newReplacements: Replacements) => void; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts index b5dc209043d5d..11fc466be1abb 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts @@ -42,7 +42,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = { getTool(params: ESQLToolParams) { if (!this.isSupported(params)) return null; - const { connectorId, inference, logger, request } = params as ESQLToolParams; + const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams; if (inference == null || connectorId == null) return null; const callNaturalLanguageToEsql = async (question: string) => { @@ -51,6 +51,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = { client: inference.getClient({ request }), connectorId, input: question, + ...(isOssModel ? { functionCalling: 'simulated' } : {}), logger: { debug: (source) => { logger.debug(typeof source === 'function' ? source() : source); From 7436f346a98a28c780e1b6471b25b2399dfa17a9 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Tue, 24 Sep 2024 13:29:49 +0200 Subject: [PATCH 07/17] Fix the issue with extra escape backslash characters which breaks the markdown --- .../graphs/default_assistant_graph/helpers.ts | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 7a0a58e93502b..9428ed40da5ec 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -167,7 +167,7 @@ export const streamGraph = async ({ if (event.event === 'on_llm_stream') { const chunk = event.data?.chunk; - let msg = isOssModel ? chunk.message.content : chunk.content; + const msg = isOssModel ? chunk.message.content : chunk.content; if (finalOutputIndex === -1) { currentOutput += msg; // Remove whitespace to simplify parsing @@ -182,8 +182,9 @@ export const streamGraph = async ({ } } else if (!streamingFinished && !didEnd) { if (msg.startsWith('"') && finalMessage.endsWith('\\')) { - finalMessage = finalMessage.slice(0, -1); - msg = `\\${msg}`; + push({ payload: msg, type: 'content' }); + finalMessage += msg; + return; } const finalOutputEndIndex = msg.search(finalOutputStopRegex); if (finalOutputEndIndex !== -1) { @@ -199,6 +200,11 @@ export const streamGraph = async ({ } } } else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) { + // Sometimes llama returns extra escape backslash characters which breaks the markdown. + // One of the solutions that I've found is to use `JSON.parse` to remove those. + // console.log(`[TEST] finalMessage 1: ${finalMessage}`); + finalMessage = JSON.parse(`{"finalMessage":"${finalMessage}"}`).finalMessage; + // console.log(`[TEST] finalMessage 2: ${finalMessage}`); handleStreamEnd(finalMessage); } }; From ae638e697039fdbeef44e252cf986df55bbc2bdd Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Tue, 24 Sep 2024 17:48:22 +0200 Subject: [PATCH 08/17] Revert streaming events parsing --- .../graphs/default_assistant_graph/helpers.ts | 117 +++++++----------- 1 file changed, 47 insertions(+), 70 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 9428ed40da5ec..7b35f35269e84 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -142,84 +142,61 @@ export const streamGraph = async ({ if (done) return; const event = value; - - const processOpenAIEvent = () => { - if (event.event === 'on_llm_stream') { - const chunk = event.data?.chunk; - const msg = chunk.message; - if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) { - // I don't think we hit this anymore because of our check for AGENT_NODE_TAG - // however, no harm to keep it in - /* empty */ - } else if (!didEnd) { - push({ payload: msg.content, type: 'content' }); - finalMessage += msg.content; - } - } else if (event.event === 'on_llm_end' && !didEnd) { - const generations = event.data.output?.generations[0]; - if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { - handleStreamEnd(generations[0]?.text ?? finalMessage); + // only process events that are part of the agent run + if ((event.tags || []).includes(AGENT_NODE_TAG)) { + if (event.name === 'ActionsClientChatOpenAI') { + if (event.event === 'on_llm_stream') { + const chunk = event.data?.chunk; + const msg = chunk.message; + if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) { + // I don't think we hit this anymore because of our check for AGENT_NODE_TAG + // however, no harm to keep it in + /* empty */ + } else if (!didEnd) { + push({ payload: msg.content, type: 'content' }); + finalMessage += msg.content; + } + } else if (event.event === 'on_llm_end' && !didEnd) { + const generations = event.data.output?.generations[0]; + if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { + handleStreamEnd(generations[0]?.text ?? finalMessage); + } } } - }; - - const processSimpleChatModelEvent = () => { - if (event.event === 'on_llm_stream') { - const chunk = event.data?.chunk; + if (event.name === 'ActionsClientSimpleChatModel') { + if (event.event === 'on_llm_stream') { + const chunk = event.data?.chunk; - const msg = isOssModel ? chunk.message.content : chunk.content; - if (finalOutputIndex === -1) { - currentOutput += msg; - // Remove whitespace to simplify parsing - const noWhitespaceOutput = currentOutput.replace(/\s/g, ''); - if (noWhitespaceOutput.includes(finalOutputStartToken)) { - const nonStrippedToken = '"action_input": "'; - finalOutputIndex = currentOutput.lastIndexOf(nonStrippedToken); - const contentStartIndex = finalOutputIndex + nonStrippedToken.length; - extraOutput = currentOutput.substring(contentStartIndex); - push({ payload: extraOutput, type: 'content' }); - finalMessage += extraOutput; - } - } else if (!streamingFinished && !didEnd) { - if (msg.startsWith('"') && finalMessage.endsWith('\\')) { - push({ payload: msg, type: 'content' }); - finalMessage += msg; - return; - } - const finalOutputEndIndex = msg.search(finalOutputStopRegex); - if (finalOutputEndIndex !== -1) { - extraOutput = msg.substring(0, finalOutputEndIndex); - streamingFinished = true; - if (extraOutput.length > 0) { + const msg = chunk.content; + if (finalOutputIndex === -1) { + currentOutput += msg; + // Remove whitespace to simplify parsing + const noWhitespaceOutput = currentOutput.replace(/\s/g, ''); + if (noWhitespaceOutput.includes(finalOutputStartToken)) { + const nonStrippedToken = '"action_input": "'; + finalOutputIndex = currentOutput.indexOf(nonStrippedToken); + const contentStartIndex = finalOutputIndex + nonStrippedToken.length; + extraOutput = currentOutput.substring(contentStartIndex); push({ payload: extraOutput, type: 'content' }); finalMessage += extraOutput; } - } else { - push({ payload: msg, type: 'content' }); - finalMessage += msg; + } else if (!streamingFinished && !didEnd) { + const finalOutputEndIndex = msg.search(finalOutputStopRegex); + if (finalOutputEndIndex !== -1) { + extraOutput = msg.substring(0, finalOutputEndIndex); + streamingFinished = true; + if (extraOutput.length > 0) { + push({ payload: extraOutput, type: 'content' }); + finalMessage += extraOutput; + } + } else { + push({ payload: chunk.content, type: 'content' }); + finalMessage += chunk.content; + } } + } else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) { + handleStreamEnd(finalMessage); } - } else if (event.event === 'on_llm_end' && streamingFinished && !didEnd) { - // Sometimes llama returns extra escape backslash characters which breaks the markdown. - // One of the solutions that I've found is to use `JSON.parse` to remove those. - // console.log(`[TEST] finalMessage 1: ${finalMessage}`); - finalMessage = JSON.parse(`{"finalMessage":"${finalMessage}"}`).finalMessage; - // console.log(`[TEST] finalMessage 2: ${finalMessage}`); - handleStreamEnd(finalMessage); - } - }; - - // only process events that are part of the agent run - if ((event.tags || []).includes(AGENT_NODE_TAG)) { - if (event.name === 'ActionsClientChatOpenAI') { - if (isOssModel) { - processSimpleChatModelEvent(); - } else { - processOpenAIEvent(); - } - } - if (event.name === 'ActionsClientSimpleChatModel') { - processSimpleChatModelEvent(); } } From a03e5ea3e4ad191d04cd7cf4f336ebff8d0fb3a2 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Tue, 24 Sep 2024 18:09:18 +0200 Subject: [PATCH 09/17] Simplified OSS model streaming --- .../lib/langchain/graphs/default_assistant_graph/graph.ts | 4 ++++ .../langchain/graphs/default_assistant_graph/helpers.ts | 8 +++++--- .../lib/langchain/graphs/default_assistant_graph/index.ts | 1 + .../graphs/default_assistant_graph/nodes/model_input.ts | 4 +++- .../lib/langchain/graphs/default_assistant_graph/types.ts | 2 ++ .../server/routes/evaluate/post_evaluate.ts | 5 ++++- 6 files changed, 19 insertions(+), 5 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 8395076ad62ee..8f2f713c170ed 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -94,6 +94,10 @@ export const getDefaultAssistantGraph = ({ value: (x: boolean, y?: boolean) => y ?? x, default: () => false, }, + isOssModel: { + value: (x: boolean, y?: boolean) => y ?? x, + default: () => false, + }, conversation: { value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) => y ?? x, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 7b35f35269e84..0545e7083abcb 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -84,8 +84,8 @@ export const streamGraph = async ({ }; if ( - (inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && - inputs?.bedrockChatEnabled + inputs.isOssModel || + ((inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && inputs?.bedrockChatEnabled) ) { const stream = await assistantGraph.streamEvents( inputs, @@ -96,7 +96,9 @@ export const streamGraph = async ({ version: 'v2', streamMode: 'values', }, - inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined + inputs.isOssModel || inputs?.llmType === 'bedrock' + ? { includeNames: ['Summarizer'] } + : undefined ); for await (const { event, data, tags } of stream) { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index c434f264c8eb9..f62d73f77b724 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -189,6 +189,7 @@ export const callAssistantGraph: AgentExecutor = async ({ conversationId, llmType, isStream, + isOssModel, input: latestMessage[0]?.content as string, }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/model_input.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/model_input.ts index f634d10f5cd4a..5f46e1ad2a741 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/model_input.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/model_input.ts @@ -22,7 +22,9 @@ interface ModelInputParams extends NodeParamsBase { export function modelInput({ logger, state }: ModelInputParams): Partial { logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`); - const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock'; + const hasRespondStep = + state.isStream && + (state.isOssModel || (state.bedrockChatEnabled && state.llmType === 'bedrock')); return { hasRespondStep, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts index 17d06b0f7042e..69632be2ffdcd 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts @@ -20,6 +20,7 @@ export interface GraphInputs { conversationId?: string; llmType?: string; isStream?: boolean; + isOssModel?: boolean; input: string; responseLanguage?: string; } @@ -31,6 +32,7 @@ export interface AgentState extends AgentStateBase { lastNode: string; hasRespondStep: boolean; isStream: boolean; + isOssModel: boolean; bedrockChatEnabled: boolean; llmType: string; responseLanguage: string; diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index 36d9bfaae69be..d97cd79401faa 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -197,6 +197,7 @@ export const postEvaluateRoute = ( name: string; graph: DefaultAssistantGraph; llmType: string | undefined; + isOssModel: boolean | undefined; }> = await Promise.all( connectors.map(async (connector) => { const llmType = getLlmType(connector.actionTypeId); @@ -316,6 +317,7 @@ export const postEvaluateRoute = ( return { name: `${runName} - ${connector.name}`, llmType, + isOssModel, graph: getDefaultAssistantGraph({ agentRunnable, dataClients, @@ -329,7 +331,7 @@ export const postEvaluateRoute = ( ); // Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector) - await asyncForEach(graphs, async ({ name, graph, llmType }) => { + await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel }) => { // Wrapper function for invoking the graph (to parse different input/output formats) const predict = async (input: { input: string }) => { logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`); @@ -342,6 +344,7 @@ export const postEvaluateRoute = ( llmType, bedrockChatEnabled: true, isStreaming: false, + isOssModel, }, // TODO: Update to use the correct input format per dataset type { runName, From 5d2e1f230c8c2a3d31e058c410f1c394b511703e Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Wed, 25 Sep 2024 17:36:33 +0200 Subject: [PATCH 10/17] Add OSS model specific prompt to the tool description --- .../graphs/default_assistant_graph/index.ts | 4 +--- .../default_assistant_graph/nodes/translations.ts | 11 ----------- .../graphs/default_assistant_graph/prompts.ts | 4 ---- .../server/routes/evaluate/post_evaluate.ts | 5 +---- .../tools/esql_language_knowledge_base/common.ts | 15 +++++++++++++++ .../esql_language_knowledge_base_tool.ts | 12 ++++++++---- .../nl_to_esql_tool.ts | 4 +++- 7 files changed, 28 insertions(+), 27 deletions(-) create mode 100644 x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index f62d73f77b724..e8896fa00b4b2 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -166,9 +166,7 @@ export const callAssistantGraph: AgentExecutor = async ({ : await createStructuredChatAgent({ llm: createLlmInstance(), tools, - prompt: isOssModel - ? formatPromptStructured(systemPrompts.ossLlm, systemPrompt) - : formatPromptStructured(systemPrompts.structuredChat, systemPrompt), + prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt), streamRunnable: isStream, }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts index 73479812ba064..7f0a9fb5d300f 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/translations.ts @@ -71,14 +71,3 @@ Action: "action_input": "Final response to human"}} Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`; - -export const OSS_SYSTEM_PROMPT = ` -Use tools as often as possible, as they have access to the latest data and syntax. - -When using ESQLKnowledgeBaseTool pass the user's questions directly as input into the tool. - -Always return value from ESQLKnowledgeBaseTool as is. - -The ES|QL query should always be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks. - -It is important that ES|QL query is preceeded by a new line.`; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts index 19d27cb5e82de..e2f8f23c40399 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts @@ -10,7 +10,6 @@ import { BEDROCK_SYSTEM_PROMPT, DEFAULT_SYSTEM_PROMPT, GEMINI_SYSTEM_PROMPT, - OSS_SYSTEM_PROMPT, STRUCTURED_SYSTEM_PROMPT, } from './nodes/translations'; @@ -26,7 +25,6 @@ export const systemPrompts = { openai: DEFAULT_SYSTEM_PROMPT, bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`, gemini: `${DEFAULT_SYSTEM_PROMPT} ${GEMINI_SYSTEM_PROMPT}`, - ossLlm: `${DEFAULT_SYSTEM_PROMPT} ${STRUCTURED_SYSTEM_PROMPT} ${OSS_SYSTEM_PROMPT}`, structuredChat: STRUCTURED_SYSTEM_PROMPT, }; @@ -47,5 +45,3 @@ export const formatPromptStructured = (prompt: string, additionalPrompt?: string ]); export const structuredChatAgentPrompt = formatPromptStructured(systemPrompts.structuredChat); - -export const ossLlmStructuredChatAgentPrompt = formatPromptStructured(systemPrompts.ossLlm); diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index d97cd79401faa..f25a27367942c 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -51,7 +51,6 @@ import { bedrockToolCallingAgentPrompt, geminiToolCallingAgentPrompt, openAIFunctionAgentPrompt, - ossLlmStructuredChatAgentPrompt, structuredChatAgentPrompt, } from '../../lib/langchain/graphs/default_assistant_graph/prompts'; import { getLlmClass, getLlmType } from '../utils'; @@ -308,9 +307,7 @@ export const postEvaluateRoute = ( : await createStructuredChatAgent({ llm, tools, - prompt: isOssModel - ? ossLlmStructuredChatAgentPrompt - : structuredChatAgentPrompt, + prompt: structuredChatAgentPrompt, streamRunnable: false, }); diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts new file mode 100644 index 0000000000000..f43dcc777a0d8 --- /dev/null +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export const getExtraPromptForOssModel = (toolName: string) => ` + When using ${toolName} tool ALWAYS pass the user's questions directly as input into the tool. + + Always return value from ${toolName} tool as is. + + The ES|QL query should ALWAYS be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks. + + It is important that ES|QL query is preceeded by a new line.`; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts index 22fc85108ab65..3aeadcbebe265 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts @@ -9,14 +9,17 @@ import { DynamicStructuredTool } from '@langchain/core/tools'; import { z } from '@kbn/zod'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; import { APP_UI_ID } from '../../../../common'; +import { getExtraPromptForOssModel } from './common'; export type EsqlKnowledgeBaseToolParams = AssistantToolParams; +const TOOL_NAME = 'ESQLKnowledgeBaseTool'; + const toolDetails = { + id: 'esql-knowledge-base-tool', + name: TOOL_NAME, description: 'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the query on a single line, with no other text. Your answer will be parsed as JSON, so never use quotes within the output and instead use backticks. Do not add any additional text to describe your output.', - id: 'esql-knowledge-base-tool', - name: 'ESQLKnowledgeBaseTool', }; export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = { ...toolDetails, @@ -28,12 +31,13 @@ export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = { getTool(params: AssistantToolParams) { if (!this.isSupported(params)) return null; - const { chain } = params as EsqlKnowledgeBaseToolParams; + const { chain, isOssModel } = params as EsqlKnowledgeBaseToolParams; if (chain == null) return null; return new DynamicStructuredTool({ name: toolDetails.name, - description: toolDetails.description, + description: + toolDetails.description + (isOssModel ? getExtraPromptForOssModel(TOOL_NAME) : ''), schema: z.object({ question: z.string().describe(`The user's exact question about ESQL`), }), diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts index 11fc466be1abb..ecb6885a29889 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts @@ -11,6 +11,7 @@ import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant- import { lastValueFrom } from 'rxjs'; import { naturalLanguageToEsql } from '@kbn/inference-plugin/server'; import { APP_UI_ID } from '../../../../common'; +import { getExtraPromptForOssModel } from './common'; export type ESQLToolParams = AssistantToolParams; @@ -63,7 +64,8 @@ export const NL_TO_ESQL_TOOL: AssistantTool = { return new DynamicStructuredTool({ name: toolDetails.name, - description: toolDetails.description, + description: + toolDetails.description + (isOssModel ? getExtraPromptForOssModel(TOOL_NAME) : ''), schema: z.object({ question: z.string().describe(`The user's exact question about ESQL`), }), From f9eb9d7087cfd2f6cfda25b9052a451bebb61bc6 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Thu, 26 Sep 2024 15:28:51 +0200 Subject: [PATCH 11/17] Brush up implementation and add some unit tests --- .../server/lib/langchain/executors/types.ts | 4 +- .../graphs/default_assistant_graph/index.ts | 15 +- .../server/routes/chat/chat_complete_route.ts | 12 +- .../server/routes/evaluate/post_evaluate.ts | 21 +-- .../server/routes/helpers.ts | 10 +- .../routes/post_actions_connector_execute.ts | 12 +- .../server/routes/utils.test.ts | 69 ++++++++ .../elastic_assistant/server/routes/utils.ts | 28 +++ .../esql_language_knowledge_base/common.ts | 2 +- .../esql_language_knowledge_base_tool.test.ts | 23 +++ .../esql_language_knowledge_base_tool.ts | 4 +- .../nl_to_esql_tool.test.ts | 162 ++++++++++++++++++ .../nl_to_esql_tool.ts | 4 +- 13 files changed, 302 insertions(+), 64 deletions(-) create mode 100644 x-pack/plugins/elastic_assistant/server/routes/utils.test.ts create mode 100644 x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.test.ts diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 949f7927b1d68..a5b7c9120faaf 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -15,7 +15,6 @@ import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { PublicMethodsOf } from '@kbn/utility-types'; import type { InferenceServerStart } from '@kbn/inference-plugin/server'; -import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { ResponseBody } from '../types'; import type { AssistantTool } from '../../../types'; import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; @@ -42,14 +41,13 @@ export interface AgentExecutorParams { bedrockChatEnabled: boolean; assistantTools?: AssistantTool[]; connectorId: string; - connectorApiUrl?: string; - connectorApiProvider?: OpenAiProviderType; conversationId?: string; dataClients?: AssistantDataClients; esClient: ElasticsearchClient; esStore: ElasticsearchStore; langChainMessages: BaseMessage[]; llmType?: string; + isOssModel?: boolean; logger: Logger; inference: InferenceServerStart; onNewReplacements?: (newReplacements: Replacements) => void; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index e8896fa00b4b2..eca58904f6d04 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -14,10 +14,6 @@ import { createToolCallingAgent, } from 'langchain/agents'; import { APMTracer } from '@kbn/langchain/server/tracers/apm'; -import { - OPENAI_CHAT_URL, - OpenAiProviderType, -} from '@kbn/stack-connectors-plugin/common/openai/constants'; import { getLlmClass } from '../../../../routes/utils'; import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types'; import { AssistantToolParams } from '../../../../types'; @@ -36,8 +32,6 @@ export const callAssistantGraph: AgentExecutor = async ({ actionsClient, alertsIndexPattern, assistantTools = [], - connectorApiUrl, - connectorApiProvider, bedrockChatEnabled, connectorId, conversationId, @@ -47,6 +41,7 @@ export const callAssistantGraph: AgentExecutor = async ({ inference, langChainMessages, llmType, + isOssModel, logger: parentLogger, isStream = false, onLlmResponse, @@ -59,13 +54,7 @@ export const callAssistantGraph: AgentExecutor = async ({ responseLanguage = 'English', }) => { const logger = parentLogger.get('defaultAssistantGraph'); - const isOpeAIType = llmType === 'openai'; - const isOpenAI = - isOpeAIType && - (!connectorApiUrl || - connectorApiUrl === OPENAI_CHAT_URL || - connectorApiProvider === OpenAiProviderType.AzureAi); - const isOssModel = isOpeAIType && !isOpenAI; + const isOpenAI = llmType === 'openai' && !isOssModel; const llmClass = getLlmClass(llmType, bedrockChatEnabled); /** diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index d9f04dbd9e6d7..47f6f1a486957 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -19,7 +19,6 @@ import { } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { INVOKE_ASSISTANT_ERROR_EVENT } from '../../lib/telemetry/event_based_telemetry'; import { ElasticAssistantPluginRouter, GetElser } from '../../types'; import { buildResponse } from '../../lib/build_response'; @@ -31,6 +30,7 @@ import { } from '../helpers'; import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers'; import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types'; +import { isOpenSourceModel } from '../utils'; export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => { return `CONTEXT:\n"""\n${context}\n"""`; @@ -102,12 +102,7 @@ export const chatCompleteRoute = ( const connectors = await actionsClient.getBulk({ ids: [connectorId] }); const connector = connectors.length > 0 ? connectors[0] : undefined; actionTypeId = connector?.actionTypeId ?? '.gen-ai'; - const connectorApiUrl = connector?.config?.apiUrl - ? (connector.config.apiUrl as string) - : undefined; - const connectorApiProvider = connector?.config?.apiProvider - ? (connector?.config?.apiProvider as OpenAiProviderType) - : undefined; + const isOssModel = isOpenSourceModel(connector); // replacements const anonymizationFieldsRes = @@ -200,8 +195,7 @@ export const chatCompleteRoute = ( actionsClient, actionTypeId, connectorId, - connectorApiUrl, - connectorApiProvider, + isOssModel, conversationId: conversationId ?? newConversation?.id, context: ctx, getElser, diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index f25a27367942c..a4cd8c319495b 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -30,10 +30,6 @@ import { createToolCallingAgent, } from 'langchain/agents'; import { RetrievalQAChain } from 'langchain/chains'; -import { - OPENAI_CHAT_URL, - OpenAiProviderType, -} from '@kbn/stack-connectors-plugin/common/openai/constants'; import { buildResponse } from '../../lib/build_response'; import { AssistantDataClients } from '../../lib/langchain/executors/types'; import { AssistantToolParams, ElasticAssistantRequestHandlerContext, GetElser } from '../../types'; @@ -53,7 +49,7 @@ import { openAIFunctionAgentPrompt, structuredChatAgentPrompt, } from '../../lib/langchain/graphs/default_assistant_graph/prompts'; -import { getLlmClass, getLlmType } from '../utils'; +import { getLlmClass, getLlmType, isOpenSourceModel } from '../utils'; const DEFAULT_SIZE = 20; const ROUTE_HANDLER_TIMEOUT = 10 * 60 * 1000; // 10 * 60 seconds = 10 minutes @@ -200,19 +196,8 @@ export const postEvaluateRoute = ( }> = await Promise.all( connectors.map(async (connector) => { const llmType = getLlmType(connector.actionTypeId); - const connectorApiUrl = connector?.config?.apiUrl - ? (connector.config.apiUrl as string) - : undefined; - const connectorApiProvider = connector?.config?.apiProvider - ? (connector?.config?.apiProvider as OpenAiProviderType) - : undefined; - const isOpeAIType = llmType === 'openai'; - const isOpenAI = - isOpeAIType && - (!connectorApiUrl || - connectorApiUrl === OPENAI_CHAT_URL || - connectorApiProvider === OpenAiProviderType.AzureAi); - const isOssModel = isOpeAIType && !isOpenAI; + const isOssModel = isOpenSourceModel(connector); + const isOpenAI = llmType === 'openai' && !isOssModel; const llmClass = getLlmClass(llmType, true); const createLlmInstance = () => new llmClass({ diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 334b05e0bd115..987bc7e53d5d9 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -29,7 +29,6 @@ import { ActionsClient } from '@kbn/actions-plugin/server'; import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; import type { InferenceServerStart } from '@kbn/inference-plugin/server'; -import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { AIAssistantKnowledgeBaseDataClient } from '../ai_assistant_data_clients/knowledge_base'; import { FindResponse } from '../ai_assistant_data_clients/find'; import { EsPromptsSchema } from '../ai_assistant_data_clients/prompts/types'; @@ -324,8 +323,7 @@ export interface LangChainExecuteParams { actionTypeId: string; connectorId: string; inference: InferenceServerStart; - connectorApiUrl?: string; - connectorApiProvider?: OpenAiProviderType; + isOssModel?: boolean; conversationId?: string; context: AwaitedProperties< Pick @@ -352,8 +350,7 @@ export const langChainExecute = async ({ telemetry, actionTypeId, connectorId, - connectorApiUrl, - connectorApiProvider, + isOssModel, context, actionsClient, inference, @@ -425,13 +422,12 @@ export const langChainExecute = async ({ assistantTools, conversationId, connectorId, - connectorApiUrl, - connectorApiProvider, esClient, esStore, inference, isStream, llmType: getLlmType(actionTypeId), + isOssModel, langChainMessages, logger, onNewReplacements, diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index fc1beef2e45ad..4b65b5bb3f1e5 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -17,7 +17,6 @@ import { Replacements, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; -import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; import { INVOKE_ASSISTANT_ERROR_EVENT } from '../lib/telemetry/event_based_telemetry'; import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants'; import { buildResponse } from '../lib/build_response'; @@ -30,6 +29,7 @@ import { getSystemPromptFromUserConversation, langChainExecute, } from './helpers'; +import { isOpenSourceModel } from './utils'; export const postActionsConnectorExecuteRoute = ( router: IRouter, @@ -97,12 +97,7 @@ export const postActionsConnectorExecuteRoute = ( const actionsClient = await actions.getActionsClientWithRequest(request); const connectors = await actionsClient.getBulk({ ids: [connectorId] }); const connector = connectors.length > 0 ? connectors[0] : undefined; - const connectorApiUrl = connector?.config?.apiUrl - ? (connector.config.apiUrl as string) - : undefined; - const connectorApiProvider = connector?.config?.apiProvider - ? (connector?.config?.apiProvider as OpenAiProviderType) - : undefined; + const isOssModel = isOpenSourceModel(connector); const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient(); @@ -138,8 +133,7 @@ export const postActionsConnectorExecuteRoute = ( actionsClient, actionTypeId, connectorId, - connectorApiUrl, - connectorApiProvider, + isOssModel, conversationId, context: ctx, getElser, diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts new file mode 100644 index 0000000000000..3ca1b8edb5036 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Connector } from '@kbn/actions-plugin/server/application/connector/types'; +import { isOpenSourceModel } from './utils'; +import { + OPENAI_CHAT_URL, + OpenAiProviderType, +} from '@kbn/stack-connectors-plugin/common/openai/constants'; + +describe('Utils', () => { + describe('isOpenSourceModel', () => { + it('should return `false` when connector is undefined', async () => { + const isOpenModel = isOpenSourceModel(); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a Bedrock', async () => { + const connector = { actionTypeId: '.bedrock' } as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a Gemini', async () => { + const connector = { actionTypeId: '.gemini' } as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a OpenAI and API url is not specified', async () => { + const connector = { + actionTypeId: '.gen-ai', + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a OpenAI and OpenAI API url is specified', async () => { + const connector = { + actionTypeId: '.gen-ai', + config: { apiUrl: OPENAI_CHAT_URL }, + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a AzureOpenAI', async () => { + const connector = { + actionTypeId: '.gen-ai', + config: { apiProvider: OpenAiProviderType.AzureAi }, + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `true` when connector is a OpenAI and non-OpenAI API url is specified', async () => { + const connector = { + actionTypeId: '.gen-ai', + config: { apiUrl: 'https://elastic.llm.com/llama/chat/completions' }, + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(true); + }); + }); +}); diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index e163526d996ae..408c72cb73b52 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -19,6 +19,11 @@ import { ActionsClientSimpleChatModel, ActionsClientGeminiChatModel, } from '@kbn/langchain/server'; +import { Connector } from '@kbn/actions-plugin/server/application/connector/types'; +import { + OPENAI_CHAT_URL, + OpenAiProviderType, +} from '@kbn/stack-connectors-plugin/common/openai/constants'; import { CustomHttpRequestError } from './custom_http_request_error'; export interface OutputError { @@ -189,3 +194,26 @@ export const getLlmClass = (llmType?: string, bedrockChatEnabled?: boolean) => : llmType === 'gemini' && bedrockChatEnabled ? ActionsClientGeminiChatModel : ActionsClientSimpleChatModel; + +export const isOpenSourceModel = (connector?: Connector): boolean => { + if (connector == null) { + return false; + } + + const llmType = getLlmType(connector.actionTypeId); + const connectorApiUrl = connector.config?.apiUrl + ? (connector.config.apiUrl as string) + : undefined; + const connectorApiProvider = connector.config?.apiProvider + ? (connector.config?.apiProvider as OpenAiProviderType) + : undefined; + + const isOpeAiType = llmType === 'openai'; + const isOpenAI = + isOpeAiType && + (!connectorApiUrl || + connectorApiUrl === OPENAI_CHAT_URL || + connectorApiProvider === OpenAiProviderType.AzureAi); + + return isOpeAiType && !isOpenAI; +}; diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts index f43dcc777a0d8..ee2bee8fab806 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/common.ts @@ -5,7 +5,7 @@ * 2.0. */ -export const getExtraPromptForOssModel = (toolName: string) => ` +export const getPromptSuffixForOssModel = (toolName: string) => ` When using ${toolName} tool ALWAYS pass the user's questions directly as input into the tool. Always return value from ${toolName} tool as is. diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.test.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.test.ts index 29b10e9fb0275..6af8abf65a98a 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.test.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.test.ts @@ -12,6 +12,7 @@ import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; import type { KibanaRequest } from '@kbn/core-http-server'; import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen'; import { loggerMock } from '@kbn/logging-mocks'; +import { getPromptSuffixForOssModel } from './common'; describe('EsqlLanguageKnowledgeBaseTool', () => { const chain = {} as RetrievalQAChain; @@ -108,5 +109,27 @@ describe('EsqlLanguageKnowledgeBaseTool', () => { expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']); }); + + it('should return tool with the expected description for OSS model', () => { + const tool = ESQL_KNOWLEDGE_BASE_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + isOssModel: true, + ...rest, + }) as DynamicTool; + + expect(tool.description).toContain(getPromptSuffixForOssModel('ESQLKnowledgeBaseTool')); + }); + + it('should return tool with the expected description for non-OSS model', () => { + const tool = ESQL_KNOWLEDGE_BASE_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + isOssModel: false, + ...rest, + }) as DynamicTool; + + expect(tool.description).not.toContain(getPromptSuffixForOssModel('ESQLKnowledgeBaseTool')); + }); }); }); diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts index 3aeadcbebe265..4d65e0c9ccc99 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/esql_language_knowledge_base_tool.ts @@ -9,7 +9,7 @@ import { DynamicStructuredTool } from '@langchain/core/tools'; import { z } from '@kbn/zod'; import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server'; import { APP_UI_ID } from '../../../../common'; -import { getExtraPromptForOssModel } from './common'; +import { getPromptSuffixForOssModel } from './common'; export type EsqlKnowledgeBaseToolParams = AssistantToolParams; @@ -37,7 +37,7 @@ export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = { return new DynamicStructuredTool({ name: toolDetails.name, description: - toolDetails.description + (isOssModel ? getExtraPromptForOssModel(TOOL_NAME) : ''), + toolDetails.description + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''), schema: z.object({ question: z.string().describe(`The user's exact question about ESQL`), }), diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.test.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.test.ts new file mode 100644 index 0000000000000..f078bccb24a36 --- /dev/null +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.test.ts @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { RetrievalQAChain } from 'langchain/chains'; +import type { DynamicTool } from '@langchain/core/tools'; +import { NL_TO_ESQL_TOOL } from './nl_to_esql_tool'; +import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; +import type { KibanaRequest } from '@kbn/core-http-server'; +import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen'; +import { loggerMock } from '@kbn/logging-mocks'; +import { getPromptSuffixForOssModel } from './common'; +import type { InferenceServerStart } from '@kbn/inference-plugin/server'; + +describe('NaturalLanguageESQLTool', () => { + const chain = {} as RetrievalQAChain; + const esClient = { + search: jest.fn().mockResolvedValue({}), + } as unknown as ElasticsearchClient; + const request = { + body: { + isEnabledKnowledgeBase: false, + alertsIndexPattern: '.alerts-security.alerts-default', + allow: ['@timestamp', 'cloud.availability_zone', 'user.name'], + allowReplacement: ['user.name'], + replacements: { key: 'value' }, + size: 20, + }, + } as unknown as KibanaRequest; + const logger = loggerMock.create(); + const inference = {} as InferenceServerStart; + const connectorId = 'fake-connector'; + const rest = { + chain, + esClient, + logger, + request, + inference, + connectorId, + }; + + describe('isSupported', () => { + it('returns false if isEnabledKnowledgeBase is false', () => { + const params = { + isEnabledKnowledgeBase: false, + modelExists: true, + ...rest, + }; + + expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(false); + }); + + it('returns false if modelExists is false (the ELSER model is not installed)', () => { + const params = { + isEnabledKnowledgeBase: true, + modelExists: false, // <-- ELSER model is not installed + ...rest, + }; + + expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(false); + }); + + it('returns true if isEnabledKnowledgeBase and modelExists are true', () => { + const params = { + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + }; + + expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(true); + }); + }); + + describe('getTool', () => { + it('returns null if isEnabledKnowledgeBase is false', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: false, + modelExists: true, + ...rest, + }); + + expect(tool).toBeNull(); + }); + + it('returns null if modelExists is false (the ELSER model is not installed)', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: false, // <-- ELSER model is not installed + ...rest, + }); + + expect(tool).toBeNull(); + }); + + it('returns null if inference plugin is not provided', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + inference: undefined, + }); + + expect(tool).toBeNull(); + }); + + it('returns null if connectorId is not provided', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + connectorId: undefined, + }); + + expect(tool).toBeNull(); + }); + + it('should return a Tool instance if isEnabledKnowledgeBase and modelExists are true', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + }); + + expect(tool?.name).toEqual('NaturalLanguageESQLTool'); + }); + + it('should return a tool with the expected tags', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + ...rest, + }) as DynamicTool; + + expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']); + }); + + it('should return tool with the expected description for OSS model', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + isOssModel: true, + ...rest, + }) as DynamicTool; + + expect(tool.description).toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool')); + }); + + it('should return tool with the expected description for non-OSS model', () => { + const tool = NL_TO_ESQL_TOOL.getTool({ + isEnabledKnowledgeBase: true, + modelExists: true, + isOssModel: false, + ...rest, + }) as DynamicTool; + + expect(tool.description).not.toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool')); + }); + }); +}); diff --git a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts index ecb6885a29889..a617877615702 100644 --- a/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts +++ b/x-pack/plugins/security_solution/server/assistant/tools/esql_language_knowledge_base/nl_to_esql_tool.ts @@ -11,7 +11,7 @@ import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant- import { lastValueFrom } from 'rxjs'; import { naturalLanguageToEsql } from '@kbn/inference-plugin/server'; import { APP_UI_ID } from '../../../../common'; -import { getExtraPromptForOssModel } from './common'; +import { getPromptSuffixForOssModel } from './common'; export type ESQLToolParams = AssistantToolParams; @@ -65,7 +65,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = { return new DynamicStructuredTool({ name: toolDetails.name, description: - toolDetails.description + (isOssModel ? getExtraPromptForOssModel(TOOL_NAME) : ''), + toolDetails.description + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''), schema: z.object({ question: z.string().describe(`The user's exact question about ESQL`), }), From 351c2f7965a98f1384b6bf3d250082c6746d2792 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Fri, 27 Sep 2024 11:38:04 +0200 Subject: [PATCH 12/17] Remove redundant code --- .../lib/langchain/graphs/default_assistant_graph/helpers.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 0545e7083abcb..840b2a9ac8ce0 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -42,7 +42,6 @@ interface StreamGraphParams { * @param request * @param traceOptions */ -/* eslint complexity: ["error", 210]*/ export const streamGraph = async ({ apmTracer, assistantGraph, From 69543ce07ff906f26d96db84ab2c0215e823ba89 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Wed, 2 Oct 2024 17:39:58 +0200 Subject: [PATCH 13/17] Make sure we log evaluation results and errors --- .../server/routes/evaluate/post_evaluate.ts | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index a4cd8c319495b..069c7715a0f23 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -337,15 +337,20 @@ export const postEvaluateRoute = ( return output; }; - const evalOutput = evaluate(predict, { + evaluate(predict, { data: datasetName ?? '', evaluators: [], // Evals to be managed in LangSmith for now experimentPrefix: name, client: new Client({ apiKey: langSmithApiKey }), // prevent rate limiting and unexpected multiple experiment runs maxConcurrency: 5, - }); - logger.debug(`runResp:\n ${JSON.stringify(evalOutput, null, 2)}`); + }) + .then((output) => { + logger.debug(`runResp:\n ${JSON.stringify(output, null, 2)}`); + }) + .catch((err) => { + logger.error(`evaluation error:\n ${JSON.stringify(err, null, 2)}`); + }); }); return response.ok({ From 513d4bf5e4c35bc122af53f8b69e6adbe0c7dd77 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Mon, 7 Oct 2024 18:04:23 +0200 Subject: [PATCH 14/17] Review feedback: long time request issue --- .../impl/assistant/use_send_message/index.tsx | 15 +++++++++++++++ .../assistant/use_send_message/translations.ts | 15 +++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/translations.ts diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx index 93bd03607e71f..efe14103655e5 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx @@ -10,6 +10,16 @@ import { useCallback, useRef, useState } from 'react'; import { ApiConfig, Replacements } from '@kbn/elastic-assistant-common'; import { useAssistantContext } from '../../assistant_context'; import { fetchConnectorExecuteAction, FetchConnectorExecuteResponse } from '../api'; +import * as i18n from './translations'; + +/** + * TODO: This is a workaround to solve the issue with the long standing server tasks while cahtting with the assistant. + * Some models (like Llama 3.1 70B) can perform poorly and be slow which leads to a long time to handle the request. + * The `core-http-browser` has a timeout of two minutes after which it will re-try the request. In combination with the slow model it can lead to + * a situation where core http client will initiate same request again and again. + * To avoid this, we abort http request after timeout which is slightly below two minutes. + */ +const EXECUTE_ACTION_TIMEOUT = 110 * 1000; // in milliseconds interface SendMessageProps { apiConfig: ApiConfig; @@ -38,6 +48,10 @@ export const useSendMessage = (): UseSendMessage => { async ({ apiConfig, http, message, conversationId, replacements }: SendMessageProps) => { setIsLoading(true); + const timeoutId = setTimeout(() => { + abortController.current.abort(i18n.FETCH_MESSAGE_TIMEOUT_ERROR); + }, EXECUTE_ACTION_TIMEOUT); + try { return await fetchConnectorExecuteAction({ conversationId, @@ -52,6 +66,7 @@ export const useSendMessage = (): UseSendMessage => { traceOptions, }); } finally { + clearTimeout(timeoutId); setIsLoading(false); } }, diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/translations.ts b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/translations.ts new file mode 100644 index 0000000000000..1185d8cfdbc65 --- /dev/null +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/translations.ts @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { i18n } from '@kbn/i18n'; + +export const FETCH_MESSAGE_TIMEOUT_ERROR = i18n.translate( + 'xpack.elasticAssistant.assistant.useSendMessage.fetchMessageTimeoutError', + { + defaultMessage: 'Assistant could not respond in time. Please try again later.', + } +); From 1d60365fa57efdb24c2d5f3fa8453e49e9ad6633 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Mon, 7 Oct 2024 18:11:00 +0200 Subject: [PATCH 15/17] Update x-pack/plugins/elastic_assistant/server/routes/utils.ts Co-authored-by: Steph Milovic --- x-pack/plugins/elastic_assistant/server/routes/utils.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index 2b8ff40ba7144..4192d10e9bc98 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -208,7 +208,7 @@ export const isOpenSourceModel = (connector?: Connector): boolean => { ? (connector.config?.apiProvider as OpenAiProviderType) : undefined; - const isOpeAiType = llmType === 'openai'; + const isOpenAiType = llmType === 'openai'; const isOpenAI = isOpeAiType && (!connectorApiUrl || From 0f96231a389978d094bab15d37321e865a930ba8 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Mon, 7 Oct 2024 18:11:22 +0200 Subject: [PATCH 16/17] Review feedback: naming --- x-pack/plugins/elastic_assistant/server/routes/utils.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index 4192d10e9bc98..5811109b94ede 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -210,10 +210,10 @@ export const isOpenSourceModel = (connector?: Connector): boolean => { const isOpenAiType = llmType === 'openai'; const isOpenAI = - isOpeAiType && + isOpenAiType && (!connectorApiUrl || connectorApiUrl === OPENAI_CHAT_URL || connectorApiProvider === OpenAiProviderType.AzureAi); - return isOpeAiType && !isOpenAI; + return isOpenAiType && !isOpenAI; }; From 59fb225d0590362a6a78893ceab437d496337bb5 Mon Sep 17 00:00:00 2001 From: Ievgen Sorokopud Date: Mon, 7 Oct 2024 21:48:16 +0200 Subject: [PATCH 17/17] Review feedback: re-instantiate AbortController after the `abort` --- .../impl/assistant/use_send_message/index.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx index efe14103655e5..438b2282371d9 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/use_send_message/index.tsx @@ -50,6 +50,7 @@ export const useSendMessage = (): UseSendMessage => { const timeoutId = setTimeout(() => { abortController.current.abort(i18n.FETCH_MESSAGE_TIMEOUT_ERROR); + abortController.current = new AbortController(); }, EXECUTE_ACTION_TIMEOUT); try {