Skip to content

Commit

Permalink
Fix handling of batch signal input payloads (#4374)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjameswh authored Jun 2, 2023
1 parent d619d79 commit b34d21a
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 6 deletions.
37 changes: 33 additions & 4 deletions service/frontend/workflow_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ import (
"go.temporal.io/server/common/persistence/serialization"
"go.temporal.io/server/common/persistence/visibility"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/common/persistence/visibility/store"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/primitives/timestamp"
"go.temporal.io/server/common/rpc"
Expand Down Expand Up @@ -3707,15 +3708,43 @@ func (wh *WorkflowHandler) StartBatchOperation(
}

// Validate concurrent batch operation
maxConcurrentBatchOperation := wh.config.MaxConcurrentBatchOperation(request.GetNamespace())
countResp, err := wh.CountWorkflowExecutions(ctx, &workflowservice.CountWorkflowExecutionsRequest{
Namespace: request.GetNamespace(),
Query: batcher.OpenBatchOperationQuery,
})
if err != nil {
return nil, err
openBatchOperationCount := 0
if err == nil {
openBatchOperationCount = int(countResp.GetCount())
} else {
if !errors.Is(err, store.OperationNotSupportedErr) {
return nil, err
}
// Some std visibility stores don't yet support CountWorkflowExecutions, even though some
// batch operations are still possible on those store (eg. by specyfing a list of Executions
// rather than a VisibilityQuery). Fallback to ListOpenWorkflowExecutions in these cases.
// TODO: Remove this once all std visibility stores support CountWorkflowExecutions.
nextPageToken := []byte{}
for nextPageToken != nil && openBatchOperationCount < maxConcurrentBatchOperation {
listResp, err := wh.ListOpenWorkflowExecutions(ctx, &workflowservice.ListOpenWorkflowExecutionsRequest{
Namespace: request.GetNamespace(),
Filters: &workflowservice.ListOpenWorkflowExecutionsRequest_TypeFilter{
TypeFilter: &filterpb.WorkflowTypeFilter{
Name: batcher.BatchWFTypeName,
},
},
MaximumPageSize: int32(maxConcurrentBatchOperation - openBatchOperationCount),
NextPageToken: nextPageToken,
})
if err != nil {
return nil, err
}
openBatchOperationCount += len(listResp.Executions)
nextPageToken = listResp.NextPageToken
}
}
if countResp.GetCount() >= int64(wh.config.MaxConcurrentBatchOperation(request.GetNamespace())) {
return nil, serviceerror.NewUnavailable("Max concurrent batch operations is reached")
if openBatchOperationCount >= maxConcurrentBatchOperation {
return nil, serviceerror.NewResourceExhausted(enumspb.RESOURCE_EXHAUSTED_CAUSE_CONCURRENT_LIMIT, "Max concurrent batch operations is reached")
}

namespaceID, err := wh.namespaceRegistry.GetNamespaceID(namespace.Name(request.GetNamespace()))
Expand Down
62 changes: 61 additions & 1 deletion service/frontend/workflow_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import (
"go.temporal.io/server/common/payloads"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/common/persistence/visibility/store"
"go.temporal.io/server/common/persistence/visibility/store/elasticsearch"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/primitives/timestamp"
Expand Down Expand Up @@ -2224,7 +2225,7 @@ func (s *workflowHandlerSuite) TestStartBatchOperation_Signal() {
s.NoError(err)
}

func (s *workflowHandlerSuite) TestStartBatchOperation_WorkflowExecutions_Singal() {
func (s *workflowHandlerSuite) TestStartBatchOperation_WorkflowExecutions_Signal() {
testNamespace := namespace.Name("test-namespace")
namespaceID := namespace.ID(uuid.New())
executions := []*commonpb.WorkflowExecution{
Expand Down Expand Up @@ -2289,6 +2290,65 @@ func (s *workflowHandlerSuite) TestStartBatchOperation_WorkflowExecutions_Singal
s.NoError(err)
}

func (s *workflowHandlerSuite) TestStartBatchOperation_WorkflowExecutions_TooMany() {
testNamespace := namespace.Name("test-namespace")
namespaceID := namespace.ID(uuid.New())
executions := []*commonpb.WorkflowExecution{
{
WorkflowId: uuid.New(),
RunId: uuid.New(),
},
}
reason := "reason"
identity := "identity"
config := s.newConfig()
wh := s.getWorkflowHandler(config)
s.mockNamespaceCache.EXPECT().GetNamespaceID(gomock.Any()).Return(namespaceID, nil).AnyTimes()
// Simulate std visibility, which does not support CountWorkflowExecutions
// TODO: remove this once every visibility implementation supports CountWorkflowExecutions
s.mockVisibilityMgr.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, store.OperationNotSupportedErr)
s.mockVisibilityMgr.EXPECT().ListOpenWorkflowExecutionsByType(
gomock.Any(),
gomock.Any(),
).DoAndReturn(
func(
_ context.Context,
request *manager.ListWorkflowExecutionsByTypeRequest,
) (*manager.ListWorkflowExecutionsResponse, error) {
s.Equal(testNamespace, request.Namespace)
s.Equal(batcher.BatchWFTypeName, request.WorkflowTypeName)
s.Equal(int(config.MaxConcurrentBatchOperation(testNamespace.String())), request.PageSize)
s.Equal([]byte{}, request.NextPageToken)
return &manager.ListWorkflowExecutionsResponse{
Executions: []*workflowpb.WorkflowExecutionInfo{
{
Execution: &commonpb.WorkflowExecution{
WorkflowId: testWorkflowID,
RunId: testRunID,
},
},
},
NextPageToken: nil,
}, nil
},
)

request := &workflowservice.StartBatchOperationRequest{
Namespace: testNamespace.String(),
JobId: uuid.New(),
Operation: &workflowservice.StartBatchOperationRequest_CancellationOperation{
CancellationOperation: &batchpb.BatchOperationCancellation{
Identity: identity,
},
},
Reason: reason,
Executions: executions,
}

_, err := wh.StartBatchOperation(context.Background(), request)
s.EqualError(err, "Max concurrent batch operations is reached")
}

func (s *workflowHandlerSuite) TestStartBatchOperation_InvalidRequest() {
request := &workflowservice.StartBatchOperationRequest{
Namespace: "",
Expand Down
11 changes: 10 additions & 1 deletion service/worker/batcher/activities.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,16 @@ func startTaskProcessor(
case BatchTypeSignal:
err = processTask(ctx, limiter, task,
func(workflowID, runID string) error {
return sdkClient.SignalWorkflow(ctx, workflowID, runID, batchParams.SignalParams.SignalName, batchParams.SignalParams.Input)
_, err := frontendClient.SignalWorkflowExecution(ctx, &workflowservice.SignalWorkflowExecutionRequest{
Namespace: batchParams.Namespace,
WorkflowExecution: &commonpb.WorkflowExecution{
WorkflowId: workflowID,
RunId: runID,
},
SignalName: batchParams.SignalParams.SignalName,
Input: batchParams.SignalParams.Input,
})
return err
})
case BatchTypeDelete:
err = processTask(ctx, limiter, task,
Expand Down
55 changes: 55 additions & 0 deletions tests/client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/pborman/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.temporal.io/api/batch/v1"
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
historypb "go.temporal.io/api/history/v1"
Expand Down Expand Up @@ -1514,3 +1515,57 @@ func (s *clientIntegrationSuite) assertHistory(wid, rid string, expected []enums

s.Equal(expected, events)
}

func (s *clientIntegrationSuite) TestBatchSignal() {

type myData struct {
Stuff string
Things []int
}

workflowFn := func(ctx workflow.Context) (myData, error) {
var receivedData myData
workflow.GetSignalChannel(ctx, "my-signal").Receive(ctx, &receivedData)
return receivedData, nil
}
s.worker.RegisterWorkflow(workflowFn)

workflowRun, err := s.sdkClient.ExecuteWorkflow(context.Background(), sdkclient.StartWorkflowOptions{
ID: uuid.New(),
TaskQueue: s.taskQueue,
WorkflowExecutionTimeout: 10 * time.Second,
}, workflowFn)
s.NoError(err)

input1 := myData{
Stuff: "here's some data",
Things: []int{7, 8, 9},
}
inputPayloads, err := converter.GetDefaultDataConverter().ToPayloads(input1)
s.NoError(err)

_, err = s.sdkClient.WorkflowService().StartBatchOperation(context.Background(), &workflowservice.StartBatchOperationRequest{
Namespace: s.namespace,
Operation: &workflowservice.StartBatchOperationRequest_SignalOperation{
SignalOperation: &batch.BatchOperationSignal{
Signal: "my-signal",
Input: inputPayloads,
},
},
Executions: []*commonpb.WorkflowExecution{
{
WorkflowId: workflowRun.GetID(),
RunId: workflowRun.GetRunID(),
},
},
JobId: uuid.New(),
Reason: "test",
})
s.NoError(err)

var returnedData myData
err = workflowRun.Get(context.Background(), &returnedData)
s.NoError(err)

s.Equal(input1, returnedData)
}

0 comments on commit b34d21a

Please sign in to comment.