Skip to content

Commit a066434

Browse files
committed
Add request level parameters for system_prompt and user_instructions.
Signed-off-by: Austin Lee <[email protected]>
1 parent 7c7330d commit a066434

File tree

6 files changed

+380
-20
lines changed

6 files changed

+380
-20
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+12
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
171171
+ " \"generative_qa_parameters\": {\n"
172172
+ " \"llm_model\": \"%s\",\n"
173173
+ " \"llm_question\": \"%s\",\n"
174+
+ " \"system_prompt\": \"%s\",\n"
175+
+ " \"user_instructions\": \"%s\",\n"
174176
+ " \"context_size\": %d,\n"
175177
+ " \"message_size\": %d,\n"
176178
+ " \"timeout\": %d\n"
@@ -188,6 +190,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
188190
+ " \"llm_model\": \"%s\",\n"
189191
+ " \"llm_question\": \"%s\",\n"
190192
+ " \"memory_id\": \"%s\",\n"
193+
+ " \"system_prompt\": \"%s\",\n"
194+
+ " \"user_instructions\": \"%s\",\n"
191195
+ " \"context_size\": %d,\n"
192196
+ " \"message_size\": %d,\n"
193197
+ " \"timeout\": %d\n"
@@ -283,6 +287,8 @@ public void testBM25WithOpenAI() throws Exception {
283287
requestParameters.match = "president";
284288
requestParameters.llmModel = OPENAI_MODEL;
285289
requestParameters.llmQuestion = "who is lincoln";
290+
requestParameters.systemPrompt = "You are great at answering questions";
291+
requestParameters.userInstructions = "Follow my instructions as best you can";
286292
requestParameters.contextSize = 5;
287293
requestParameters.interactionSize = 5;
288294
requestParameters.timeout = 60;
@@ -502,6 +508,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
502508
requestParameters.match,
503509
requestParameters.llmModel,
504510
requestParameters.llmQuestion,
511+
requestParameters.systemPrompt,
512+
requestParameters.userInstructions,
505513
requestParameters.contextSize,
506514
requestParameters.interactionSize,
507515
requestParameters.timeout
@@ -516,6 +524,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
516524
requestParameters.llmModel,
517525
requestParameters.llmQuestion,
518526
requestParameters.conversationId,
527+
requestParameters.systemPrompt,
528+
requestParameters.userInstructions,
519529
requestParameters.contextSize,
520530
requestParameters.interactionSize,
521531
requestParameters.timeout
@@ -556,6 +566,8 @@ static class SearchRequestParameters {
556566
String match;
557567
String llmModel;
558568
String llmQuestion;
569+
String systemPrompt;
570+
String userInstructions;
559571
int contextSize;
560572
int interactionSize;
561573
int timeout;

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+10
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
143143
}
144144
List<String> searchResults = getSearchResults(response, topN);
145145

146+
// See if the prompt is being overridden at the request level.
147+
String effectiveSystemPrompt = systemPrompt;
148+
String effectiveUserInstructions = userInstructions;
149+
if (params.getSystemPrompt() != null) {
150+
effectiveSystemPrompt = params.getSystemPrompt();
151+
}
152+
if (params.getUserInstructions() != null) {
153+
effectiveUserInstructions = params.getUserInstructions();
154+
}
155+
146156
start = Instant.now();
147157
try {
148158
ChatCompletionOutput output = llm

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+27
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.opensearch.core.xcontent.ToXContentObject;
3030
import org.opensearch.core.xcontent.XContentBuilder;
3131
import org.opensearch.core.xcontent.XContentParser;
32+
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
3233

3334
import com.google.common.base.Preconditions;
3435

@@ -70,13 +71,19 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
7071
// from a remote inference endpoint before timing out the request.
7172
private static final ParseField TIMEOUT = new ParseField("timeout");
7273

74+
private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);
75+
76+
private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);
77+
7378
public static final int SIZE_NULL_VALUE = -1;
7479

7580
static {
7681
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
7782
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
7883
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
7984
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
85+
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
86+
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
8087
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
8188
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
8289
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
@@ -106,10 +113,20 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
106113
@Getter
107114
private Integer timeout;
108115

116+
@Setter
117+
@Getter
118+
private String systemPrompt;
119+
120+
@Setter
121+
@Getter
122+
private String userInstructions;
123+
109124
public GenerativeQAParameters(
110125
String conversationId,
111126
String llmModel,
112127
String llmQuestion,
128+
String systemPrompt,
129+
String userInstructions,
113130
Integer contextSize,
114131
Integer interactionSize,
115132
Integer timeout
@@ -121,6 +138,8 @@ public GenerativeQAParameters(
121138
// for question rewriting.
122139
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
123140
this.llmQuestion = llmQuestion;
141+
this.systemPrompt = systemPrompt;
142+
this.userInstructions = userInstructions;
124143
this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize;
125144
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
126145
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
@@ -130,6 +149,8 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
130149
this.conversationId = input.readOptionalString();
131150
this.llmModel = input.readOptionalString();
132151
this.llmQuestion = input.readString();
152+
this.systemPrompt = input.readOptionalString();
153+
this.userInstructions = input.readOptionalString();
133154
this.contextSize = input.readInt();
134155
this.interactionSize = input.readInt();
135156
this.timeout = input.readInt();
@@ -141,6 +162,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
141162
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
142163
.field(LLM_MODEL.getPreferredName(), this.llmModel)
143164
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
165+
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
166+
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
144167
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
145168
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
146169
.field(TIMEOUT.getPreferredName(), this.timeout);
@@ -153,6 +176,8 @@ public void writeTo(StreamOutput out) throws IOException {
153176

154177
Preconditions.checkNotNull(llmQuestion, "llm_question must not be null.");
155178
out.writeString(llmQuestion);
179+
out.writeOptionalString(systemPrompt);
180+
out.writeOptionalString(userInstructions);
156181
out.writeInt(contextSize);
157182
out.writeInt(interactionSize);
158183
out.writeInt(timeout);
@@ -175,6 +200,8 @@ public boolean equals(Object o) {
175200
return Objects.equals(this.conversationId, other.getConversationId())
176201
&& Objects.equals(this.llmModel, other.getLlmModel())
177202
&& Objects.equals(this.llmQuestion, other.getLlmQuestion())
203+
&& Objects.equals(this.systemPrompt, other.getSystemPrompt())
204+
&& Objects.equals(this.userInstructions, other.getUserInstructions())
178205
&& (this.contextSize == other.getContextSize())
179206
&& (this.interactionSize == other.getInteractionSize())
180207
&& (this.timeout == other.getTimeout());

0 commit comments

Comments
 (0)