29
29
import org .opensearch .core .xcontent .ToXContentObject ;
30
30
import org .opensearch .core .xcontent .XContentBuilder ;
31
31
import org .opensearch .core .xcontent .XContentParser ;
32
+ import org .opensearch .searchpipelines .questionanswering .generative .GenerativeQAProcessorConstants ;
32
33
33
34
import com .google .common .base .Preconditions ;
34
35
@@ -70,13 +71,19 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
70
71
// from a remote inference endpoint before timing out the request.
71
72
private static final ParseField TIMEOUT = new ParseField ("timeout" );
72
73
74
+ private static final ParseField SYSTEM_PROMPT = new ParseField (GenerativeQAProcessorConstants .CONFIG_NAME_SYSTEM_PROMPT );
75
+
76
+ private static final ParseField USER_INSTRUCTIONS = new ParseField (GenerativeQAProcessorConstants .CONFIG_NAME_USER_INSTRUCTIONS );
77
+
73
78
public static final int SIZE_NULL_VALUE = -1 ;
74
79
75
80
static {
76
81
PARSER = new ObjectParser <>("generative_qa_parameters" , GenerativeQAParameters ::new );
77
82
PARSER .declareString (GenerativeQAParameters ::setConversationId , CONVERSATION_ID );
78
83
PARSER .declareString (GenerativeQAParameters ::setLlmModel , LLM_MODEL );
79
84
PARSER .declareString (GenerativeQAParameters ::setLlmQuestion , LLM_QUESTION );
85
+ PARSER .declareStringOrNull (GenerativeQAParameters ::setSystemPrompt , SYSTEM_PROMPT );
86
+ PARSER .declareStringOrNull (GenerativeQAParameters ::setUserInstructions , USER_INSTRUCTIONS );
80
87
PARSER .declareIntOrNull (GenerativeQAParameters ::setContextSize , SIZE_NULL_VALUE , CONTEXT_SIZE );
81
88
PARSER .declareIntOrNull (GenerativeQAParameters ::setInteractionSize , SIZE_NULL_VALUE , INTERACTION_SIZE );
82
89
PARSER .declareIntOrNull (GenerativeQAParameters ::setTimeout , SIZE_NULL_VALUE , TIMEOUT );
@@ -106,10 +113,20 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
106
113
@ Getter
107
114
private Integer timeout ;
108
115
116
+ @ Setter
117
+ @ Getter
118
+ private String systemPrompt ;
119
+
120
+ @ Setter
121
+ @ Getter
122
+ private String userInstructions ;
123
+
109
124
public GenerativeQAParameters (
110
125
String conversationId ,
111
126
String llmModel ,
112
127
String llmQuestion ,
128
+ String systemPrompt ,
129
+ String userInstructions ,
113
130
Integer contextSize ,
114
131
Integer interactionSize ,
115
132
Integer timeout
@@ -121,6 +138,8 @@ public GenerativeQAParameters(
121
138
// for question rewriting.
122
139
Preconditions .checkArgument (!Strings .isNullOrEmpty (llmQuestion ), LLM_QUESTION .getPreferredName () + " must be provided." );
123
140
this .llmQuestion = llmQuestion ;
141
+ this .systemPrompt = systemPrompt ;
142
+ this .userInstructions = userInstructions ;
124
143
this .contextSize = (contextSize == null ) ? SIZE_NULL_VALUE : contextSize ;
125
144
this .interactionSize = (interactionSize == null ) ? SIZE_NULL_VALUE : interactionSize ;
126
145
this .timeout = (timeout == null ) ? SIZE_NULL_VALUE : timeout ;
@@ -130,6 +149,8 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
130
149
this .conversationId = input .readOptionalString ();
131
150
this .llmModel = input .readOptionalString ();
132
151
this .llmQuestion = input .readString ();
152
+ this .systemPrompt = input .readOptionalString ();
153
+ this .userInstructions = input .readOptionalString ();
133
154
this .contextSize = input .readInt ();
134
155
this .interactionSize = input .readInt ();
135
156
this .timeout = input .readInt ();
@@ -141,6 +162,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
141
162
.field (CONVERSATION_ID .getPreferredName (), this .conversationId )
142
163
.field (LLM_MODEL .getPreferredName (), this .llmModel )
143
164
.field (LLM_QUESTION .getPreferredName (), this .llmQuestion )
165
+ .field (SYSTEM_PROMPT .getPreferredName (), this .systemPrompt )
166
+ .field (USER_INSTRUCTIONS .getPreferredName (), this .userInstructions )
144
167
.field (CONTEXT_SIZE .getPreferredName (), this .contextSize )
145
168
.field (INTERACTION_SIZE .getPreferredName (), this .interactionSize )
146
169
.field (TIMEOUT .getPreferredName (), this .timeout );
@@ -153,6 +176,8 @@ public void writeTo(StreamOutput out) throws IOException {
153
176
154
177
Preconditions .checkNotNull (llmQuestion , "llm_question must not be null." );
155
178
out .writeString (llmQuestion );
179
+ out .writeOptionalString (systemPrompt );
180
+ out .writeOptionalString (userInstructions );
156
181
out .writeInt (contextSize );
157
182
out .writeInt (interactionSize );
158
183
out .writeInt (timeout );
@@ -175,6 +200,8 @@ public boolean equals(Object o) {
175
200
return Objects .equals (this .conversationId , other .getConversationId ())
176
201
&& Objects .equals (this .llmModel , other .getLlmModel ())
177
202
&& Objects .equals (this .llmQuestion , other .getLlmQuestion ())
203
+ && Objects .equals (this .systemPrompt , other .getSystemPrompt ())
204
+ && Objects .equals (this .userInstructions , other .getUserInstructions ())
178
205
&& (this .contextSize == other .getContextSize ())
179
206
&& (this .interactionSize == other .getInteractionSize ())
180
207
&& (this .timeout == other .getTimeout ());
0 commit comments