Skip to content

Commit

Permalink
tune pattern
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 12, 2024
1 parent d8bc43a commit d65eaa4
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,39 @@ public static String addContextToPrompt(Map<String, String> parameters, String p
return prompt;
}

public static List<String> MODEL_RESPONSE_PATTERNS = List.of(
"\\{\\s*\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?\\}",
"\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"action\"\\s*:\\s*\".*?\"\\s*,\\s*\"action_input\"\\s*:\\s*\".*?\"\\s*}",
"\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"final_answer\"\\s*:\\s*\".*?\"\\s*}"
);
public static String extractModelResponseJson(String text) {
Pattern pattern = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher matcher = pattern.matcher(text);
Pattern pattern1 = Pattern.compile("```json\\s*([\\s\\S]+?)\\s*```");
Matcher matcher1 = pattern1.matcher(text);

if (matcher.find()) {
return matcher.group(1);
if (matcher1.find()) {
return matcher1.group(1);
} else {
Pattern pattern2 = Pattern.compile("\\{(?:[^{}]|\\{(?:[^{}]|\\{[^{}]*\\})*\\})*\\}");
Matcher matcher2 = pattern2.matcher(text);
// Find the JSON content
if (matcher2.find()) {
return matcher2.group();
for (String p : MODEL_RESPONSE_PATTERNS) {
Pattern pattern = Pattern.compile(p);
Matcher matcher = pattern.matcher(text);
if (matcher.find()) {
return matcher.group();
}
}
//// Pattern pattern2 = Pattern.compile("\\{(?:[^{}]|\\{(?:[^{}]|\\{[^{}]*\\})*\\})*\\}");
// Pattern pattern2 = Pattern.compile("\\{\\s*\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?\\}");
// Pattern pattern3 = Pattern.compile("\\{\\s*\"thought\"\\s*:\\s*\".*?\"\\s*,\\s*\"final_answer\"\\s*:\\s*\".*?\"\\s*}");
//
//// Pattern pattern2 = Pattern.compile("\\{\\s*(\"thought\":.*?\\s*,\\s*\"action\":.*?\\s*,\\s*\"action_input\":.*?|\"thought\":.*?\\s*,\\s*\"final_answer\":.*?)\\}");
// Matcher matcher2 = pattern2.matcher(text);
// Matcher matcher3 = pattern3.matcher(text);
// // Find the JSON content
// if (matcher2.find()) {
// return matcher2.group();
// }
// if (matcher3.find()) {
// return matcher3.group();
// }
throw new IllegalArgumentException("Model output is invalid");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.RESPONSE_FORMAT_INSTRUCTION;
Expand All @@ -19,6 +20,8 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -272,33 +275,35 @@ private void runReAct(
MLTaskResponse llmResponse = (MLTaskResponse) output;
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
String thoughtResponse = "";
String thoughtResponse = null;
String finalAnswer = null;
if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
String llmReasoningResponse = (String) dataAsMap.get("response");
thoughtResponse = extractModelResponseJson(llmReasoningResponse);
dataAsMap = gson.fromJson(thoughtResponse, Map.class);
try {
thoughtResponse = extractModelResponseJson(llmReasoningResponse);
} catch (IllegalArgumentException e) {
thoughtResponse = llmReasoningResponse;
finalAnswer = llmReasoningResponse;
System.out.println("0000000000 ylwudddebug1: get final answer directly : " + finalAnswer);
}
if (isJson(thoughtResponse)) {
dataAsMap = gson.fromJson(thoughtResponse, Map.class);
}
} else {
try {
Map<String, ?> finalDataAsMap = dataAsMap;
thoughtResponse = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(finalDataAsMap));
} catch (Exception e) {
e.printStackTrace();
}
}
lastToolSelectionResponse.set(thoughtResponse);
String thought = String.valueOf(dataAsMap.get("thought"));
String action = String.valueOf(dataAsMap.get("action"));
String actionInput = parseInputFromLLMReturn(dataAsMap);
String finalAnswer = (String) dataAsMap.get("final_answer");

if (finalI == 0 && !thought.contains("Thought:")) {
sessionMsgAnswerBuilder.append("Thought: ");
if (dataAsMap.containsKey("final_answer")) {
finalAnswer = String.valueOf(dataAsMap.get("final_answer"));
}
sessionMsgAnswerBuilder.append(thought);
lastThought.set(thought);
cotModelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build())).build());
// TODO: check if verbose
modelTensors.addAll(tmpModelTensorOutput.getMlModelOutputs());

if (conversationIndexMemory != null && finalAnswer == null) {
ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(thoughtResponse).finalAnswer(false).sessionId(sessionId).build();
if (!traceDisabled) {
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM");
}
}
if (finalAnswer != null) {
finalAnswer = finalAnswer.trim();
if (conversationIndexMemory != null) {
Expand All @@ -315,35 +320,29 @@ private void runReAct(
listener.onFailure(e);
});
saveMessage(conversationIndexMemory, question, finalAnswer, sessionId, parentInteractionId, traceNumber, true, traceDisabled, saveTraceListener);
// // Create final trace message.
// ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(finalAnswer1).finalAnswer(true).sessionId(sessionId).build();
// // Save last trace and update final answer in parallel.
// if (!traceDisabled) {
// conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", ActionListener.<CreateInteractionResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
// }
// conversationIndexMemory.getMemoryManager().updateInteraction(parentInteractionId, ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo), ActionListener.<UpdateResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
} else {
returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, getFinalAnswer, additionalInfo, finalAnswer);
}
return;
}

