Skip to content

Commit

Permalink
test3
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 12, 2024
1 parent 89f9746 commit d8bc43a
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ public static String extractModelResponseJson(String text) {
if (matcher.find()) {
return matcher.group(1);
} else {
Pattern pattern2 = Pattern.compile("\\{(?:[^{}]|\\{(?:[^{}]|\\{[^{}]*\\})*\\})*\\}");
Matcher matcher2 = pattern2.matcher(text);
// Find the JSON content
if (matcher2.find()) {
return matcher2.group();
}
throw new IllegalArgumentException("Model output is invalid");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PROMPT_TEMPLATE_ASK_AGAIN;

import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -32,8 +31,6 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -178,6 +175,7 @@ private void runReAct(
String question = parameters.get(MLAgentExecutor.QUESTION);
String parentInteractionId = parameters.get(MLAgentExecutor.PARENT_INTERACTION_ID);
boolean verbose = parameters.containsKey("verbose") && Boolean.parseBoolean(parameters.get("verbose"));
boolean traceDisabled = parameters.containsKey("disable_trace") && Boolean.parseBoolean(parameters.get("disable_trace"));
Map<String, String> tmpParameters = new HashMap<>();
if (llm.getParameters() != null) {
tmpParameters.putAll(llm.getParameters());
Expand Down Expand Up @@ -245,6 +243,7 @@ private void runReAct(
AtomicReference<String> lastThought = new AtomicReference<>();
AtomicReference<String> lastAction = new AtomicReference<>();
AtomicReference<String> lastActionInput = new AtomicReference<>();
AtomicReference<String> lastToolSelectionResponse = new AtomicReference<>();
Map<String, Object> additionalInfo = new ConcurrentHashMap<>();

StepListener<?> lastStepListener = null;
Expand All @@ -265,7 +264,6 @@ private void runReAct(




//////////////////////////////////////////////////////////////////////////////////////////
// start
//////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -275,31 +273,16 @@ private void runReAct(
ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput();
Map<String, ?> dataAsMap = tmpModelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap();
String thoughtResponse = "";
String llmReasoningResponse = null;
if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
llmReasoningResponse = (String) dataAsMap.get("response");
String llmReasoningResponse = (String) dataAsMap.get("response");
thoughtResponse = extractModelResponseJson(llmReasoningResponse);
dataAsMap = gson.fromJson(thoughtResponse, Map.class);
}
lastToolSelectionResponse.set(thoughtResponse);
String thought = String.valueOf(dataAsMap.get("thought"));
String action = String.valueOf(dataAsMap.get("action"));
String actionInput = parseInputFromLLMReturn(dataAsMap);
String finalAnswer = (String) dataAsMap.get("final_answer");
// if (!dataAsMap.containsKey("thought")) {//TODO: check if we can remove this if block
// String response = (String) dataAsMap.get("response");
// Pattern pattern = Pattern.compile("```json(.*?)```", Pattern.DOTALL);
// Matcher matcher = pattern.matcher(response);
// if (matcher.find()) {
// String jsonBlock = matcher.group(1);
// Map map = gson.fromJson(jsonBlock, Map.class);
// thought = String.valueOf(map.get("thought"));
// action = String.valueOf(map.get("action"));
// actionInput = parseInputFromLLMReturn(map);
// finalAnswer = (String) map.get("final_answer");
// } else {
// finalAnswer = response;
// }
// }

if (finalI == 0 && !thought.contains("Thought:")) {
sessionMsgAnswerBuilder.append("Thought: ");
Expand All @@ -311,78 +294,56 @@ private void runReAct(
modelTensors.addAll(tmpModelTensorOutput.getMlModelOutputs());

if (conversationIndexMemory != null && finalAnswer == null) {
ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(llmReasoningResponse).finalAnswer(false).sessionId(sessionId).build();
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null);
ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(thoughtResponse).finalAnswer(false).sessionId(sessionId).build();
if (!traceDisabled) {
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM");
}
}
if (finalAnswer != null) {
finalAnswer = finalAnswer.trim();
String finalAnswer2 = finalAnswer;
// Composite execution response and reply.
final ActionListener<Boolean> executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> {
cotModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build())
)
.build()
);

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
List
.of(
ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(),
ModelTensor
.builder()
.name(MLAgentExecutor.PARENT_INTERACTION_ID)
.result(parentInteractionId)
.build()
)
)
.build()
);
finalModelTensors
.add(
ModelTensors
.builder()
.mlModelTensors(
Collections
.singletonList(
ModelTensor
.builder()
.name("response")
.dataAsMap(
ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)
)
.build()
)
)
.build()
);
getFinalAnswer.set(true);
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}
}, listener::onFailure));
// Sending execution response by internalListener is after the trace and answer saving.
final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(2, executionListener);
if (conversationIndexMemory != null) {
String finalAnswer1 = finalAnswer;
// Create final trace message.
ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(finalAnswer1).finalAnswer(true).sessionId(sessionId).build();
// Save last trace and update final answer in parallel.
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), null, ActionListener.<CreateInteractionResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
conversationIndexMemory.getMemoryManager().updateInteraction(parentInteractionId, ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo), ActionListener.<UpdateResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
String copyOfFinalAnswer = finalAnswer;
ActionListener saveTraceListener = ActionListener.wrap(r->{
conversationIndexMemory.getMemoryManager().updateInteraction(parentInteractionId,
Map.of(AI_RESPONSE_FIELD, copyOfFinalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo),
ActionListener.wrap(res ->{
returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, getFinalAnswer, additionalInfo, copyOfFinalAnswer);
}, e-> {
listener.onFailure(e);
}));
}, e-> {
listener.onFailure(e);
});
saveMessage(conversationIndexMemory, question, finalAnswer, sessionId, parentInteractionId, traceNumber, true, traceDisabled, saveTraceListener);
// // Create final trace message.
// ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(finalAnswer1).finalAnswer(true).sessionId(sessionId).build();
// // Save last trace and update final answer in parallel.
// if (!traceDisabled) {
// conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", ActionListener.<CreateInteractionResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
// }
// conversationIndexMemory.getMemoryManager().updateInteraction(parentInteractionId, ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo), ActionListener.<UpdateResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
} else {
returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, getFinalAnswer, additionalInfo, finalAnswer);
}
return;

