Skip to content

Commit

Permalink
re-add feature flag checks and tests to transport layer
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Aug 31, 2023
1 parent d0b3670 commit 6927c45
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -38,6 +41,7 @@ public class CreateConversationTransportAction extends HandledTransportAction<Cr

private ConversationalMemoryHandler cmHandler;
private Client client;
private ClusterService clusterService;

/**
* Constructor
Expand All @@ -51,15 +55,27 @@ public CreateConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(CreateConversationAction.NAME, transportService, actionFilters, CreateConversationRequest::new);
this.cmHandler = cmHandler;
this.client = client;
this.clusterService = clusterService;
}

@Override
protected void doExecute(Task task, CreateConversationRequest request, ActionListener<CreateConversationResponse> actionListener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
)
);
return;
}
String name = request.getName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateConversationResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -38,6 +41,7 @@ public class CreateInteractionTransportAction extends HandledTransportAction<Cre

private ConversationalMemoryHandler cmHandler;
private Client client;
private ClusterService clusterService;

/**
* Constructor
Expand All @@ -51,15 +55,27 @@ public CreateInteractionTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(CreateInteractionAction.NAME, transportService, actionFilters, CreateInteractionRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.clusterService = clusterService;
}

@Override
protected void doExecute(Task task, CreateInteractionRequest request, ActionListener<CreateInteractionResponse> actionListener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
)
);
return;
}
String cid = request.getConversationId();
String inp = request.getInput();
String rsp = request.getResponse();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -38,6 +41,7 @@ public class DeleteConversationTransportAction extends HandledTransportAction<De

private Client client;
private ConversationalMemoryHandler cmHandler;
private ClusterService clusterService;

/**
* Constructor
Expand All @@ -51,15 +55,27 @@ public DeleteConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(DeleteConversationAction.NAME, transportService, actionFilters, DeleteConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.clusterService = clusterService;
}

@Override
public void doExecute(Task task, DeleteConversationRequest request, ActionListener<DeleteConversationResponse> listener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
listener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
)
);
return;
}
String conversationId = request.getId();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<DeleteConversationResponse> internalListener = ActionListener.runBefore(listener, () -> context.restore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@

import java.util.List;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand All @@ -41,6 +44,7 @@ public class GetConversationsTransportAction extends HandledTransportAction<GetC

private Client client;
private ConversationalMemoryHandler cmHandler;
private ClusterService clusterService;

/**
* Constructor
Expand All @@ -54,15 +58,27 @@ public GetConversationsTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(GetConversationsAction.NAME, transportService, actionFilters, GetConversationsRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.clusterService = clusterService;
}

@Override
public void doExecute(Task task, GetConversationsRequest request, ActionListener<GetConversationsResponse> actionListener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
)
);
return;
}
int maxResults = request.getMaxResults();
int from = request.getFrom();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

import java.util.List;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
Expand All @@ -41,6 +44,7 @@ public class GetInteractionsTransportAction extends HandledTransportAction<GetIn

private Client client;
private ConversationalMemoryHandler cmHandler;
private ClusterService clusterService;

/**
* Constructor
Expand All @@ -54,15 +58,27 @@ public GetInteractionsTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
Client client,
ClusterService clusterService
) {
super(GetInteractionsAction.NAME, transportService, actionFilters, GetInteractionsRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.clusterService = clusterService;
}

@Override
public void doExecute(Task task, GetInteractionsRequest request, ActionListener<GetInteractionsResponse> actionListener) {
if (!clusterService.getSettings().getAsBoolean(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, false)) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME
)
);
return;
}
int maxResults = request.getMaxResults();
int from = request.getFrom();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -89,12 +90,13 @@ public void setup() throws IOException {
this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class);

this.request = new CreateConversationRequest("test");
this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client));
this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

Settings settings = Settings.builder().build();
Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build();
this.threadContext = new ThreadContext(settings);
when(this.client.threadPool()).thenReturn(this.threadPool);
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
when(this.clusterService.getSettings()).thenReturn(settings);
}

public void testCreateConversation() {
Expand Down Expand Up @@ -144,4 +146,12 @@ public void testDoExecuteFails_thenFail() {
assert (argCaptor.getValue().getMessage().equals("Test doExecute Error"));
}

public void testFeatureDisabled_ThenFail() {
when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY);
action.doExecute(null, request, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled."));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -89,12 +90,13 @@ public void setup() throws IOException {
this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class);

this.request = new CreateInteractionRequest("test-cid", "input", "pt", "response", "origin", "metadata");
this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client));
this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

Settings settings = Settings.builder().build();
Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build();
this.threadContext = new ThreadContext(settings);
when(this.client.threadPool()).thenReturn(this.threadPool);
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
when(this.clusterService.getSettings()).thenReturn(settings);
}

public void testCreateInteraction() {
Expand Down Expand Up @@ -134,4 +136,12 @@ public void testDoExecuteFails_thenFail() {
assert (argCaptor.getValue().getMessage().equals("Failure in doExecute"));
}

public void testFeatureDisabled_ThenFail() {
when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY);
action.doExecute(null, request, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled."));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -89,12 +90,13 @@ public void setup() throws IOException {
this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class);

this.request = new DeleteConversationRequest("test");
this.action = spy(new DeleteConversationTransportAction(transportService, actionFilters, cmHandler, client));
this.action = spy(new DeleteConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService));

Settings settings = Settings.builder().build();
Settings settings = Settings.builder().put(ConversationalIndexConstants.MEMORY_FEATURE_FLAG_NAME, true).build();
this.threadContext = new ThreadContext(settings);
when(this.client.threadPool()).thenReturn(this.threadPool);
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
when(this.clusterService.getSettings()).thenReturn(settings);
}

public void testDeleteConversation() {
Expand Down Expand Up @@ -130,4 +132,12 @@ public void testdoExecuteFails_thenFail() {
assert (argCaptor.getValue().getMessage().equals("Test doExecute Error"));
}

public void testFeatureDisabled_ThenFail() {
when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY);
action.doExecute(null, request, actionListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled."));
}

}
Loading

0 comments on commit 6927c45

Please sign in to comment.