Skip to content

Commit b3b2f2d

Browse files
committed
Add request level parameters for system_prompt and user_instructions.
Signed-off-by: Austin Lee <[email protected]>
1 parent 833c9c1 commit b3b2f2d

File tree

6 files changed

+163
-19
lines changed

6 files changed

+163
-19
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

+12
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
171171
+ " \"generative_qa_parameters\": {\n"
172172
+ " \"llm_model\": \"%s\",\n"
173173
+ " \"llm_question\": \"%s\",\n"
174+
+ " \"system_prompt\": \"%s\",\n"
175+
+ " \"user_instructions\": \"%s\",\n"
174176
+ " \"context_size\": %d,\n"
175177
+ " \"message_size\": %d,\n"
176178
+ " \"timeout\": %d\n"
@@ -188,6 +190,8 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
188190
+ " \"llm_model\": \"%s\",\n"
189191
+ " \"llm_question\": \"%s\",\n"
190192
+ " \"memory_id\": \"%s\",\n"
193+
+ " \"system_prompt\": \"%s\",\n"
194+
+ " \"user_instructions\": \"%s\",\n"
191195
+ " \"context_size\": %d,\n"
192196
+ " \"message_size\": %d,\n"
193197
+ " \"timeout\": %d\n"
@@ -308,6 +312,8 @@ public void testBM25WithOpenAI() throws Exception {
308312
requestParameters.match = "president";
309313
requestParameters.llmModel = OPENAI_MODEL;
310314
requestParameters.llmQuestion = "who is lincoln";
315+
requestParameters.systemPrompt = "You are great at answering questions";
316+
requestParameters.userInstructions = "Follow my instructions as best you can";
311317
requestParameters.contextSize = 5;
312318
requestParameters.interactionSize = 5;
313319
requestParameters.timeout = 60;
@@ -527,6 +533,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
527533
requestParameters.match,
528534
requestParameters.llmModel,
529535
requestParameters.llmQuestion,
536+
requestParameters.systemPrompt,
537+
requestParameters.userInstructions,
530538
requestParameters.contextSize,
531539
requestParameters.interactionSize,
532540
requestParameters.timeout
@@ -541,6 +549,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
541549
requestParameters.llmModel,
542550
requestParameters.llmQuestion,
543551
requestParameters.conversationId,
552+
requestParameters.systemPrompt,
553+
requestParameters.userInstructions,
544554
requestParameters.contextSize,
545555
requestParameters.interactionSize,
546556
requestParameters.timeout
@@ -581,6 +591,8 @@ static class SearchRequestParameters {
581591
String match;
582592
String llmModel;
583593
String llmQuestion;
594+
String systemPrompt;
595+
String userInstructions;
584596
int contextSize;
585597
int interactionSize;
586598
int timeout;

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

+12
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
143143
}
144144
List<String> searchResults = getSearchResults(response, topN);
145145

146+
// See if the prompt is being overridden at the request level.
147+
String effectiveSystemPrompt = systemPrompt;
148+
String effectiveUserInstructions = userInstructions;
149+
if (params.getSystemPrompt() != null) {
150+
effectiveSystemPrompt = params.getSystemPrompt();
151+
}
152+
if (params.getUserInstructions() != null) {
153+
effectiveUserInstructions = params.getUserInstructions();
154+
}
155+
log.info("system_prompt: {}", effectiveSystemPrompt);
156+
log.info("user_instructions: {}", effectiveUserInstructions);
157+
146158
start = Instant.now();
147159
try {
148160
ChatCompletionOutput output = llm

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

+27
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.opensearch.core.xcontent.ToXContentObject;
3030
import org.opensearch.core.xcontent.XContentBuilder;
3131
import org.opensearch.core.xcontent.XContentParser;
32+
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
3233

3334
import com.google.common.base.Preconditions;
3435

@@ -70,13 +71,19 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
7071
// from a remote inference endpoint before timing out the request.
7172
private static final ParseField TIMEOUT = new ParseField("timeout");
7273

74+
private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);
75+
76+
private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);
77+
7378
public static final int SIZE_NULL_VALUE = -1;
7479