// Composite execution response and reply.
// final ActionListener<Boolean> executionListener = ActionListener.notifyOnce(ActionListener.wrap(r -> {
// extractedTmep(sessionId, listener, parentInteractionId, verbose, cotModelTensors, getFinalAnswer, additionalInfo, finalAnswer2);
// }, listener::onFailure));
// // Sending execution response by internalListener is after the trace and answer saving.
// final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(traceDisabled? 1:2, executionListener);
// if (conversationIndexMemory != null) {
// String finalAnswer1 = finalAnswer;
// // Create final trace message.
// ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(finalAnswer1).finalAnswer(true).sessionId(sessionId).build();
// // Save last trace and update final answer in parallel.
// if (!traceDisabled) {
// conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", ActionListener.<CreateInteractionResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
// }
// conversationIndexMemory.getMemoryManager().updateInteraction(parentInteractionId, ImmutableMap.of(AI_RESPONSE_FIELD, finalAnswer1, ADDITIONAL_INFO_FIELD, additionalInfo), ActionListener.<UpdateResponse>wrap(groupedListener::onResponse, groupedListener::onFailure));
// }
// return;
}

lastAction.set(action);
Expand Down Expand Up @@ -456,15 +417,17 @@ private void runReAct(
modelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + outputToOutputString(output))).build())).build());

String toolResponse = tmpParameters.get(TOOL_RESPONSE);
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(Map.of("observation", outputToOutputString(output), "tool_name", lastAction.get()), "${parameters.", "}");
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(Map.of("llm_tool_selection_response", /*"```json\\n" +*/ lastToolSelectionResponse.get() /*+ "\n```"*/, "tool_name", lastAction.get(),"tool_input", lastActionInput.get(),"observation", outputToOutputString(output)), "${parameters.", "}");
toolResponse = toolResponseSubstitutor.replace(toolResponse);
scratchpadBuilder.append(toolResponse).append("\n\n");
if (conversationIndexMemory != null) {
// String res = "Action: " + lastAction.get() + "\nAction Input: " + lastActionInput + "\nObservation: " + result;
ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type("ReAct").question(lastActionInput.get()).response(outputToOutputString(output)).finalAnswer(false).sessionId(sessionId).build();
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), lastAction.get());
if (!traceDisabled) {
conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), lastAction.get());
}
}
StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder + PROMPT_TEMPLATE_ASK_AGAIN), "${parameters.", "}");
StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder /*+ PROMPT_TEMPLATE_ASK_AGAIN*/), "${parameters.", "}");
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());

Expand Down Expand Up @@ -502,6 +465,39 @@ private void runReAct(
client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener);
}

private static void returnFinalResponse(String sessionId, ActionListener<Object> listener, String parentInteractionId, boolean verbose, List<ModelTensors> cotModelTensors, AtomicBoolean getFinalAnswer, Map<String, Object> additionalInfo, String finalAnswer2) {
cotModelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().name("response").result(finalAnswer2).build())).build());

List<ModelTensors> finalModelTensors = new ArrayList<>();
finalModelTensors.add(ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result(sessionId).build(), ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result(parentInteractionId).build())).build());
finalModelTensors.add(ModelTensors.builder().mlModelTensors(Collections.singletonList(ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)).build())).build());
getFinalAnswer.set(true);
if (verbose) {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(cotModelTensors).build());
} else {
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(finalModelTensors).build());
}
}

private void saveMessage(
ConversationIndexMemory memory,
String question ,
String finalAnswer ,
String sessionId ,
String parentInteractionId,
AtomicInteger traceNumber,
boolean isFinalAnswer,
boolean traceDisabled,
ActionListener listener
) {
ConversationIndexMessage msgTemp = ConversationIndexMessage.conversationIndexMessageBuilder().type(memory.getType()).question(question).response(finalAnswer).finalAnswer(isFinalAnswer).sessionId(sessionId).build();
if (traceDisabled) {
listener.onResponse(true);
} else {
memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener);
}
}

private GroupedActionListener<ActionResponse> createGroupedListener(final int size, final ActionListener<Boolean> listener) {
return new GroupedActionListener<>(new ActionListener<Collection<ActionResponse>>() {
@Override
Expand Down
Loading

0 comments on commit d8bc43a

Please sign in to comment.