Skip to content

Commit

Permalink
enhance parsing model response function for more edge cases
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 16, 2024
1 parent c38854f commit 91e3943
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@

package org.opensearch.ml.engine.algorithms.agent;

import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MESSAGE_HISTORY_LIMIT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.ACTION_INPUT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CONTEXT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.EXAMPLES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.FINAL_ANSWER;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.OS_INDICES;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.THOUGHT_RESPONSE;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;
Expand All @@ -19,10 +27,13 @@
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -33,7 +44,11 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.utils.StringUtils;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class AgentUtils {

public static final String SELECTED_TOOLS = "selected_tools";
Expand Down Expand Up @@ -167,23 +182,161 @@ public static String extractModelResponseJson(String text) {
return extractModelResponseJson(text, null);
}

public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
Pattern jsonBlockPattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);

if (jsonBlockMatcher.find()) {
return jsonBlockMatcher.group(1);
public static Map<String, String> parseLLMOutput(
ModelTensorOutput tmpModelTensorOutput,
List<String> llmResponsePatterns,
Set<String> inputTools
) {
Map<String, String> modelOutput = new HashMap<>();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
String llmReasoningResponse = (String) dataAsMap.get("response");
String thoughtResponse = null;
try {
thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns);
modelOutput.put(THOUGHT_RESPONSE, thoughtResponse);
} catch (IllegalArgumentException e) {
modelOutput.put(THOUGHT_RESPONSE, llmReasoningResponse);
thoughtResponse = llmReasoningResponse;
}
parseThoughtResponse(modelOutput, thoughtResponse);
} else {
String matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
if (matchedPart == null && llmResponsePatterns != null) {
// If no match is found, try additional patterns if provided
matchedPart = findMatchedPart(text, llmResponsePatterns);
extractParams(modelOutput, dataAsMap, THOUGHT);
extractParams(modelOutput, dataAsMap, ACTION);
extractParams(modelOutput, dataAsMap, ACTION_INPUT);
extractParams(modelOutput, dataAsMap, FINAL_ANSWER);
try {
modelOutput.put(THOUGHT_RESPONSE, StringUtils.toJson(dataAsMap));
} catch (Exception e) {
log.warn("Failed to parse model response", e);
}
}
String action = modelOutput.get(ACTION);
if (action != null) {
modelOutput.put(ACTION, getMatchingTool(inputTools, action));
}
if (!modelOutput.containsKey(ACTION) && !modelOutput.containsKey(FINAL_ANSWER)) {
modelOutput.put(FINAL_ANSWER, modelOutput.get(THOUGHT_RESPONSE));
}
return modelOutput;
}

public static String getMatchingTool(Collection<String> tools, String action) {
for (String tool : tools) {
if (action.toLowerCase(Locale.ROOT).contains(tool.toLowerCase(Locale.ROOT))) {
return tool;
}
}
return null;
}

public static void extractParams(Map<String, String> modelOutput, Map<String, ?> dataAsMap, String paramName) {
if (dataAsMap.containsKey(paramName)) {
modelOutput.put(paramName, toJson(dataAsMap.get(paramName)));
}
}

public static String extractModelResponseJson(String text, List<String> llmResponsePatterns) {
if (text.contains("```json")) {
text = text.substring(text.indexOf("```json") + "```json".length());
if (text.contains("```")) {
text = text.substring(0, text.lastIndexOf("```"));
}
}
text = text.trim();
if (isJson(text)) {
return text;
}
String matchedPart = null;
if (llmResponsePatterns != null) {
matchedPart = findMatchedPart(text, llmResponsePatterns);
if (matchedPart != null) {
return matchedPart;
}
throw new IllegalArgumentException("Model output is invalid");
}
matchedPart = findMatchedPart(text, MODEL_RESPONSE_PATTERNS);
if (matchedPart != null) {
return matchedPart;
}
throw new IllegalArgumentException("Model output is invalid");
}