7580
static {
7681
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
7782
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
7883
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
7984
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
85+
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
86+
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
8087
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
8188
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
8289
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
@@ -106,10 +113,20 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
106113
@Getter
107114
private Integer timeout;
108115

116+
@Setter
117+
@Getter
118+
private String systemPrompt;
119+
120+
@Setter
121+
@Getter
122+
private String userInstructions;
123+
109124
public GenerativeQAParameters(
110125
String conversationId,
111126
String llmModel,
112127
String llmQuestion,
128+
String systemPrompt,
129+
String userInstructions,
113130
Integer contextSize,
114131
Integer interactionSize,
115132
Integer timeout
@@ -121,6 +138,8 @@ public GenerativeQAParameters(
121138
// for question rewriting.
122139
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
123140
this.llmQuestion = llmQuestion;
141+
this.systemPrompt = systemPrompt;
142+
this.userInstructions = userInstructions;
124143
this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize;
125144
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
126145
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
@@ -130,6 +149,8 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
130149
this.conversationId = input.readOptionalString();
131150
this.llmModel = input.readOptionalString();
132151
this.llmQuestion = input.readString();
152+
this.systemPrompt = input.readOptionalString();
153+
this.userInstructions = input.readOptionalString();
133154
this.contextSize = input.readInt();
134155
this.interactionSize = input.readInt();
135156
this.timeout = input.readInt();
@@ -141,6 +162,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
141162
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
142163
.field(LLM_MODEL.getPreferredName(), this.llmModel)
143164
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
165+
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
166+
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
144167
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
145168
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
146169
.field(TIMEOUT.getPreferredName(), this.timeout);
@@ -153,6 +176,8 @@ public void writeTo(StreamOutput out) throws IOException {
153176

154177
Preconditions.checkNotNull(llmQuestion, "llm_question must not be null.");
155178
out.writeString(llmQuestion);
179+
out.writeOptionalString(systemPrompt);
180+
out.writeOptionalString(userInstructions);
156181
out.writeInt(contextSize);
157182
out.writeInt(interactionSize);
158183
out.writeInt(timeout);
@@ -175,6 +200,8 @@ public boolean equals(Object o) {
175200
return Objects.equals(this.conversationId, other.getConversationId())
176201
&& Objects.equals(this.llmModel, other.getLlmModel())
177202
&& Objects.equals(this.llmQuestion, other.getLlmQuestion())
203+
&& Objects.equals(this.systemPrompt, other.getSystemPrompt())
204+
&& Objects.equals(this.userInstructions, other.getUserInstructions())
178205
&& (this.contextSize == other.getContextSize())
179206
&& (this.interactionSize == other.getInteractionSize())
180207
&& (this.timeout == other.getTimeout());

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java

+40-4
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,16 @@ public void testProcessResponseNoSearchHits() throws Exception {
106106

107107
SearchRequest request = new SearchRequest(); // mock(SearchRequest.class);
108108
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class);
109-
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null);
109+
GenerativeQAParameters params = new GenerativeQAParameters(
110+
"12345",
111+
"llm_model",
112+
"You are kind.",
113+
"system_prompt",
114+
"user_instructions",
115+
null,
116+
null,
117+
null
118+
);
110119
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
111120
extBuilder.setParams(params);
112121
request.source(sourceBuilder);
@@ -170,7 +179,16 @@ public void testProcessResponse() throws Exception {
170179

171180
SearchRequest request = new SearchRequest();
172181
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
173-
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null);
182+
GenerativeQAParameters params = new GenerativeQAParameters(
183+
"12345",
184+
"llm_model",
185+
"You are kind.",
186+
"system_promt",
187+
"user_insturctions",
188+
null,
189+
null,
190+
null
191+
);
174192
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
175193
extBuilder.setParams(params);
176194
request.source(sourceBuilder);
@@ -245,7 +263,16 @@ public void testProcessResponseSmallerContextSize() throws Exception {
245263
SearchRequest request = new SearchRequest();
246264
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
247265
int contextSize = 5;
248-
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", contextSize, null, null);
266+
GenerativeQAParameters params = new GenerativeQAParameters(
267+
"12345",
268+
"llm_model",
269+
"You are kind.",
270+
"system_prompt",
271+
"user_instructions",
272+
contextSize,
273+
null,
274+
null
275+
);
249276
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
250277
extBuilder.setParams(params);
251278
request.source(sourceBuilder);
@@ -319,7 +346,16 @@ public void testProcessResponseMissingContextField() throws Exception {
319346

320347
SearchRequest request = new SearchRequest();
321348
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
322-
GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind.", null, null, null);
349+
GenerativeQAParameters params = new GenerativeQAParameters(
350+
"12345",
351+
"llm_model",
352+
"You are kind.",
353+
"system_prompt",
354+
"user_instructions",
355+
null,
356+
null,
357+
null
358+
);
323359
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
324360
extBuilder.setParams(params);
325361
request.source(sourceBuilder);

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java

+17-8
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,16 @@ public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase {
3939

4040
public void testCtor() throws IOException {
4141
GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder();
42-
GenerativeQAParameters parameters = new GenerativeQAParameters("conversation_id", "model_id", "question", null, null, null);
42+
GenerativeQAParameters parameters = new GenerativeQAParameters(
43+
"conversation_id",
44+
"model_id",
45+
"question",
46+
"system_promtp",
47+
"user_instructions",
48+
null,
49+
null,
50+
null
51+
);
4352
builder.setParams(parameters);
4453
assertEquals(parameters, builder.getParams());
4554

@@ -79,8 +88,8 @@ public int read() throws IOException {
7988
}
8089

8190
public void testMiscMethods() throws IOException {
82-
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null);
83-
GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", null, null, null);
91+
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null);
92+
GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d", "s", "u", null, null, null);
8493
GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder();
8594
GenerativeQAParamExtBuilder builder2 = new GenerativeQAParamExtBuilder();
8695
builder1.setParams(param1);
@@ -92,7 +101,7 @@ public void testMiscMethods() throws IOException {
92101

93102
StreamOutput so = mock(StreamOutput.class);
94103
builder1.writeTo(so);
95-
verify(so, times(2)).writeOptionalString(any());
104+
verify(so, times(4)).writeOptionalString(any());
96105
verify(so, times(1)).writeString(any());
97106
}
98107

@@ -105,7 +114,7 @@ public void testParse() throws IOException {
105114
}
106115

107116
public void testXContentRoundTrip() throws IOException {
108-
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null);
117+
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null);
109118
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
110119
extBuilder.setParams(param1);
111120
XContentType xContentType = randomFrom(XContentType.values());
@@ -120,7 +129,7 @@ public void testXContentRoundTrip() throws IOException {
120129
}
121130

122131
public void testXContentRoundTripAllValues() throws IOException {
123-
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3);
132+
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3);
124133
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
125134
extBuilder.setParams(param1);
126135
XContentType xContentType = randomFrom(XContentType.values());
@@ -131,7 +140,7 @@ public void testXContentRoundTripAllValues() throws IOException {
131140
}
132141

133142
public void testStreamRoundTrip() throws IOException {
134-
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", null, null, null);
143+
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null);
135144
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
136145
extBuilder.setParams(param1);
137146
BytesStreamOutput bso = new BytesStreamOutput();
@@ -145,7 +154,7 @@ public void testStreamRoundTrip() throws IOException {
145154
}
146155

147156
public void testStreamRoundTripAllValues() throws IOException {
148-
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", 1, 2, 3);
157+
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3);
149158
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
150159
extBuilder.setParams(param1);
151160
BytesStreamOutput bso = new BytesStreamOutput();

0 commit comments

Comments
 (0)