diff --git a/common/collection/indexedtakelist.go b/common/collection/indexedtakelist.go new file mode 100644 index 00000000000..772ebb093f3 --- /dev/null +++ b/common/collection/indexedtakelist.go @@ -0,0 +1,86 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package collection + +// IndexedTakeList holds a set of values that can only be observed by being +// removed from the set. It is possible for this set to contain duplicate values +// as long as each value maps to a distinct index. +type ( + IndexedTakeList[K comparable, V any] struct { + values []kv[K, V] + } + + kv[K comparable, V any] struct { + key K + value V + removed bool + } +) + +// NewIndexedTakeList constructs a new IndexedTakeSet by applying the provided +// indexer to each of the provided values. +func NewIndexedTakeList[K comparable, V any]( + values []V, + indexer func(V) K, +) *IndexedTakeList[K, V] { + ret := &IndexedTakeList[K, V]{ + values: make([]kv[K, V], 0, len(values)), + } + for _, v := range values { + ret.values = append(ret.values, kv[K, V]{key: indexer(v), value: v}) + } + return ret +} + +// Take finds a value in this set by its key and removes it, returning the +// value. +func (itl *IndexedTakeList[K, V]) Take(key K) (V, bool) { + var zero V + for i := 0; i < len(itl.values); i++ { + kv := &itl.values[i] + if kv.key != key { + continue + } + if kv.removed { + return zero, false + } + kv.removed = true + return kv.value, true + } + return zero, false +} + +// TakeRemaining removes all remaining values from this set and returns them. +func (itl *IndexedTakeList[K, V]) TakeRemaining() []V { + out := make([]V, 0, len(itl.values)) + for i := 0; i < len(itl.values); i++ { + kv := &itl.values[i] + if !kv.removed { + out = append(out, kv.value) + } + } + itl.values = nil + return out +} diff --git a/common/collection/indexedtakelist_test.go b/common/collection/indexedtakelist_test.go new file mode 100644 index 00000000000..3ba6b463b25 --- /dev/null +++ b/common/collection/indexedtakelist_test.go @@ -0,0 +1,71 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package collection_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/collection" +) + +func TestIndexedTakeList(t *testing.T) { + t.Parallel() + + type Value struct{ ID int } + values := []Value{{ID: 0}, {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}} + indexer := func(v Value) int { return v.ID } + + list := collection.NewIndexedTakeList(values, indexer) + + t.Run("take absent", func(t *testing.T) { + _, ok := list.Take(999) + require.False(t, ok) + }) + + t.Run("take present", func(t *testing.T) { + v, ok := list.Take(3) + require.True(t, ok) + require.Equal(t, 3, v.ID) + }) + + t.Run("take taken", func(t *testing.T) { + _, ok := list.Take(3) + require.False(t, ok) + }) + + t.Run("take remaining", func(t *testing.T) { + allowed := []int{0, 1, 2, 4} + remaining := list.TakeRemaining() + require.Len(t, remaining, len(allowed)) + for _, v := range remaining { + require.Contains(t, allowed, v.ID) + } + }) + + t.Run("now empty", func(t *testing.T) { + require.Empty(t, list.TakeRemaining()) + }) +} diff --git a/common/metrics/metric_defs.go b/common/metrics/metric_defs.go index 8e7488f84b1..7b1715eab43 100644 --- a/common/metrics/metric_defs.go +++ b/common/metrics/metric_defs.go @@ -1358,6 +1358,7 @@ var ( CommandTypeUpsertWorkflowSearchAttributesCounter = NewCounterDef("upsert_workflow_search_attributes_command") CommandTypeModifyWorkflowPropertiesCounter = NewCounterDef("modify_workflow_properties_command") CommandTypeChildWorkflowCounter = NewCounterDef("child_workflow_command") + CommandTypeProtocolMessage = NewCounterDef("protocol_message_command") MessageTypeRequestWorkflowExecutionUpdateCounter = NewCounterDef("request_workflow_update_message") MessageTypeAcceptWorkflowExecutionUpdateCounter = NewCounterDef("accept_workflow_update_message") MessageTypeRespondWorkflowExecutionUpdateCounter = NewCounterDef("respond_workflow_update_message") diff --git a/internal/protocol/naming.go b/internal/protocol/naming.go new file mode 100644 index 00000000000..9cd890ae3d8 --- /dev/null +++ b/internal/protocol/naming.go @@ -0,0 +1,99 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package protocol + +import ( + "fmt" + "strings" + + "github.com/gogo/protobuf/types" + protocolpb "go.temporal.io/api/protocol/v1" +) + +type ( + // Type is a protocol type of which there may be many instances like the + // update protocol or the query protocol. + Type string + + // MessageType is the type of a message within a protocol. + MessageType string + + constErr string +) + +const ( + MessageTypeUnknown = MessageType("__message_type_unknown") + TypeUnknown = Type("__protocol_type_unknown") + + errNilMsg = constErr("nil message") + errNilBody = constErr("nil message body") + errProtoFmt = constErr("failed to extract protocol type") +) + +// String transforms a MessageType into a string +func (mt MessageType) String() string { + return string(mt) +} + +// String tranforms a Type into a string +func (pt Type) String() string { + return string(pt) +} + +// Identify is a function that given a protocol message gives the specific +// message type of the body and the type of the protocol to which this message +// belongs. +func Identify(msg *protocolpb.Message) (Type, MessageType, error) { + if msg == nil { + return TypeUnknown, MessageTypeUnknown, errNilMsg + } + if msg.Body == nil { + return TypeUnknown, MessageTypeUnknown, errNilBody + } + bodyTypeName, err := types.AnyMessageName(msg.Body) + if err != nil { + return TypeUnknown, MessageTypeUnknown, err + } + msgType := MessageType(bodyTypeName) + + lastDot := strings.LastIndex(bodyTypeName, ".") + if lastDot < 0 { + err := fmt.Errorf("%w: no . found in %q", errProtoFmt, bodyTypeName) + return TypeUnknown, msgType, err + } + return Type(bodyTypeName[0:lastDot]), msgType, nil +} + +// IdentifyOrUnknown wraps Identify to return TypeUnknown and/or +// MessageTypeUnknown in the case where either one cannot be determined due to +// an error. +func IdentifyOrUnknown(msg *protocolpb.Message) (Type, MessageType) { + pt, mt, _ := Identify(msg) + return pt, mt +} + +func (cerr constErr) Error() string { + return string(cerr) +} diff --git a/internal/protocol/naming_test.go b/internal/protocol/naming_test.go new file mode 100644 index 00000000000..6dcca8b2d1f --- /dev/null +++ b/internal/protocol/naming_test.go @@ -0,0 +1,77 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package protocol_test + +import ( + "testing" + + "github.com/gogo/protobuf/types" + "github.com/stretchr/testify/require" + protocolpb "go.temporal.io/api/protocol/v1" + "go.temporal.io/server/internal/protocol" +) + +func TestNilSafety(t *testing.T) { + t.Parallel() + + t.Run("nil message", func(t *testing.T) { + pt, mt := protocol.IdentifyOrUnknown(nil) + require.Equal(t, protocol.TypeUnknown, pt) + require.Equal(t, protocol.MessageTypeUnknown, mt) + }) + + t.Run("nil message body", func(t *testing.T) { + pt, mt := protocol.IdentifyOrUnknown(&protocolpb.Message{}) + require.Equal(t, protocol.TypeUnknown, pt) + require.Equal(t, protocol.MessageTypeUnknown, mt) + }) +} + +func TestWithValidMessage(t *testing.T) { + t.Parallel() + + body, err := types.MarshalAny(&types.Empty{}) + require.NoError(t, err) + + msg := protocolpb.Message{Body: body} + + pt, mt := protocol.IdentifyOrUnknown(&msg) + + require.Equal(t, "google.protobuf", pt.String()) + require.Equal(t, "google.protobuf.Empty", mt.String()) +} + +func TestWithInvalidBody(t *testing.T) { + t.Parallel() + + body, err := types.MarshalAny(&types.Empty{}) + require.NoError(t, err) + + msg := protocolpb.Message{Body: body} + msg.Body.TypeUrl = "this isn't valid" + + _, _, err = protocol.Identify(&msg) + require.Error(t, err) +} diff --git a/service/history/commandChecker.go b/service/history/commandChecker.go index 17188310515..aa937050e5a 100644 --- a/service/history/commandChecker.go +++ b/service/history/commandChecker.go @@ -271,6 +271,24 @@ func (c *workflowSizeChecker) checkIfSearchAttributesSizeExceedsLimit( return err } +func (v *commandAttrValidator) validateProtocolMessageAttributes( + namespaceID namespace.ID, + attributes *commandpb.ProtocolMessageCommandAttributes, + runTimeout time.Duration, +) (enumspb.WorkflowTaskFailedCause, error) { + const failedCause = enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE + + if attributes == nil { + return failedCause, serviceerror.NewInvalidArgument("ProtocolMessageCommandAttributes is not set on command.") + } + + if attributes.MessageId == "" { + return failedCause, serviceerror.NewInvalidArgument("MessageID is not set on command.") + } + + return enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNSPECIFIED, nil +} + func (v *commandAttrValidator) validateActivityScheduleAttributes( namespaceID namespace.ID, attributes *commandpb.ScheduleActivityTaskCommandAttributes, diff --git a/service/history/workflow/update/version.go b/service/history/workflow/update/version.go index 3cdadaa3276..55ca0c3292e 100644 --- a/service/history/workflow/update/version.go +++ b/service/history/workflow/update/version.go @@ -24,4 +24,6 @@ package update -const ProtocolV1 = "temporal.api.update.v1" +import "go.temporal.io/server/internal/protocol" + +const ProtocolV1 = protocol.Type("temporal.api.update.v1") diff --git a/service/history/workflowTaskHandler.go b/service/history/workflowTaskHandler.go index fef4619ff52..90684d2b0e9 100644 --- a/service/history/workflowTaskHandler.go +++ b/service/history/workflowTaskHandler.go @@ -27,10 +27,8 @@ package history import ( "context" "fmt" - "strings" "time" - "github.com/gogo/protobuf/types" "github.com/pborman/uuid" commandpb "go.temporal.io/api/command/v1" commonpb "go.temporal.io/api/common/v1" @@ -40,12 +38,14 @@ import ( "go.temporal.io/api/serviceerror" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/internal/effect" + "go.temporal.io/server/internal/protocol" "go.temporal.io/server/service/history/workflow/update" "go.temporal.io/server/api/historyservice/v1" tokenspb "go.temporal.io/server/api/token/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" + "go.temporal.io/server/common/collection" "go.temporal.io/server/common/enums" "go.temporal.io/server/common/failure" "go.temporal.io/server/common/log" @@ -161,7 +161,7 @@ func newWorkflowTaskHandler( func (handler *workflowTaskHandlerImpl) handleCommands( ctx context.Context, commands []*commandpb.Command, - earlyDeliverMessages func(context.Context) error, + msgs *collection.IndexedTakeList[string, *protocolpb.Message], ) ([]workflowTaskResponseMutation, error) { if err := handler.attrValidator.validateCommandSequence( commands, @@ -172,7 +172,7 @@ func (handler *workflowTaskHandlerImpl) handleCommands( var mutations []workflowTaskResponseMutation var postActions []commandPostAction for _, command := range commands { - response, err := handler.handleCommand(ctx, command, earlyDeliverMessages) + response, err := handler.handleCommand(ctx, command, msgs) if err != nil || handler.stopProcessing { return nil, err } @@ -186,6 +186,15 @@ func (handler *workflowTaskHandlerImpl) handleCommands( } } + if handler.mutableState.IsWorkflowExecutionRunning() { + for _, msg := range msgs.TakeRemaining() { + err := handler.handleMessage(ctx, msg) + if err != nil || handler.stopProcessing { + return nil, err + } + } + } + for _, postAction := range postActions { mutation, err := postAction(ctx) if err != nil || handler.stopProcessing { @@ -203,17 +212,14 @@ func (handler *workflowTaskHandlerImpl) handleCommands( func (handler *workflowTaskHandlerImpl) handleCommand( ctx context.Context, command *commandpb.Command, - earlyDeliverMessages func(context.Context) error, + msgs *collection.IndexedTakeList[string, *protocolpb.Message], ) (*handleCommandResponse, error) { switch command.GetCommandType() { case enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK: return handler.handleCommandScheduleActivity(ctx, command.GetScheduleActivityTaskCommandAttributes()) case enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION: - if err := earlyDeliverMessages(ctx); err != nil { - return nil, err - } - return nil, handler.handleCommandCompleteWorkflow(ctx, command.GetCompleteWorkflowExecutionCommandAttributes()) + return nil, handler.handleCommandCompleteWorkflow(ctx, command.GetCompleteWorkflowExecutionCommandAttributes(), msgs) case enumspb.COMMAND_TYPE_FAIL_WORKFLOW_EXECUTION: return nil, handler.handleCommandFailWorkflow(ctx, command.GetFailWorkflowExecutionCommandAttributes()) @@ -252,62 +258,33 @@ func (handler *workflowTaskHandlerImpl) handleCommand( return nil, handler.handleCommandModifyWorkflowProperties(ctx, command.GetModifyWorkflowPropertiesCommandAttributes()) case enumspb.COMMAND_TYPE_PROTOCOL_MESSAGE: - return nil, nil + return nil, handler.handleCommandProtocolMessage(ctx, command.GetProtocolMessageCommandAttributes(), msgs) default: return nil, serviceerror.NewInvalidArgument(fmt.Sprintf("Unknown command type: %v", command.GetCommandType())) } } -func (handler *workflowTaskHandlerImpl) handleMessages( - ctx context.Context, - messages []*protocolpb.Message, -) error { - if !handler.mutableState.IsWorkflowExecutionRunning() { - // Workflow might get completed within the same WT after processing corresponding command. - // Currently, messages are used to transport updates only and updates should not be processed after workflow is completed. - // Therefore, just ignore all messages. - // If later, there are messages which can be processed after workflow is completed, - // then this check should be moved down to specific message handler. - return nil - } - - for _, message := range messages { - err := handler.handleMessage(ctx, message, handler.updateRegistry) - if err != nil || handler.stopProcessing { - return err - } - } - - return nil -} - func (handler *workflowTaskHandlerImpl) handleMessage( ctx context.Context, message *protocolpb.Message, - updateRegistry update.Registry, ) error { - bodyTypeName, err := types.AnyMessageName(message.GetBody()) + protocolType, msgType, err := protocol.Identify(message) if err != nil { return serviceerror.NewInvalidArgument(err.Error()) } if err := handler.sizeLimitChecker.checkIfPayloadSizeExceedsLimit( // TODO (alex-update): Should use MessageTypeTag here but then it needs to be another metric name too. - metrics.CommandTypeTag(bodyTypeName), + metrics.CommandTypeTag(msgType.String()), message.Body.Size(), - fmt.Sprintf("Message type %v exceeds size limit.", bodyTypeName), + fmt.Sprintf("Message type %v exceeds size limit.", msgType), ); err != nil { return handler.failWorkflow(enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, err) } - protocolTypeName := "__unknown" - if lastDot := strings.LastIndex(bodyTypeName, "."); lastDot > -1 { - protocolTypeName = bodyTypeName[0:lastDot] - } - - switch protocolTypeName { + switch protocolType { case update.ProtocolV1: - upd, ok := updateRegistry.Find(ctx, message.ProtocolInstanceId) + upd, ok := handler.updateRegistry.Find(ctx, message.ProtocolInstanceId) if !ok { return handler.failWorkflowTask( enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, @@ -320,12 +297,43 @@ func (handler *workflowTaskHandlerImpl) handleMessage( default: return handler.failWorkflowTask( enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, - serviceerror.NewInvalidArgument(fmt.Sprintf("unsupported protocol type %q", protocolTypeName))) + serviceerror.NewInvalidArgument(fmt.Sprintf("unsupported protocol type %q", protocolType))) } return nil } +func (handler *workflowTaskHandlerImpl) handleCommandProtocolMessage( + ctx context.Context, + attr *commandpb.ProtocolMessageCommandAttributes, + msgs *collection.IndexedTakeList[string, *protocolpb.Message], +) error { + handler.metricsHandler.Counter(metrics.CommandTypeProtocolMessage.GetMetricName()).Record(1) + + executionInfo := handler.mutableState.GetExecutionInfo() + namespaceID := namespace.ID(executionInfo.NamespaceId) + + if err := handler.validateCommandAttr( + func() (enumspb.WorkflowTaskFailedCause, error) { + return handler.attrValidator.validateProtocolMessageAttributes( + namespaceID, + attr, + timestamp.DurationValue(executionInfo.WorkflowRunTimeout), + ) + }, + ); err != nil || handler.stopProcessing { + return err + } + + if msg, ok := msgs.Take(attr.MessageId); ok { + return handler.handleMessage(ctx, msg) + } + return handler.failWorkflowTask( + enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, + serviceerror.NewInvalidArgument(fmt.Sprintf("ProtocolMessageCommand referenced absent message ID %q", attr.MessageId)), + ) +} + func (handler *workflowTaskHandlerImpl) handleCommandScheduleActivity( _ context.Context, attr *commandpb.ScheduleActivityTaskCommandAttributes, @@ -544,8 +552,16 @@ func (handler *workflowTaskHandlerImpl) handleCommandStartTimer( func (handler *workflowTaskHandlerImpl) handleCommandCompleteWorkflow( ctx context.Context, attr *commandpb.CompleteWorkflowExecutionCommandAttributes, + msgs *collection.IndexedTakeList[string, *protocolpb.Message], ) error { + for _, msg := range msgs.TakeRemaining() { + err := handler.handleMessage(ctx, msg) + if err != nil || handler.stopProcessing { + return err + } + } + handler.metricsHandler.Counter(metrics.CommandTypeCompleteWorkflowCounter.GetMetricName()).Record(1) if handler.hasBufferedEvents { diff --git a/service/history/workflowTaskHandlerCallbacks.go b/service/history/workflowTaskHandlerCallbacks.go index 6d15e33dc6f..729cd42d7f3 100644 --- a/service/history/workflowTaskHandlerCallbacks.go +++ b/service/history/workflowTaskHandlerCallbacks.go @@ -31,6 +31,7 @@ import ( commandpb "go.temporal.io/api/command/v1" enumspb "go.temporal.io/api/enums/v1" historypb "go.temporal.io/api/history/v1" + protocolpb "go.temporal.io/api/protocol/v1" querypb "go.temporal.io/api/query/v1" "go.temporal.io/api/serviceerror" taskqueuepb "go.temporal.io/api/taskqueue/v1" @@ -40,6 +41,7 @@ import ( "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/clock" + "go.temporal.io/server/common/collection" "go.temporal.io/server/common/definition" "go.temporal.io/server/common/failure" "go.temporal.io/server/common/log" @@ -531,33 +533,17 @@ func (handler *workflowTaskHandlerCallbacksImpl) handleWorkflowTaskCompleted( hasBufferedEvents, ) - // roll message delivery up into a one-time-only func so that we can - // either call it as part of handling a workflow execution completed - // command or we can run it after all commands have been handled. This - // solves the problem of a single workflow task completion that contains - // _both_ the workflow execution completed command _and_ an update - // completed message. - msgsDelivered := false - handleMessages := func(ctx context.Context) error { - if !msgsDelivered { - msgsDelivered = true - return workflowTaskHandler.handleMessages(ctx, request.Messages) - } - return nil - } - if responseMutations, err = workflowTaskHandler.handleCommands( ctx, request.Commands, - handleMessages, + collection.NewIndexedTakeList( + request.Messages, + func(msg *protocolpb.Message) string { return msg.Id }, + ), ); err != nil { return nil, err } - if err := handleMessages(ctx); err != nil { - return nil, err - } - // set the vars used by following logic // further refactor should also clean up the vars used below wtFailedCause = workflowTaskHandler.workflowTaskFailedCause diff --git a/service/history/workflowTaskHandler_test.go b/service/history/workflowTaskHandler_test.go new file mode 100644 index 00000000000..a1d22b8afa9 --- /dev/null +++ b/service/history/workflowTaskHandler_test.go @@ -0,0 +1,314 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package history + +import ( + "context" + "testing" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/types" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + commandpb "go.temporal.io/api/command/v1" + enumspb "go.temporal.io/api/enums/v1" + protocolpb "go.temporal.io/api/protocol/v1" + "go.temporal.io/api/serviceerror" + updatepb "go.temporal.io/api/update/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/common/collection" + "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" + "go.temporal.io/server/internal/effect" + "go.temporal.io/server/service/history/configs" + "go.temporal.io/server/service/history/shard" + "go.temporal.io/server/service/history/workflow" + "go.temporal.io/server/service/history/workflow/update" +) + +func TestCommandProtocolMessage(t *testing.T) { + t.Parallel() + + type testconf struct { + ms *workflow.MockMutableState + updates update.Registry + handler *workflowTaskHandlerImpl + conf map[dynamicconfig.Key]any + } + + const defaultBlobSizeLimit = 1 * 1024 * 1024 + + msgCommand := func(msgID string) *commandpb.Command { + return &commandpb.Command{ + CommandType: enumspb.COMMAND_TYPE_PROTOCOL_MESSAGE, + Attributes: &commandpb.Command_ProtocolMessageCommandAttributes{ + ProtocolMessageCommandAttributes: &commandpb.ProtocolMessageCommandAttributes{ + MessageId: msgID, + }, + }, + } + } + + setup := func(t *testing.T, out *testconf, blobSizeLimit int) { + shardCtx := shard.NewMockContext(gomock.NewController(t)) + logger := log.NewNoopLogger() + metricsHandler := metrics.NoopMetricsHandler + out.conf = map[dynamicconfig.Key]any{} + out.ms = workflow.NewMockMutableState(gomock.NewController(t)) + out.ms.EXPECT().GetAcceptedWorkflowExecutionUpdateIDs(gomock.Any()).Times(1).Return(nil) + out.updates = update.NewRegistry(out.ms) + var effects effect.Buffer + config := configs.NewConfig( + dynamicconfig.NewCollection( + dynamicconfig.StaticClient(out.conf), logger), 1, false, false) + mockMeta := persistence.NewMockMetadataManager(gomock.NewController(t)) + nsReg := namespace.NewRegistry( + mockMeta, + true, + func() time.Duration { return 1 * time.Hour }, + metricsHandler, + logger, + ) + out.handler = newWorkflowTaskHandler( // 😲 + t.Name(), //identity + 123, // workflowTaskCompletedID + out.ms, + out.updates, + &effects, + newCommandAttrValidator( + nsReg, + config, + nil, // searchAttributesValidator + ), + newWorkflowSizeChecker( + workflowSizeLimits{blobSizeLimitError: blobSizeLimit}, + out.ms, + nil, //searchAttributesValidator + &persistencespb.ExecutionStats{}, + metricsHandler, + logger, + ), + logger, + nsReg, + metricsHandler, + config, + shardCtx, + nil, // searchattribute.MapperProvider + false, + ) + } + + t.Run("missing message ID", func(t *testing.T) { + var tc testconf + setup(t, &tc, defaultBlobSizeLimit) + var ( + command = msgCommand("") // blank is invalid + ) + + tc.ms.EXPECT().GetExecutionInfo().AnyTimes().Return(&persistencespb.WorkflowExecutionInfo{}) + + _, err := tc.handler.handleCommand(context.Background(), command, newMsgList()) + require.NoError(t, err) + require.NotNil(t, tc.handler.workflowTaskFailedCause) + require.Equal(t, + enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, + tc.handler.workflowTaskFailedCause.failedCause) + }) + + t.Run("message not found", func(t *testing.T) { + var tc testconf + setup(t, &tc, defaultBlobSizeLimit) + var ( + command = msgCommand("valid_but_not_found_msg_id") + ) + + tc.ms.EXPECT().GetExecutionInfo().AnyTimes().Return(&persistencespb.WorkflowExecutionInfo{}) + + _, err := tc.handler.handleCommand(context.Background(), command, newMsgList()) + require.NoError(t, err) + require.NotNil(t, tc.handler.workflowTaskFailedCause) + require.Equal(t, + enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, + tc.handler.workflowTaskFailedCause.failedCause) + }) + + t.Run("message too large", func(t *testing.T) { + var tc testconf + t.Log("setting max blob size to zero") + setup(t, &tc, 0) + var ( + msgID = t.Name() + "-message-id" + command = msgCommand(msgID) // blank is invalid + msg = &protocolpb.Message{ + Id: msgID, + ProtocolInstanceId: "does_not_matter", + Body: mustMarshalAny(t, &types.Empty{}), + } + ) + + tc.ms.EXPECT().GetExecutionInfo().AnyTimes().Return(&persistencespb.WorkflowExecutionInfo{}) + tc.ms.EXPECT().GetExecutionState().AnyTimes().Return(&persistencespb.WorkflowExecutionState{}) + + _, err := tc.handler.handleCommand(context.Background(), command, newMsgList(msg)) + require.NoError(t, err) + require.NotNil(t, tc.handler.workflowTaskFailedCause) + require.Equal(t, + enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, + tc.handler.workflowTaskFailedCause.failedCause) + require.ErrorContains(t, tc.handler.workflowTaskFailedCause.causeErr, "exceeds size limit") + }) + + t.Run("message for unsupported protocol", func(t *testing.T) { + var tc testconf + setup(t, &tc, defaultBlobSizeLimit) + var ( + msgID = t.Name() + "-message-id" + command = msgCommand(msgID) // blank is invalid + msg = &protocolpb.Message{ + Id: msgID, + ProtocolInstanceId: "does_not_matter", + Body: mustMarshalAny(t, &types.Empty{}), + } + ) + + tc.ms.EXPECT().GetExecutionInfo().AnyTimes().Return(&persistencespb.WorkflowExecutionInfo{}) + tc.ms.EXPECT().GetExecutionState().AnyTimes().Return(&persistencespb.WorkflowExecutionState{}) + + _, err := tc.handler.handleCommand(context.Background(), command, newMsgList(msg)) + require.NoError(t, err) + require.NotNil(t, tc.handler.workflowTaskFailedCause) + require.Equal(t, + enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, + tc.handler.workflowTaskFailedCause.failedCause) + var invalidArg *serviceerror.InvalidArgument + require.ErrorAs(t, tc.handler.workflowTaskFailedCause.causeErr, &invalidArg) + require.ErrorContains(t, tc.handler.workflowTaskFailedCause.causeErr, "protocol type") + }) + + t.Run("update not found", func(t *testing.T) { + var tc testconf + setup(t, &tc, defaultBlobSizeLimit) + var ( + msgID = t.Name() + "-message-id" + command = msgCommand(msgID) // blank is invalid + msg = &protocolpb.Message{ + Id: msgID, + ProtocolInstanceId: "will not be found", + Body: mustMarshalAny(t, &updatepb.Acceptance{}), + } + ) + + tc.ms.EXPECT().GetExecutionInfo().AnyTimes().Return(&persistencespb.WorkflowExecutionInfo{}) + tc.ms.EXPECT().GetExecutionState().AnyTimes().Return(&persistencespb.WorkflowExecutionState{}) + tc.ms.EXPECT().GetUpdateInfo(gomock.Any(), msg.ProtocolInstanceId).Return(nil, false) + + _, err := tc.handler.handleCommand(context.Background(), command, newMsgList(msg)) + require.NoError(t, err) + require.NotNil(t, tc.handler.workflowTaskFailedCause) + require.Equal(t, + enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, + tc.handler.workflowTaskFailedCause.failedCause) + var notfound *serviceerror.NotFound + require.ErrorAs(t, tc.handler.workflowTaskFailedCause.causeErr, ¬found) + }) + + t.Run("deliver message failure", func(t *testing.T) { + var tc testconf + setup(t, &tc, defaultBlobSizeLimit) + var ( + updateID = t.Name() + "-update-id" + msgID = t.Name() + "-message-id" + command = msgCommand(msgID) // blank is invalid + msg = &protocolpb.Message{ + Id: msgID, + ProtocolInstanceId: updateID, + Body: mustMarshalAny(t, &updatepb.Acceptance{}), + } + ) + tc.ms.EXPECT().GetExecutionInfo().AnyTimes().Return(&persistencespb.WorkflowExecutionInfo{}) + tc.ms.EXPECT().GetExecutionState().AnyTimes().Return(&persistencespb.WorkflowExecutionState{}) + tc.ms.EXPECT().GetUpdateInfo(gomock.Any(), updateID).Return(nil, false) + + t.Log("create the expected protocol instance") + _, _, err := tc.updates.FindOrCreate(context.Background(), updateID) + require.NoError(t, err) + + t.Log("delivering an acceptance message to an update in the admitted state should cause a protocol error") + _, err = tc.handler.handleCommand(context.Background(), command, newMsgList(msg)) + require.NoError(t, err) + require.NotNil(t, tc.handler.workflowTaskFailedCause) + require.Equal(t, + enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE, + tc.handler.workflowTaskFailedCause.failedCause) + var gotErr *serviceerror.InvalidArgument + require.ErrorAs(t, tc.handler.workflowTaskFailedCause.causeErr, &gotErr) + }) + + t.Run("deliver message success", func(t *testing.T) { + var tc testconf + setup(t, &tc, defaultBlobSizeLimit) + var ( + updateID = t.Name() + "-update-id" + msgID = t.Name() + "-message-id" + command = msgCommand(msgID) // blank is invalid + msg = &protocolpb.Message{ + Id: msgID, + ProtocolInstanceId: updateID, + Body: mustMarshalAny(t, &updatepb.Request{ + Meta: &updatepb.Meta{UpdateId: updateID}, + Input: &updatepb.Input{Name: "not_empty"}, + }), + } + msgs = newMsgList(msg) + ) + tc.ms.EXPECT().GetExecutionInfo().AnyTimes().Return(&persistencespb.WorkflowExecutionInfo{}) + tc.ms.EXPECT().GetExecutionState().AnyTimes().Return(&persistencespb.WorkflowExecutionState{}) + tc.ms.EXPECT().GetUpdateInfo(gomock.Any(), updateID).Return(nil, false) + + t.Log("create the expected protocol instance") + _, _, err := tc.updates.FindOrCreate(context.Background(), updateID) + require.NoError(t, err) + + _, err = tc.handler.handleCommand(context.Background(), command, msgs) + require.NoError(t, err, + "delivering a request message to an update in the admitted state should succeed") + require.Nil(t, tc.handler.workflowTaskFailedCause) + }) +} + +func newMsgList(msgs ...*protocolpb.Message) *collection.IndexedTakeList[string, *protocolpb.Message] { + return collection.NewIndexedTakeList(msgs, func(msg *protocolpb.Message) string { return msg.Id }) +} + +func mustMarshalAny(t *testing.T, pb proto.Message) *types.Any { + t.Helper() + a, err := types.MarshalAny(pb) + require.NoError(t, err) + return a +}