public static void parseThoughtResponse(Map<String, String> modelOutput, String thoughtResponse) {
if (thoughtResponse != null) {
if (isJson(thoughtResponse)) {
modelOutput.putAll(getParameterMap(gson.fromJson(thoughtResponse, Map.class)));
} else {// sometimes LLM return invalid json response
String thought = extractThought(thoughtResponse);
String action = extractAction(thoughtResponse);
String actionInput = extractActionInput(thoughtResponse);
String finalAnswer = extractFinalAnswer(thoughtResponse);
if (thought != null) {
modelOutput.put(THOUGHT, thought);
}
if (action != null) {
modelOutput.put(ACTION, action);
}
if (actionInput != null) {
modelOutput.put(ACTION_INPUT, actionInput);
}
if (finalAnswer != null) {
modelOutput.put(FINAL_ANSWER, finalAnswer);
}
}
}
}

public static String extractFinalAnswer(String text) {
String result = null;
if (text.contains("\"final_answer\"")) {
String pattern = "\"final_answer\"\\s*:\\s*\"(.*?)$";
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
if (jsonBlockMatcher.find()) {
result = jsonBlockMatcher.group(1);
}
}
return result;
}

public static String extractThought(String text) {
String result = null;
if (text.contains("\"thought\"")) {
String pattern = "\"thought\"\\s*:\\s*\"(.*?)\"\\s*,\\s*[\"final_answer\"|\"action\"]";
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
if (jsonBlockMatcher.find()) {
result = jsonBlockMatcher.group(1);
}
}
return result;
}

public static String extractAction(String text) {
String result = null;
if (text.contains("\"action\"")) {
String pattern = "\"action\"\\s*:\\s*\"(.*?)(?:\"action_input\"|$)";
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
if (jsonBlockMatcher.find()) {
result = jsonBlockMatcher.group(1);
}
}
return result;
}

public static String extractActionInput(String text) {
String result = null;
if (text.contains("\"action_input\"")) {
String pattern = "\"action_input\"\\s*:\\s*\"((?:[^\\\"]|\\\")*)\"";
Pattern jsonBlockPattern = Pattern.compile(pattern, Pattern.DOTALL); // Add Pattern.DOTALL to match across newlines
Matcher jsonBlockMatcher = jsonBlockPattern.matcher(text);
if (jsonBlockMatcher.find()) {
result = jsonBlockMatcher.group(1);
result = result.replace("\\\"", "\"");
}
}
return result;
}

public static String findMatchedPart(String text, List<String> patternList) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@

import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
Expand All @@ -19,11 +16,11 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolNames;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX;