lastToolSelectionResponse.set(thoughtResponse);

if (finalI == 0 && !thought.contains("Thought:")) {
sessionMsgAnswerBuilder.append("Thought: ");
}
sessionMsgAnswerBuilder.append(thought);
lastThought.set(thought);

// Composite execution response and reply.
// final ActionListener<Boolean> executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> {
// extractedTmep(sessionId, listener, parentInteractionId, verbose, cotModelTensors, getFinalAnswer, additionalInfo, finalAnswer2);
// }, listener::onFailure));
// // Sending execution response by internalListener is after the trace and answer saving.
// final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(traceDisabled? 1:2, executionListener);
// if (conversationIndexMemory != null) {
// String finalAnswer1 = finalAnswer;
// // Create final trace message.
// ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(finalAnswer1).finalAnswer(true).sessionId(sessionId).build();
// // Save last trace and update final answer in parallel.
// if (!traceDisabled) {
// conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", ActionListener.<CreateInteractionResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
// }
// conversationIndexMemory.getMemoryManager().updateInteraction(parentInteractionId, ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo), ActionListener.<UpdateResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
// }
// return;
cotModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build())).build());
// TODO: check if verbose
modelTensors.addAll(tmpModelTensorOutput.getMlModelOutputs());

if (conversationIndexMemory != null) {
ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(thoughtResponse).finalAnswer(false).sessionId(sessionId).build();
if (!traceDisabled) {
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM");
}
}

lastAction.set(action);
Expand Down Expand Up @@ -466,11 +465,11 @@ private void runReAct(
}

private static void returnFinalResponse(String sessionId, ActionListener<Object> listener, String parentInteractionId, boolean verbose, List<ModelTensors> cotModelTensors, AtomicBoolean getFinalAnswer, Map<String, Object> additionalInfo, String finalAnswer2) {
cotModelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build())).build());
cotModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name("response").result(finalAnswer2).build())).build());

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build())).build());
finalModelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)).build())).build());
finalModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)).build())).build());
getFinalAnswer.set(true);
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ public class PromptTemplate {
"Human:RESPONSE FORMAT INSTRUCTIONS\n----------------------------\nOutput a JSON markdown code snippet containing a valid JSON object in one of two formats:\n\n**Option 1:**\nUse this if you want the human to use a tool.\nMarkdown code snippet formatted in the following schema:\n\n```json\n{\n \"thought\": string, // think about what to do next: if you know the final answer just return \"Now I know the final answer\", otherwise suggest which tool to use.\n \"action\": string, // The action to take. Must be one of these tool names: [${parameters.tool_names}], do NOT use any other name for action except the tool names.\n \"action_input\": string // The input to the action. May be a stringified object.\n}\n```\n\n**Option #2:**\nUse this if you want to respond directly and conversationally to the human. Markdown code snippet formatted in the following schema:\n\n```json\n{\n \"thought\": \"Now I know the final answer\",\n \"final_answer\": string, // summarize and return the final answer in a sentence with details, don't just return a number or a word.\n}\n```";
public static final String PROMPT_TEMPLATE_SUFFIX =
"Human:TOOLS\n------\nAssistant can ask Human to use tools to look up information that may be helpful in answering the users original question. The tool response will be listed in \"TOOL RESPONSE of {tool name}:\". If TOOL RESPONSE is enough to answer human's question, Assistant should avoid rerun the same tool. \nAssistant should NEVER suggest run a tool with same input if it's already in TOOL RESPONSE. \nThe tools the human can use are:\n\n${parameters.tool_descriptions}\n\n${parameters.chat_history}\n\n${parameters.prompt.format_instruction}\n\n\nHuman:USER'S INPUT\n--------------------\nHere is the user's input :\n${parameters.question}\n\n${parameters.scratchpad}";
public static final String PROMPT_TEMPLATE = "\n\nHuman:${parameters.prompt.prefix}\n\n${parameters.prompt.suffix}\n\nAssistant:";
public static final String PROMPT_TEMPLATE = "\n\nHuman:${parameters.prompt.prefix}\n\n${parameters.prompt.suffix}\n\nHuman: follow RESPONSE FORMAT INSTRUCTIONS\n\nAssistant:";
public static final String PROMPT_TEMPLATE_TOOL_RESPONSE =
"Assistant:\n---------------------\n${parameters.llm_tool_selection_response}\n\nHuman: TOOL RESPONSE of ${parameters.tool_name}: \n---------------------\nTool input:\n${parameters.tool_input}\n\nTool output:\n${parameters.observation}\n\n";
public static final String PROMPT_TEMPLATE_ASK_AGAIN = "\n\nUSER'S INPUT\n--------------------\n\nOkay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names.";
//public static final String s = "USER'S INPUT\n--------------------\n\nOkay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else.";
// Human: follow RESPONSE FORMAT INSTRUCTIONS
}
Loading

0 comments on commit d65eaa4

Please sign in to comment.