diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 02d9e3d96d..df6bd3ab20 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -40,6 +40,7 @@ public class AgentUtils { public static final String PROMPT_SUFFIX = "prompt.suffix"; public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction"; public static final String TOOL_RESPONSE = "prompt.tool_response"; + public static final String PROMPT_CHAT_HISTORY_PREFIX = "prompt.chat_history_prefix"; public static final String DISABLE_TRACE = "disable_trace"; public static final String VERBOSE = "verbose"; @@ -182,8 +183,8 @@ public static String extractModelResponseJson(String text, List llmRespo } } - public static String findMatchedPart(String text, List llmResponsePatterns) { - for (String p : llmResponsePatterns) { + public static String findMatchedPart(String text, List patternList) { + for (String p : patternList) { Pattern pattern = Pattern.compile(p); Matcher matcher = pattern.matcher(text); if (matcher.find()) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 8b95a8f78c..621b9884ae 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.common.utils.StringUtils.isJson; import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION; @@ -23,6 +24,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolNames; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; import java.security.PrivilegedActionException; import java.util.ArrayList; @@ -145,7 +147,8 @@ public void run(MLAgent mlAgent, Map params, ActionListener 0) { - chatHistoryBuilder.append("Human:CONVERSATION HISTORY WITH AI ASSISTANT\n----------------------------\nBelow is Chat History between Human and AI which sorted by time with asc order:\n"); + String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX); + chatHistoryBuilder.append(chatHistoryPrefix); for (Message message : messageList) { chatHistoryBuilder.append(message.toString()).append("\n"); } @@ -203,7 +206,6 @@ private void runReAct( AtomicInteger traceNumber = new AtomicInteger(0); AtomicReference> lastLlmListener = new AtomicReference<>(); -// AtomicBoolean getFinalAnswer = new AtomicBoolean(false); AtomicReference lastThought = new AtomicReference<>(); AtomicReference lastAction = new AtomicReference<>(); AtomicReference lastActionInput = new AtomicReference<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index d32d2b3edd..23ff6f9781 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -11,4 +11,5 @@ public class PromptTemplate { public static final String PROMPT_TEMPLATE = "\n\nHuman:${parameters.prompt.prefix}\n\n${parameters.prompt.suffix}\n\nHuman: follow RESPONSE FORMAT INSTRUCTIONS\n\nAssistant:"; public static final String PROMPT_TEMPLATE_TOOL_RESPONSE = "Assistant:\n---------------------\n${parameters.llm_tool_selection_response}\n\nHuman: TOOL RESPONSE of ${parameters.tool_name}: \n---------------------\nTool input:\n${parameters.tool_input}\n\nTool output:\n${parameters.observation}\n\n"; + public static final String CHAT_HISTORY_PREFIX = "Human:CONVERSATION HISTORY WITH AI ASSISTANT\n----------------------------\nBelow is Chat History between Human and AI which sorted by time with asc order:\n"; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java index 2a084ee9b9..8436d98059 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMessage.java @@ -34,7 +34,7 @@ public ConversationIndexMessage(String type, String sessionId, String question, @Override public String toString() { - return "Human:" + question + "\nAI:" + response; + return "Human:" + question + "\nAssistant:" + response; } @Override