diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 072c6e714aa..4f90ce4afaf 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -77,6 +77,7 @@ import ( "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/visibility" "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/persistence/visibility/store" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/rpc" @@ -3707,15 +3708,43 @@ func (wh *WorkflowHandler) StartBatchOperation( } // Validate concurrent batch operation + maxConcurrentBatchOperation := wh.config.MaxConcurrentBatchOperation(request.GetNamespace()) countResp, err := wh.CountWorkflowExecutions(ctx, &workflowservice.CountWorkflowExecutionsRequest{ Namespace: request.GetNamespace(), Query: batcher.OpenBatchOperationQuery, }) - if err != nil { - return nil, err + openBatchOperationCount := 0 + if err == nil { + openBatchOperationCount = int(countResp.GetCount()) + } else { + if !errors.Is(err, store.OperationNotSupportedErr) { + return nil, err + } + // Some std visibility stores don't yet support CountWorkflowExecutions, even though some + // batch operations are still possible on those store (eg. by specyfing a list of Executions + // rather than a VisibilityQuery). Fallback to ListOpenWorkflowExecutions in these cases. + // TODO: Remove this once all std visibility stores support CountWorkflowExecutions. + nextPageToken := []byte{} + for nextPageToken != nil && openBatchOperationCount < maxConcurrentBatchOperation { + listResp, err := wh.ListOpenWorkflowExecutions(ctx, &workflowservice.ListOpenWorkflowExecutionsRequest{ + Namespace: request.GetNamespace(), + Filters: &workflowservice.ListOpenWorkflowExecutionsRequest_TypeFilter{ + TypeFilter: &filterpb.WorkflowTypeFilter{ + Name: batcher.BatchWFTypeName, + }, + }, + MaximumPageSize: int32(maxConcurrentBatchOperation - openBatchOperationCount), + NextPageToken: nextPageToken, + }) + if err != nil { + return nil, err + } + openBatchOperationCount += len(listResp.Executions) + nextPageToken = listResp.NextPageToken + } } - if countResp.GetCount() >= int64(wh.config.MaxConcurrentBatchOperation(request.GetNamespace())) { - return nil, serviceerror.NewUnavailable("Max concurrent batch operations is reached") + if openBatchOperationCount >= maxConcurrentBatchOperation { + return nil, serviceerror.NewResourceExhausted(enumspb.RESOURCE_EXHAUSTED_CAUSE_CONCURRENT_LIMIT, "Max concurrent batch operations is reached") } namespaceID, err := wh.namespaceRegistry.GetNamespaceID(namespace.Name(request.GetNamespace())) diff --git a/service/frontend/workflow_handler_test.go b/service/frontend/workflow_handler_test.go index 8a2aadece1c..14ed361773d 100644 --- a/service/frontend/workflow_handler_test.go +++ b/service/frontend/workflow_handler_test.go @@ -72,6 +72,7 @@ import ( "go.temporal.io/server/common/payloads" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/persistence/visibility/store" "go.temporal.io/server/common/persistence/visibility/store/elasticsearch" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/primitives/timestamp" @@ -2224,7 +2225,7 @@ func (s *workflowHandlerSuite) TestStartBatchOperation_Signal() { s.NoError(err) } -func (s *workflowHandlerSuite) TestStartBatchOperation_WorkflowExecutions_Singal() { +func (s *workflowHandlerSuite) TestStartBatchOperation_WorkflowExecutions_Signal() { testNamespace := namespace.Name("test-namespace") namespaceID := namespace.ID(uuid.New()) executions := []*commonpb.WorkflowExecution{ @@ -2289,6 +2290,65 @@ func (s *workflowHandlerSuite) TestStartBatchOperation_WorkflowExecutions_Singal s.NoError(err) } +func (s *workflowHandlerSuite) TestStartBatchOperation_WorkflowExecutions_TooMany() { + testNamespace := namespace.Name("test-namespace") + namespaceID := namespace.ID(uuid.New()) + executions := []*commonpb.WorkflowExecution{ + { + WorkflowId: uuid.New(), + RunId: uuid.New(), + }, + } + reason := "reason" + identity := "identity" + config := s.newConfig() + wh := s.getWorkflowHandler(config) + s.mockNamespaceCache.EXPECT().GetNamespaceID(gomock.Any()).Return(namespaceID, nil).AnyTimes() + // Simulate std visibility, which does not support CountWorkflowExecutions + // TODO: remove this once every visibility implementation supports CountWorkflowExecutions + s.mockVisibilityMgr.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, store.OperationNotSupportedErr) + s.mockVisibilityMgr.EXPECT().ListOpenWorkflowExecutionsByType( + gomock.Any(), + gomock.Any(), + ).DoAndReturn( + func( + _ context.Context, + request *manager.ListWorkflowExecutionsByTypeRequest, + ) (*manager.ListWorkflowExecutionsResponse, error) { + s.Equal(testNamespace, request.Namespace) + s.Equal(batcher.BatchWFTypeName, request.WorkflowTypeName) + s.Equal(int(config.MaxConcurrentBatchOperation(testNamespace.String())), request.PageSize) + s.Equal([]byte{}, request.NextPageToken) + return &manager.ListWorkflowExecutionsResponse{ + Executions: []*workflowpb.WorkflowExecutionInfo{ + { + Execution: &commonpb.WorkflowExecution{ + WorkflowId: testWorkflowID, + RunId: testRunID, + }, + }, + }, + NextPageToken: nil, + }, nil + }, + ) + + request := &workflowservice.StartBatchOperationRequest{ + Namespace: testNamespace.String(), + JobId: uuid.New(), + Operation: &workflowservice.StartBatchOperationRequest_CancellationOperation{ + CancellationOperation: &batchpb.BatchOperationCancellation{ + Identity: identity, + }, + }, + Reason: reason, + Executions: executions, + } + + _, err := wh.StartBatchOperation(context.Background(), request) + s.EqualError(err, "Max concurrent batch operations is reached") +} + func (s *workflowHandlerSuite) TestStartBatchOperation_InvalidRequest() { request := &workflowservice.StartBatchOperationRequest{ Namespace: "", diff --git a/service/worker/batcher/activities.go b/service/worker/batcher/activities.go index 78543cd9fb5..09dc4635710 100644 --- a/service/worker/batcher/activities.go +++ b/service/worker/batcher/activities.go @@ -244,7 +244,16 @@ func startTaskProcessor( case BatchTypeSignal: err = processTask(ctx, limiter, task, func(workflowID, runID string) error { - return sdkClient.SignalWorkflow(ctx, workflowID, runID, batchParams.SignalParams.SignalName, batchParams.SignalParams.Input) + _, err := frontendClient.SignalWorkflowExecution(ctx, &workflowservice.SignalWorkflowExecutionRequest{ + Namespace: batchParams.Namespace, + WorkflowExecution: &commonpb.WorkflowExecution{ + WorkflowId: workflowID, + RunId: runID, + }, + SignalName: batchParams.SignalParams.SignalName, + Input: batchParams.SignalParams.Input, + }) + return err }) case BatchTypeDelete: err = processTask(ctx, limiter, task, diff --git a/tests/client_integration_test.go b/tests/client_integration_test.go index 82fc7d254cd..5262b6e9e38 100644 --- a/tests/client_integration_test.go +++ b/tests/client_integration_test.go @@ -40,6 +40,7 @@ import ( "github.com/pborman/uuid" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "go.temporal.io/api/batch/v1" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" historypb "go.temporal.io/api/history/v1" @@ -1514,3 +1515,57 @@ func (s *clientIntegrationSuite) assertHistory(wid, rid string, expected []enums s.Equal(expected, events) } + +func (s *clientIntegrationSuite) TestBatchSignal() { + + type myData struct { + Stuff string + Things []int + } + + workflowFn := func(ctx workflow.Context) (myData, error) { + var receivedData myData + workflow.GetSignalChannel(ctx, "my-signal").Receive(ctx, &receivedData) + return receivedData, nil + } + s.worker.RegisterWorkflow(workflowFn) + + workflowRun, err := s.sdkClient.ExecuteWorkflow(context.Background(), sdkclient.StartWorkflowOptions{ + ID: uuid.New(), + TaskQueue: s.taskQueue, + WorkflowExecutionTimeout: 10 * time.Second, + }, workflowFn) + s.NoError(err) + + input1 := myData{ + Stuff: "here's some data", + Things: []int{7, 8, 9}, + } + inputPayloads, err := converter.GetDefaultDataConverter().ToPayloads(input1) + s.NoError(err) + + _, err = s.sdkClient.WorkflowService().StartBatchOperation(context.Background(), &workflowservice.StartBatchOperationRequest{ + Namespace: s.namespace, + Operation: &workflowservice.StartBatchOperationRequest_SignalOperation{ + SignalOperation: &batch.BatchOperationSignal{ + Signal: "my-signal", + Input: inputPayloads, + }, + }, + Executions: []*commonpb.WorkflowExecution{ + { + WorkflowId: workflowRun.GetID(), + RunId: workflowRun.GetRunID(), + }, + }, + JobId: uuid.New(), + Reason: "test", + }) + s.NoError(err) + + var returnedData myData + err = workflowRun.Get(context.Background(), &returnedData) + s.NoError(err) + + s.Equal(input1, returnedData) +}