import java.security.PrivilegedActionException;
Expand Down Expand Up @@ -62,7 +59,6 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.engine.tools.MLModelTool;
Expand Down Expand Up @@ -196,7 +192,7 @@ private void runReAct(
boolean traceDisabled = parameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(parameters.get(DISABLE_TRACE));

Map<String, String> tmpParameters = constructLLMParams(llm, parameters);
String prompt = constructLLMPrompt(tools, parameters, inputTools, tmpParameters);
String prompt = constructLLMPrompt(tools, inputTools, tmpParameters);
tmpParameters.put(PROMPT, prompt);

List<ModelTensors> traceTensors = createModelTensors(sessionId, parentInteractionId);
Expand Down Expand Up @@ -235,7 +231,7 @@ private void runReAct(
MLTaskResponse llmResponse = (MLTaskResponse) output;
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
List<String> llmResponsePatterns = gson.fromJson(parameters.get("llm_response_pattern"), List.class);
Map<String, String> modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns);
Map<String, String> modelOutput = parseLLMOutput(tmpModelTensorOutput, llmResponsePatterns, tools.keySet());

String thought = String.valueOf(modelOutput.get(THOUGHT));
String action = String.valueOf(modelOutput.get(ACTION));
Expand Down Expand Up @@ -287,8 +283,6 @@ private void runReAct(
"LLM"
);

action = getMatchingTool(tools, action);

if (tools.containsKey(action) && inputTools.contains(action)) {
Map<String, String> toolParams = constructToolParams(
tools,
Expand Down Expand Up @@ -402,42 +396,6 @@ private void runReAct(
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
}

private static Map<String, String> parseLLMOutput(ModelTensorOutput tmpModelTensorOutput, List<String> llmResponsePatterns) {
Map<String, String> modelOutput = new HashMap<>();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
String llmReasoningResponse = (String) dataAsMap.get("response");
String thoughtResponse = null;
try {
thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns);
modelOutput.put(THOUGHT_RESPONSE, thoughtResponse);
} catch (IllegalArgumentException e) {
modelOutput.put(THOUGHT_RESPONSE, llmReasoningResponse);
modelOutput.put(FINAL_ANSWER, llmReasoningResponse);
}
if (isJson(thoughtResponse)) {
modelOutput.putAll(getParameterMap(gson.fromJson(thoughtResponse, Map.class)));
}
} else {
extractParams(modelOutput, dataAsMap, THOUGHT);
extractParams(modelOutput, dataAsMap, ACTION);
extractParams(modelOutput, dataAsMap, ACTION_INPUT);
extractParams(modelOutput, dataAsMap, FINAL_ANSWER);
try {
modelOutput.put(THOUGHT_RESPONSE, StringUtils.toJson(dataAsMap));
} catch (Exception e) {
log.warn("Failed to parse model response", e);
}
}
return modelOutput;
}

private static void extractParams(Map<String, String> modelOutput, Map<String, ?> dataAsMap, String paramName) {
if (dataAsMap.containsKey(paramName)) {
modelOutput.put(paramName, toJson(dataAsMap.get(paramName)));
}
}

private static List<ModelTensors> createFinalAnswerTensors(List<ModelTensors> sessionId, List<ModelTensor> lastThought) {
List<ModelTensors> finalModelTensors = sessionId;
finalModelTensors.add(ModelTensors.builder().mlModelTensors(lastThought).build());
Expand Down Expand Up @@ -550,16 +508,6 @@ private static Map<String, String> constructToolParams(
return toolParams;
}

private static String getMatchingTool(Map<String, Tool> tools, String name) {
String toolName = name;
for (String key : tools.keySet()) {
if (name.toLowerCase().contains(key.toLowerCase())) {
toolName = key;
}
}
return toolName;
}

private static void saveTraceData(
ConversationIndexMemory conversationIndexMemory,
String memory,
Expand Down Expand Up @@ -655,21 +603,16 @@ private static List<ModelTensors> createModelTensors(String sessionId, String pa
return cotModelTensors;
}

private static String constructLLMPrompt(
Map<String, Tool> tools,
Map<String, String> parameters,
List<String> inputTools,
Map<String, String> tmpParameters
) {
String prompt = parameters.getOrDefault(PROMPT, PromptTemplate.PROMPT_TEMPLATE);
private static String constructLLMPrompt(Map<String, Tool> tools, List<String> inputTools, Map<String, String> tmpParameters) {
String prompt = tmpParameters.getOrDefault(PROMPT, PromptTemplate.PROMPT_TEMPLATE);
StringSubstitutor promptSubstitutor = new StringSubstitutor(tmpParameters, "${parameters.", "}");
prompt = promptSubstitutor.replace(prompt);
prompt = AgentUtils.addPrefixSuffixToPrompt(parameters, prompt);
prompt = AgentUtils.addToolsToPrompt(tools, parameters, inputTools, prompt);
prompt = AgentUtils.addIndicesToPrompt(parameters, prompt);
prompt = AgentUtils.addExamplesToPrompt(parameters, prompt);
prompt = AgentUtils.addChatHistoryToPrompt(parameters, prompt);
prompt = AgentUtils.addContextToPrompt(parameters, prompt);
prompt = AgentUtils.addPrefixSuffixToPrompt(tmpParameters, prompt);
prompt = AgentUtils.addToolsToPrompt(tools, tmpParameters, inputTools, prompt);
prompt = AgentUtils.addIndicesToPrompt(tmpParameters, prompt);
prompt = AgentUtils.addExamplesToPrompt(tmpParameters, prompt);
prompt = AgentUtils.addChatHistoryToPrompt(tmpParameters, prompt);
prompt = AgentUtils.addContextToPrompt(tmpParameters, prompt);
return prompt;
}

Expand Down
Loading

0 comments on commit 91e3943

Please sign in to comment.