From d0c931d954beafb2b78ff72c64a0ff658875573a Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 12 Feb 2024 15:15:17 -0800 Subject: [PATCH] clean code Signed-off-by: Yaliang Wu --- .../ml/common/connector/HttpConnector.java | 3 + .../engine/algorithms/agent/AgentUtils.java | 90 +++- .../algorithms/agent/MLAgentExecutor.java | 2 +- .../algorithms/agent/MLChatAgentRunner.java | 502 +++++------------- .../MLConversationalFlowAgentRunner.java | 62 +-- .../algorithms/agent/MLFlowAgentRunner.java | 3 +- .../algorithms/agent/PromptTemplate.java | 6 +- .../remote/AwsConnectorExecutor.java | 3 + .../ml/engine/memory/MLMemoryManager.java | 2 +- .../ml/engine/tools/CatIndexTool.java | 21 +- .../algorithms/agent/AgentUtilsTest.java | 41 +- 11 files changed, 295 insertions(+), 440 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ca37bc40dc..e19a960b16 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -290,6 +290,9 @@ public T createPredictPayload(Map parameters) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); + log.info("++++++++++++++++++++++++++++++++++++++++++++++++++"); + System.out.println(payload); + log.info("--------------------------------------------------"); if (!isJson(payload)) { throw new IllegalArgumentException("Invalid payload: " + payload); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 5268f4a559..3f0430d493 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -11,8 +11,6 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES; -import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX; -import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS; @@ -20,6 +18,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,6 +27,7 @@ import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -35,6 +35,12 @@ public class AgentUtils { + public static final String SELECTED_TOOLS = "selected_tools"; + public static final String PROMPT_PREFIX = "prompt.prefix"; + public static final String PROMPT_SUFFIX = "prompt.suffix"; + public static final String RESPONSE_FORMAT_INSTRUCTION = "prompt.format_instruction"; + public static final String TOOL_RESPONSE = "prompt.tool_response"; + public static String addExamplesToPrompt(Map parameters, String prompt) { Map examplesMap = new HashMap<>(); if (parameters.containsKey(EXAMPLES)) { @@ -150,13 +156,25 @@ public static String addContextToPrompt(Map parameters, String p return prompt; } + public static List MODEL_RESPONSE_PATTERNS = List.of( + "\\{\\s*\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?\\}", + "\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"action\"\\s*:\\s*\".*?\"\\s*,\\s*\"action_input\"\\s*:\\s*\".*?\"\\s*}", + "\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"final_answer\"\\s*:\\s*\".*?\"\\s*}" + ); public static String extractModelResponseJson(String text) { - Pattern pattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```"); - Matcher matcher = pattern.matcher(text); + Pattern pattern1 = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```"); + Matcher matcher1 = pattern1.matcher(text); - if (matcher.find()) { - return matcher.group(1); + if (matcher1.find()) { + return matcher1.group(1); } else { + for (String p : MODEL_RESPONSE_PATTERNS) { + Pattern pattern = Pattern.compile(p); + Matcher matcher = pattern.matcher(text); + if (matcher.find()) { + return matcher.group(); + } + } throw new IllegalArgumentException("Model output is invalid"); } } @@ -197,4 +215,64 @@ public static int getMessageHistoryLimit(Map params) { public static String getToolName(MLToolSpec toolSpec) { return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType(); } + + public static List getMlToolSpecs(MLAgent mlAgent, Map params) { + String selectedToolsStr = params.get(SELECTED_TOOLS); + List toolSpecs = mlAgent.getTools(); + if (selectedToolsStr != null) { + List selectedTools = gson.fromJson(selectedToolsStr, List.class); + Map toolNameSpecMap = new HashMap<>(); + for (MLToolSpec toolSpec : toolSpecs) { + toolNameSpecMap.put(getToolName(toolSpec), toolSpec); + } + List selectedToolSpecs = new ArrayList<>(); + for (String tool : selectedTools) { + if (toolNameSpecMap.containsKey(tool)) { + selectedToolSpecs.add(toolNameSpecMap.get(tool)); + } + } + toolSpecs = selectedToolSpecs; + } + return toolSpecs; + } + + public static void createTools(Map toolFactories, + Map params, + List toolSpecs, + Map tools, + Map toolSpecMap) { + for (MLToolSpec toolSpec : toolSpecs) { + Tool tool = createTool(toolFactories, params, toolSpec); + tools.put(tool.getName(), tool); + toolSpecMap.put(tool.getName(), toolSpec); + } + } + + public static Tool createTool(Map toolFactories, Map params, MLToolSpec toolSpec) { + if (!toolFactories.containsKey(toolSpec.getType())) { + throw new IllegalArgumentException("Tool not found: " + toolSpec.getType()); + } + Map executeParams = new HashMap<>(); + if (toolSpec.getParameters() != null) { + executeParams.putAll(toolSpec.getParameters()); + } + for (String key : params.keySet()) { + String toolNamePrefix = getToolName(toolSpec) + "."; + if (key.startsWith(toolNamePrefix)) { + executeParams.put(key.replace(toolNamePrefix, ""), params.get(key)); + } + } + Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams); + String toolName = getToolName(toolSpec); + tool.setName(toolName); + + if (toolSpec.getDescription() != null) { + tool.setDescription(toolSpec.getDescription()); + } + if (params.containsKey(toolName + ".description")) { + tool.setDescription(params.get(toolName + ".description")); + } + + return tool; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index cb779abf31..b9e90efce6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -63,7 +63,7 @@ public class MLAgentExecutor implements Executable { public static final String MEMORY_ID = "memory_id"; public static final String QUESTION = "question"; - public static final String PARENT_INTERACTION_ID = "parent_interaction_id"; + public static final String PARENT_INTERACTION_ID = "interaction_id"; public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id"; public static final String MESSAGE_HISTORY_LIMIT = "message_history_limit"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index f944aa90e9..c8b10f4f0f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -8,11 +8,20 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -25,14 +34,10 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; import org.opensearch.action.StepListener; -import org.opensearch.action.support.GroupedActionListener; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -59,7 +64,6 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; -import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.ml.repackage.com.google.common.collect.Lists; @@ -73,10 +77,8 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String SESSION_ID = "session_id"; - public static final String PROMPT_PREFIX = "prompt_prefix"; public static final String LLM_TOOL_PROMPT_PREFIX = "LanguageModelTool.prompt_prefix"; public static final String LLM_TOOL_PROMPT_SUFFIX = "LanguageModelTool.prompt_suffix"; - public static final String PROMPT_SUFFIX = "prompt_suffix"; public static final String TOOLS = "tools"; public static final String TOOL_DESCRIPTIONS = "tool_descriptions"; public static final String TOOL_NAMES = "tool_names"; @@ -121,6 +123,7 @@ public void run(MLAgent mlAgent, Map params, ActionListenerwrap(memory -> { + //TODO: call runAgent directly if messageHistoryLimit == 0 memory.getMessages(ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); for (Interaction next : r) { @@ -131,20 +134,12 @@ public void run(MLAgent mlAgent, Map params, ActionListener 0) { - chatHistoryBuilder.append("Below is Chat History between Human and AI which sorted by time with asc order:\n"); + chatHistoryBuilder.append("Human:CONVERSATION HISTORY WITH AI ASSISTANT\n----------------------------\nBelow is Chat History between Human and AI which sorted by time with asc order:\n"); for (Message message : messageList) { chatHistoryBuilder.append(message.toString()).append("\n"); } @@ -160,34 +155,10 @@ public void run(MLAgent mlAgent, Map params, ActionListener params, ActionListener listener, Memory memory, String sessionId) { - List toolSpecs = mlAgent.getTools(); + List toolSpecs = getMlToolSpecs(mlAgent, params); Map tools = new HashMap<>(); Map toolSpecMap = new HashMap<>(); - for (MLToolSpec toolSpec : toolSpecs) { - Map toolParams = new HashMap<>(); - Map executeParams = new HashMap<>(); - if (toolSpec.getParameters() != null) { - toolParams.putAll(toolSpec.getParameters()); - executeParams.putAll(toolSpec.getParameters()); - } - for (String key : params.keySet()) { - if (key.startsWith(toolSpec.getType() + ".")) { - executeParams.put(key.replace(toolSpec.getType() + ".", ""), params.get(key)); - } - } - log.info("Fetching tool for type: " + toolSpec.getType()); - Tool tool = toolFactories.get(toolSpec.getType()).create(executeParams); - if (toolSpec.getName() != null) { - tool.setName(toolSpec.getName()); - } - - if (toolSpec.getDescription() != null) { - tool.setDescription(toolSpec.getDescription()); - } - String toolName = Optional.ofNullable(tool.getName()).orElse(toolSpec.getType()); - tools.put(toolName, tool); - toolSpecMap.put(toolName, toolSpec); - } + createTools(toolFactories, params, toolSpecs, tools, toolSpecMap); runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, listener); } @@ -204,6 +175,7 @@ private void runReAct( String question = parameters.get(MLAgentExecutor.QUESTION); String parentInteractionId = parameters.get(MLAgentExecutor.PARENT_INTERACTION_ID); boolean verbose = parameters.containsKey("verbose") && Boolean.parseBoolean(parameters.get("verbose")); + boolean traceDisabled = parameters.containsKey("disable_trace") && Boolean.parseBoolean(parameters.get("disable_trace")); Map tmpParameters = new HashMap<>(); if (llm.getParameters() != null) { tmpParameters.putAll(llm.getParameters()); @@ -213,46 +185,28 @@ private void runReAct( tmpParameters.put("stop", gson.toJson(new String[] { "\nObservation:", "\n\tObservation:" })); } if (!tmpParameters.containsKey("stop_sequences")) { - tmpParameters - .put( - "stop_sequences", - gson - .toJson( - new String[] { - "\n\nHuman:", - "\nObservation:", - "\n\tObservation:", - "\nObservation", - "\n\tObservation", - "\n\nQuestion" } - ) - ); + tmpParameters.put("stop_sequences", gson.toJson(new String[] {"\n\nHuman:", "\nObservation:", "\n\tObservation:", "\nObservation", "\n\tObservation", "\n\nQuestion" })); } - String prompt = parameters.get(PROMPT); - if (prompt == null) { - prompt = PromptTemplate.PROMPT_TEMPLATE; - } - String promptPrefix = parameters.getOrDefault("prompt.prefix", PromptTemplate.PROMPT_TEMPLATE_PREFIX); - tmpParameters.put("prompt.prefix", promptPrefix); + String prompt = parameters.getOrDefault(PROMPT, PromptTemplate.PROMPT_TEMPLATE); + String promptPrefix = parameters.getOrDefault(PROMPT_PREFIX, PromptTemplate.PROMPT_TEMPLATE_PREFIX); + tmpParameters.put(PROMPT_PREFIX, promptPrefix); - String promptSuffix = parameters.getOrDefault("prompt.suffix", PromptTemplate.PROMPT_TEMPLATE_SUFFIX); - tmpParameters.put("prompt.suffix", promptSuffix); + String promptSuffix = parameters.getOrDefault(PROMPT_SUFFIX, PromptTemplate.PROMPT_TEMPLATE_SUFFIX); + tmpParameters.put(PROMPT_SUFFIX, promptSuffix); - String promptFormatInstruction = parameters.getOrDefault("prompt.format_instruction", PromptTemplate.PROMPT_FORMAT_INSTRUCTION); - tmpParameters.put("prompt.format_instruction", promptFormatInstruction); - if (!tmpParameters.containsKey("prompt.tool_response")) { - tmpParameters.put("prompt.tool_response", PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE); - } - String promptToolResponse = parameters.getOrDefault("prompt.tool_response", PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE); - tmpParameters.put("prompt.tool_response", promptToolResponse); + String promptFormatInstruction = parameters.getOrDefault(RESPONSE_FORMAT_INSTRUCTION, PromptTemplate.PROMPT_FORMAT_INSTRUCTION); + tmpParameters.put(RESPONSE_FORMAT_INSTRUCTION, promptFormatInstruction); + + String promptToolResponse = parameters.getOrDefault(TOOL_RESPONSE, PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE); + tmpParameters.put(TOOL_RESPONSE, promptToolResponse); StringSubstitutor promptSubstitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}"); prompt = promptSubstitutor.replace(prompt); final List inputTools = new ArrayList<>(); for (Map.Entry entry : tools.entrySet()) { - String toolName = Optional.ofNullable(entry.getValue().getName()).orElse(entry.getValue().getType()); + String toolName = entry.getValue().getName(); inputTools.add(toolName); } @@ -268,26 +222,10 @@ private void runReAct( List modelTensors = new ArrayList<>(); List cotModelTensors = new ArrayList<>(); - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build() - ) - ) - .build() - ); + cotModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build())).build()); StringBuilder scratchpadBuilder = new StringBuilder(); - StringSubstitutor tmpSubstitutor = new StringSubstitutor( - ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()), - "${parameters.", - "}" - ); + StringSubstitutor tmpSubstitutor = new StringSubstitutor(ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); @@ -305,6 +243,7 @@ private void runReAct( AtomicReference lastThought = new AtomicReference<>(); AtomicReference lastAction = new AtomicReference<>(); AtomicReference lastActionInput = new AtomicReference<>(); + AtomicReference lastToolSelectionResponse = new AtomicReference<>(); Map additionalInfo = new ConcurrentHashMap<>(); StepListener lastStepListener = null; @@ -325,151 +264,73 @@ private void runReAct( MLTaskResponse llmResponse = (MLTaskResponse) output; ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); Map dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); + String thoughtResponse = null; + String finalAnswer = null; if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) { - String response = (String) dataAsMap.get("response"); - String thoughtResponse = extractModelResponseJson(response); - dataAsMap = gson.fromJson(thoughtResponse, Map.class); + String llmReasoningResponse = (String) dataAsMap.get("response"); + try { + thoughtResponse = extractModelResponseJson(llmReasoningResponse); + } catch (IllegalArgumentException e) { + thoughtResponse = llmReasoningResponse; + finalAnswer = llmReasoningResponse; + } + if (isJson(thoughtResponse)) { + dataAsMap = gson.fromJson(thoughtResponse, Map.class); + } + } else { + try { + Map finalDataAsMap = dataAsMap; + thoughtResponse = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(finalDataAsMap)); + } catch (Exception e) { + e.printStackTrace(); + } } String thought = String.valueOf(dataAsMap.get("thought")); String action = String.valueOf(dataAsMap.get("action")); String actionInput = parseInputFromLLMReturn(dataAsMap); - String finalAnswer = (String) dataAsMap.get("final_answer"); - if (!dataAsMap.containsKey("thought")) { - String response = (String) dataAsMap.get("response"); - Pattern pattern = Pattern.compile("```json(.*?)```", Pattern.DOTALL); - Matcher matcher = pattern.matcher(response); - if (matcher.find()) { - String jsonBlock = matcher.group(1); - Map map = gson.fromJson(jsonBlock, Map.class); - thought = String.valueOf(map.get("thought")); - action = String.valueOf(map.get("action")); - actionInput = parseInputFromLLMReturn(map); - finalAnswer = (String) map.get("final_answer"); + if (dataAsMap.containsKey("final_answer")) { + finalAnswer = String.valueOf(dataAsMap.get("final_answer")); + } + + if (finalAnswer != null) { + finalAnswer = finalAnswer.trim(); + if (conversationIndexMemory != null) { + String copyOfFinalAnswer = finalAnswer; + ActionListener saveTraceListener = ActionListener.wrap(r->{ + conversationIndexMemory.getMemoryManager().updateInteraction(parentInteractionId, + Map.of(AI_RESPONSE_FIELD, copyOfFinalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo), + ActionListener.wrap(res ->{ + returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, getFinalAnswer, additionalInfo, copyOfFinalAnswer); + }, e-> { + listener.onFailure(e); + })); + }, e-> { + listener.onFailure(e); + }); + saveMessage(conversationIndexMemory, question, finalAnswer, sessionId, parentInteractionId, traceNumber, true, traceDisabled, saveTraceListener); } else { - finalAnswer = response; + returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, getFinalAnswer, additionalInfo, finalAnswer); } + return; } + lastToolSelectionResponse.set(thoughtResponse); + if (finalI == 0 && !thought.contains("Thought:")) { sessionMsgAnswerBuilder.append("Thought: "); } sessionMsgAnswerBuilder.append(thought); lastThought.set(thought); - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build() - ) - ) - .build() - ); + + cotModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build())).build()); // TODO: check if verbose modelTensors.addAll(tmpModelTensorOutput.getMlModelOutputs()); if (conversationIndexMemory != null) { - ConversationIndexMessage msgTemp = ConversationIndexMessage - .conversationIndexMessageBuilder() - .type(memory.getType()) - .question(question) - .response(thought) - .finalAnswer(false) - .sessionId(sessionId) - .build(); - conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null); - } - if (finalAnswer != null) { - finalAnswer = finalAnswer.trim(); - String finalAnswer2 = finalAnswer; - // Composite execution response and reply. - final ActionListener executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> { - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build()) - ) - .build() - ); - - List finalModelTensors = new ArrayList<>(); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor - .builder() - .name(MLAgentExecutor.PARENT_INTERACTION_ID) - .result(parentInteractionId) - .build() - ) - ) - .build() - ); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor - .builder() - .name("response") - .dataAsMap( - ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo) - ) - .build() - ) - ) - .build() - ); - getFinalAnswer.set(true); - if (verbose) { - listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build()); - } else { - listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); - } - }, listener::onFailure)); - // Sending execution response by internalListener is after the trace and answer saving. - final GroupedActionListener groupedListener = createGroupedListener(2, executionListener); - if (conversationIndexMemory != null) { - String finalAnswer1 = finalAnswer; - // Create final trace message. - ConversationIndexMessage msgTemp = ConversationIndexMessage - .conversationIndexMessageBuilder() - .type(memory.getType()) - .question(question) - .response(finalAnswer1) - .finalAnswer(true) - .sessionId(sessionId) - .build(); - // Save last trace and update final answer in parallel. - conversationIndexMemory - .save( - msgTemp, - parentInteractionId, - traceNumber.addAndGet(1), - null, - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - conversationIndexMemory - .getMemoryManager() - .updateInteraction( - parentInteractionId, - ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo), - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); + ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(thoughtResponse).finalAnswer(false).sessionId(sessionId).build(); + if (!traceDisabled) { + conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM"); } - return; } lastAction.set(action); @@ -499,17 +360,10 @@ private void runReAct( try { String finalAction = action; ActionListener toolListener = ActionListener - .wrap(r -> { ((ActionListener) nextStepListener).onResponse(r); }, e -> { - ((ActionListener) nextStepListener) - .onResponse( - String - .format( - Locale.ROOT, - "Failed to run the tool %s with the error message %s.", - finalAction, - e.getMessage() - ) - ); + .wrap(r -> { + ((ActionListener) nextStepListener).onResponse(r); + }, e -> { + ((ActionListener) nextStepListener).onResponse(String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", finalAction, e.getMessage())); }); if (tools.get(action) instanceof MLModelTool) { Map llmToolTmpParameters = new HashMap<>(); @@ -522,30 +376,16 @@ private void runReAct( tools.get(action).run(toolParams, toolListener); // run tool } } catch (Exception e) { - ((ActionListener) nextStepListener) - .onResponse( - String - .format( - Locale.ROOT, - "Failed to run the tool %s with the error message %s.", - action, - e.getMessage() - ) - ); + ((ActionListener) nextStepListener).onResponse(String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", action, e.getMessage())); } } else { - String res = String - .format(Locale.ROOT, "Failed to run the tool %s due to wrong input %s.", action, actionInput); + String res = String.format(Locale.ROOT, "Failed to run the tool %s due to wrong input %s.", action, actionInput); ((ActionListener) nextStepListener).onResponse(res); } } else { String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action); ((ActionListener) nextStepListener).onResponse(res); - StringSubstitutor substitutor = new StringSubstitutor( - ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()), - "${parameters.", - "}" - ); + StringSubstitutor substitutor = new StringSubstitutor(ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); } @@ -553,7 +393,6 @@ private void runReAct( MLToolSpec toolSpec = toolSpecMap.get(lastAction.get()); if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) { String outputString = outputToOutputString(output); - String toolOutputKey = String.format("%s.output", toolSpec.getType()); if (additionalInfo.get(toolOutputKey) != null) { List list = (List) additionalInfo.get(toolOutputKey); @@ -561,123 +400,39 @@ private void runReAct( } else { additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); } - } - modelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor - .builder() - .dataAsMap( - ImmutableMap - .of( - "response", - lastThought.get() + "\nObservation: " + outputToOutputString(output) - ) - ) - .build() - ) - ) - .build() - ); - - String toolResponse = tmpParameters.get("prompt.tool_response"); - StringSubstitutor toolResponseSubstitutor = new StringSubstitutor( - ImmutableMap.of("observation", outputToOutputString(output)), - "${parameters.", - "}" - ); + modelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + outputToOutputString(output))).build())).build()); + + String toolResponse = tmpParameters.get(TOOL_RESPONSE); + StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(Map.of("llm_tool_selection_response", /*"```json\\n" +*/ lastToolSelectionResponse.get() /*+ "\n```"*/, "tool_name", lastAction.get(),"tool_input", lastActionInput.get(),"observation", outputToOutputString(output)), "${parameters.", "}"); toolResponse = toolResponseSubstitutor.replace(toolResponse); scratchpadBuilder.append(toolResponse).append("\n\n"); if (conversationIndexMemory != null) { // String res = "Action: " + lastAction.get() + "\nAction Input: " + lastActionInput + "\nObservation: " + result; - ConversationIndexMessage msgTemp = ConversationIndexMessage - .conversationIndexMessageBuilder() - .type("ReAct") - .question(lastActionInput.get()) - .response(outputToOutputString(output)) - .finalAnswer(false) - .sessionId(sessionId) - .build(); - conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), lastAction.get()); - + ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type("ReAct").question(lastActionInput.get()).response(outputToOutputString(output)).finalAnswer(false).sessionId(sessionId).build(); + if (!traceDisabled) { + conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), lastAction.get()); + } } - StringSubstitutor substitutor = new StringSubstitutor( - ImmutableMap.of(SCRATCHPAD, scratchpadBuilder.toString()), - "${parameters.", - "}" - ); + StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder /*+ PROMPT_TEMPLATE_ASK_AGAIN*/), "${parameters.", "}"); newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); - sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output)); - cotModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build() - ) - ) - .build() - ); - - ActionRequest request = new MLPredictionTaskRequest( - llm.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()) - .build() - ); - client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener) nextStepListener); + sessionMsgAnswerBuilder.append(outputToOutputString(output)); + cotModelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build())).build()); + + //client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener) nextStepListener); if (finalI == maxIterations - 1) { if (verbose) { listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build()); } else { List finalModelTensors = new ArrayList<>(); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - List - .of( - ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), - ModelTensor - .builder() - .name(MLAgentExecutor.PARENT_INTERACTION_ID) - .result(parentInteractionId) - .build() - ) - ) - .build() - ); - finalModelTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor - .builder() - .name("response") - .dataAsMap(ImmutableMap.of("response", lastThought.get())) - .build() - ) - ) - .build() - ); + finalModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build())).build()); + finalModelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", lastThought.get())).build())).build()); listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); } } else { + ActionRequest request = new MLPredictionTaskRequest(llm.getModelId(), RemoteInferenceMLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()).build()); client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener) nextStepListener); } } @@ -690,34 +445,41 @@ private void runReAct( } } - ActionRequest request = new MLPredictionTaskRequest( - llm.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()) - .build() - ); + ActionRequest request = new MLPredictionTaskRequest(llm.getModelId(), RemoteInferenceMLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()).build()); client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); } - private GroupedActionListener createGroupedListener(final int size, final ActionListener listener) { - return new GroupedActionListener<>(new ActionListener>() { - @Override - public void onResponse(final Collection responses) { - CreateInteractionResponse createInteractionResponse = extractResponse(responses, CreateInteractionResponse.class); - log.info("saved message with interaction id: {}", createInteractionResponse.getId()); - UpdateResponse updateResponse = extractResponse(responses, UpdateResponse.class); - log.info("Updated final answer into interaction id: {}", updateResponse.getId()); - - listener.onResponse(true); - } + private static void returnFinalResponse(String sessionId, ActionListener listener, String parentInteractionId, boolean verbose, List cotModelTensors, AtomicBoolean getFinalAnswer, Map additionalInfo, String finalAnswer2) { + cotModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name("response").result(finalAnswer2).build())).build()); + + List finalModelTensors = new ArrayList<>(); + finalModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build())).build()); + finalModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)).build())).build()); + getFinalAnswer.set(true); + if (verbose) { + listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build()); + } else { + listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build()); + } + } - @Override - public void onFailure(final Exception e) { - listener.onFailure(e); - } - }, size); + private void saveMessage( + ConversationIndexMemory memory, + String question , + String finalAnswer , + String sessionId , + String parentInteractionId, + AtomicInteger traceNumber, + boolean isFinalAnswer, + boolean traceDisabled, + ActionListener listener + ) { + ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(finalAnswer).finalAnswer(isFinalAnswer).sessionId(sessionId).build(); + if (traceDisabled) { + listener.onResponse(true); + } else { + memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); + } } @SuppressWarnings("unchecked") diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index c2471e4f35..a5b73616a6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -11,7 +11,9 @@ import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName; import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; @@ -62,7 +64,7 @@ public class MLConversationalFlowAgentRunner implements MLAgentRunner { public static final String CHAT_HISTORY = "chat_history"; - public static final String SELECTED_TOOLS = "selected_tools"; + private Client client; private Settings settings; private ClusterService clusterService; @@ -156,8 +158,7 @@ private void runAgent( Map firstToolExecuteParams = null; StepListener previousStepListener = null; Map additionalInfo = new ConcurrentHashMap<>(); - String selectedToolsStr = params.get(SELECTED_TOOLS); - List toolSpecs = getMlToolSpecs(mlAgent, selectedToolsStr); + List toolSpecs = getMlToolSpecs(mlAgent, params); if (toolSpecs == null || toolSpecs.size() == 0) { listener.onFailure(new IllegalArgumentException("no tool configured")); @@ -173,7 +174,7 @@ private void runAgent( for (int i = 0; i <= toolSpecs.size(); i++) { if (i == 0) { MLToolSpec toolSpec = toolSpecs.get(i); - Tool tool = createTool(toolSpec); + Tool tool = createTool(toolFactories, params, toolSpec); firstStepListener = new StepListener<>(); previousStepListener = firstStepListener; firstTool = tool; @@ -231,25 +232,6 @@ private void runAgent( } } - private static List getMlToolSpecs(MLAgent mlAgent, String selectedToolsStr) { - List toolSpecs = mlAgent.getTools(); - if (selectedToolsStr != null) { - List selectedTools = gson.fromJson(selectedToolsStr, List.class); - Map toolNameSpecMap = new HashMap<>(); - for (MLToolSpec toolSpec : toolSpecs) { - toolNameSpecMap.put(getToolName(toolSpec), toolSpec); - } - List selectedToolSpecs = new ArrayList<>(); - for (String tool : selectedTools) { - if (toolNameSpecMap.containsKey(tool)) { - selectedToolSpecs.add(toolNameSpecMap.get(tool)); - } - } - toolSpecs = selectedToolSpecs; - } - return toolSpecs; - } - private void processOutput( Map params, ActionListener listener, @@ -271,6 +253,7 @@ private void processOutput( String outputKey = toolName + ".output"; String outputResponse = parseResponse(output); params.put(outputKey, escapeJson(outputResponse)); + boolean traceDisabled = params.containsKey("disable_trace") && Boolean.parseBoolean(params.get("disable_trace")); if (previousToolSpec.isIncludeOutputInAgentResponse() || finalI == toolSpecs.size()) { if (output instanceof ModelTensorOutput) { @@ -303,7 +286,7 @@ private void processOutput( updateMemoryWithListener(additionalInfo, memorySpec, memoryId, parentInteractionId, updateListener); } } else { - saveMessage(params, memory, outputResponse, memoryId, parentInteractionId, toolName, traceNumber, ActionListener.wrap(r -> { + saveMessage(params, memory, outputResponse, memoryId, parentInteractionId, toolName, traceNumber,traceDisabled, ActionListener.wrap(r -> { log.info("saved last trace for interaction " + parentInteractionId + " of flow agent"); Map updateContent = Map.of(AI_RESPONSE_FIELD, outputResponse, ADDITIONAL_INFO_FIELD, additionalInfo); memory.update(parentInteractionId, updateContent, updateListener); @@ -316,7 +299,7 @@ private void processOutput( if (memory == null) { runNextStep(params, toolSpecs, finalI, nextStepListener); } else { - saveMessage(params, memory, outputResponse, memoryId, parentInteractionId, toolName, traceNumber, ActionListener.wrap(r -> { + saveMessage(params, memory, outputResponse, memoryId, parentInteractionId, toolName, traceNumber,traceDisabled, ActionListener.wrap(r -> { runNextStep(params, toolSpecs, finalI, nextStepListener); }, e -> { log.error("Failed to update root interaction ", e); @@ -328,7 +311,7 @@ private void processOutput( private void runNextStep(Map params, List toolSpecs, int finalI, StepListener nextStepListener) { MLToolSpec toolSpec = toolSpecs.get(finalI); - Tool tool = createTool(toolSpec); + Tool tool = createTool(toolFactories, params, toolSpec); if (finalI < toolSpecs.size()) { tool.run(getToolExecuteParams(toolSpec, params), nextStepListener); } @@ -342,6 +325,7 @@ private void saveMessage( String parentInteractionId, String toolName, AtomicInteger traceNumber, + boolean traceDisabled, ActionListener listener ) { ConversationIndexMessage finalMessage = ConversationIndexMessage @@ -352,7 +336,11 @@ private void saveMessage( .finalAnswer(true) .sessionId(memoryId) .build(); - memory.save(finalMessage, parentInteractionId, traceNumber.addAndGet(1), toolName, listener); + if (traceDisabled) { + listener.onResponse(true); + } else { + memory.save(finalMessage, parentInteractionId, traceNumber.addAndGet(1), toolName, listener); + } } @VisibleForTesting @@ -397,26 +385,6 @@ String parseResponse(Object output) throws IOException { } } - @VisibleForTesting - Tool createTool(MLToolSpec toolSpec) { - Map toolParams = new HashMap<>(); - if (toolSpec.getParameters() != null) { - toolParams.putAll(toolSpec.getParameters()); - } - if (!toolFactories.containsKey(toolSpec.getType())) { - throw new IllegalArgumentException("Tool not found: " + toolSpec.getType()); - } - Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams); - if (toolSpec.getName() != null) { - tool.setName(toolSpec.getName()); - } - - if (toolSpec.getDescription() != null) { - tool.setDescription(toolSpec.getDescription()); - } - return tool; - } - @VisibleForTesting Map getToolExecuteParams(MLToolSpec toolSpec, Map params) { Map executeParams = new HashMap<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 674a1237c6..0f6bba2931 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.agent; import static org.apache.commons.text.StringEscapeUtils.escapeJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import java.io.IOException; import java.security.AccessController; @@ -73,7 +74,7 @@ public MLFlowAgentRunner( @Override public void run(MLAgent mlAgent, Map params, ActionListener listener) { - List toolSpecs = mlAgent.getTools(); + List toolSpecs = getMlToolSpecs(mlAgent, params); StepListener firstStepListener = null; Tool firstTool = null; List flowAgentOutput = new ArrayList<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index bbeee117be..d32d2b3edd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -7,8 +7,8 @@ public class PromptTemplate { public static final String PROMPT_FORMAT_INSTRUCTION = "Human:RESPONSE FORMAT INSTRUCTIONS\n----------------------------\nOutput a JSON markdown code snippet containing a valid JSON object in one of two formats:\n\n**Option 1:**\nUse this if you want the human to use a tool.\nMarkdown code snippet formatted in the following schema:\n\n```json\n{\n \"thought\": string, // think about what to do next: if you know the final answer just return \"Now I know the final answer\", otherwise suggest which tool to use.\n \"action\": string, // The action to take. Must be one of these tool names: [${parameters.tool_names}], do NOT use any other name for action except the tool names.\n \"action_input\": string // The input to the action. May be a stringified object.\n}\n```\n\n**Option #2:**\nUse this if you want to respond directly and conversationally to the human. Markdown code snippet formatted in the following schema:\n\n```json\n{\n \"thought\": \"Now I know the final answer\",\n \"final_answer\": string, // summarize and return the final answer in a sentence with details, don't just return a number or a word.\n}\n```"; public static final String PROMPT_TEMPLATE_SUFFIX = - "Human:TOOLS\n------\nAssistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:\n\n${parameters.tool_descriptions}\n\n${parameters.prompt.format_instruction}\n\n${parameters.chat_history}\n\n\nHuman:USER'S INPUT\n--------------------\nHere is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):\n${parameters.question}\n\n${parameters.scratchpad}"; - public static final String PROMPT_TEMPLATE = "\n\nHuman:${parameters.prompt.prefix}\n\n${parameters.prompt.suffix}\n\nAssistant:"; + "Human:TOOLS\n------\nAssistant can ask Human to use tools to look up information that may be helpful in answering the users original question. The tool response will be listed in \"TOOL RESPONSE of {tool name}:\". If TOOL RESPONSE is enough to answer human's question, Assistant should avoid rerun the same tool. \nAssistant should NEVER suggest run a tool with same input if it's already in TOOL RESPONSE. \nThe tools the human can use are:\n\n${parameters.tool_descriptions}\n\n${parameters.chat_history}\n\n${parameters.prompt.format_instruction}\n\n\nHuman:USER'S INPUT\n--------------------\nHere is the user's input :\n${parameters.question}\n\n${parameters.scratchpad}"; + public static final String PROMPT_TEMPLATE = "\n\nHuman:${parameters.prompt.prefix}\n\n${parameters.prompt.suffix}\n\nHuman: follow RESPONSE FORMAT INSTRUCTIONS\n\nAssistant:"; public static final String PROMPT_TEMPLATE_TOOL_RESPONSE = - "TOOL RESPONSE: \n---------------------\n${parameters.observation}\n\nUSER'S INPUT\n--------------------\n\nOkay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else."; + "Assistant:\n---------------------\n${parameters.llm_tool_selection_response}\n\nHuman: TOOL RESPONSE of ${parameters.tool_name}: \n---------------------\nTool input:\n${parameters.tool_input}\n\nTool output:\n${parameters.observation}\n\n"; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 0e8169ac64..f221ec304b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -117,6 +117,9 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); } String modelResponse = responseBuilder.toString(); + log.info("############################################################ response"); + System.out.println(modelResponse); + log.info("############################################################ "); if (statusCode < 200 || statusCode >= 300) { throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index 6219a4b3a6..1dc76711ec 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -139,7 +139,7 @@ public void createInteraction( * @param actionListener get all the final interactions that are not traces */ public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener> actionListener) { - Preconditions.checkArgument(lastNInteraction > 0, "lastN must be at least 1."); + Preconditions.checkArgument(lastNInteraction > 0, "History message size must be at least 1."); log.debug("Getting Interactions, conversationId {}, lastN {}", conversationId, lastNInteraction); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index c26c650fe6..f6a37c7f60 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -14,6 +14,7 @@ import java.util.Locale; import java.util.Map; import java.util.Spliterators; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -113,10 +114,10 @@ public void run(Map parameters, ActionListener listener) StringBuilder sb = new StringBuilder( // Currently using c.value which is short header matching _cat/indices // May prefer to use c.attr.get("desc") for full description - table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining("\t", "", "\n")) + table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining(",", "", "\n")) ); for (List row : table.getRows()) { - sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining("\t", "", "\n"))); + sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining(",", "", "\n"))); } @SuppressWarnings("unchecked") T response = (T) sb.toString(); @@ -359,16 +360,17 @@ private Table getTableWithHeader() { table.startHeaders(); // First param is cell.value which is currently returned // Second param is cell.attr we may want to use attr.desc in the future + table.addCell("row", "alias:r;desc:row number"); table.addCell("health", "alias:h;desc:current health status"); table.addCell("status", "alias:s;desc:open/close status"); table.addCell("index", "alias:i,idx;desc:index name"); table.addCell("uuid", "alias:id,uuid;desc:index uuid"); - table.addCell("pri", "alias:p,shards.primary,shardsPrimary;text-align:right;desc:number of primary shards"); - table.addCell("rep", "alias:r,shards.replica,shardsReplica;text-align:right;desc:number of replica shards"); - table.addCell("docs.count", "alias:dc,docsCount;text-align:right;desc:available docs"); - table.addCell("docs.deleted", "alias:dd,docsDeleted;text-align:right;desc:deleted docs"); - table.addCell("store.size", "sibling:pri;alias:ss,storeSize;text-align:right;desc:store size of primaries & replicas"); - table.addCell("pri.store.size", "text-align:right;desc:store size of primaries"); + table.addCell("pri(number of primary shards)", "alias:p,shards.primary,shardsPrimary;text-align:right;desc:number of primary shards"); + table.addCell("rep(number of replica shards)", "alias:r,shards.replica,shardsReplica;text-align:right;desc:number of replica shards"); + table.addCell("docs.count(number of available documents)", "alias:dc,docsCount;text-align:right;desc:available docs"); + table.addCell("docs.deleted(number of deleted documents)", "alias:dd,docsDeleted;text-align:right;desc:deleted docs"); + table.addCell("store.size(store size of primary and replica shards)", "sibling:pri;alias:ss,storeSize;text-align:right;desc:store size of primaries & replicas"); + table.addCell("pri.store.size(store size of primary shards)", "text-align:right;desc:store size of primaries"); // Above includes all the default fields for cat indices. See RestIndicesAction for a lot more that could be included. table.endHeaders(); return table; @@ -381,7 +383,7 @@ private Table buildTable( final Map indicesMetadatas ) { final Table table = getTableWithHeader(); - + AtomicInteger rowNum = new AtomicInteger(0); indicesSettings.forEach((indexName, settings) -> { if (!indicesMetadatas.containsKey(indexName)) { // the index exists in the Get Indices response but is not present in the cluster state: @@ -414,6 +416,7 @@ private Table buildTable( totalStats = indexStats.getTotal(); } table.startRow(); + table.addCell(rowNum.addAndGet(1)); table.addCell(health); table.addCell(indexState.toString().toLowerCase(Locale.ROOT)); table.addCell(indexName); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index b348bcc228..fe4f4cd5a3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -8,18 +8,19 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.when; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES; -import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_PREFIX; -import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.PROMPT_SUFFIX; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -265,4 +266,40 @@ public void testExtractModelResponseJsonWithValidModelOutput() { String responseJson = AgentUtils.extractModelResponseJson(text); assertEquals("{\"thought\":\"use CatIndexTool to get index first\",\"action\":\"CatIndexTool\"}", responseJson); } + + @Test + public void test() { + String text = "---------------------\n{\n \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}"; + String result = AgentUtils.extractModelResponseJson(text); + String expectedResult = "{\n" + + " \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n" + + " \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n" + + "}"; + System.out.println(result); + Assert.assertEquals(expectedResult, result); + } + + @Test + public void test2() { + String text = "---------------------```json\n{\n \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n}\n```"; + String result = AgentUtils.extractModelResponseJson(text); + String expectedResult = "{\n" + + " \"thought\": \"Unfortunately the tools did not provide the weather forecast directly. Let me check online sources:\",\n" + + " \"final_answer\": \"After checking online weather forecasts, it looks like tomorrow will be sunny with a high of 25 degrees Celsius.\"\n" + + "}"; + System.out.println(result); + Assert.assertEquals(expectedResult, result); + } + @Test + public void test3() { + String text = "---------------------\n{\n \"thought\": \"Let me search our index to find population projections\", \n \"action\": \"VectorDBTool\",\n \"action_input\": \"Seattle population projection 2023\"\n}"; + String result = AgentUtils.extractModelResponseJson(text); + String expectedResult = "{\n" + + " \"thought\": \"Let me search our index to find population projections\", \n" + + " \"action\": \"VectorDBTool\",\n" + + " \"action_input\": \"Seattle population projection 2023\"\n" + + "}"; + System.out.println(result); + Assert.assertEquals(expectedResult, result); + } }