Skip to content

Commit

Permalink
Handle ProtocolMessageCommands in WFT completion (#4328)
Browse files Browse the repository at this point in the history
ProtocolMessageCommands are references to messages in the same WFT and
serve as a way to sequence command and message processing.
  • Loading branch information
Matt McShane authored May 15, 2023
1 parent 1a479ad commit 9712711
Show file tree
Hide file tree
Showing 10 changed files with 736 additions and 66 deletions.
86 changes: 86 additions & 0 deletions common/collection/indexedtakelist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package collection

// IndexedTakeList holds a set of values that can only be observed by being
// removed from the set. It is possible for this set to contain duplicate values
// as long as each value maps to a distinct index.
type (
IndexedTakeList[K comparable, V any] struct {
values []kv[K, V]
}

kv[K comparable, V any] struct {
key K
value V
removed bool
}
)

// NewIndexedTakeList constructs a new IndexedTakeSet by applying the provided
// indexer to each of the provided values.
func NewIndexedTakeList[K comparable, V any](
values []V,
indexer func(V) K,
) *IndexedTakeList[K, V] {
ret := &IndexedTakeList[K, V]{
values: make([]kv[K, V], 0, len(values)),
}
for _, v := range values {
ret.values = append(ret.values, kv[K, V]{key: indexer(v), value: v})
}
return ret
}

// Take finds a value in this set by its key and removes it, returning the
// value.
func (itl *IndexedTakeList[K, V]) Take(key K) (V, bool) {
var zero V
for i := 0; i < len(itl.values); i++ {
kv := &itl.values[i]
if kv.key != key {
continue
}
if kv.removed {
return zero, false
}
kv.removed = true
return kv.value, true
}
return zero, false
}

// TakeRemaining removes all remaining values from this set and returns them.
func (itl *IndexedTakeList[K, V]) TakeRemaining() []V {
out := make([]V, 0, len(itl.values))
for i := 0; i < len(itl.values); i++ {
kv := &itl.values[i]
if !kv.removed {
out = append(out, kv.value)
}
}
itl.values = nil
return out
}
71 changes: 71 additions & 0 deletions common/collection/indexedtakelist_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package collection_test

import (
"testing"

"github.com/stretchr/testify/require"
"go.temporal.io/server/common/collection"
)

func TestIndexedTakeList(t *testing.T) {
t.Parallel()

type Value struct{ ID int }
values := []Value{{ID: 0}, {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}}
indexer := func(v Value) int { return v.ID }

list := collection.NewIndexedTakeList(values, indexer)

t.Run("take absent", func(t *testing.T) {
_, ok := list.Take(999)
require.False(t, ok)
})

t.Run("take present", func(t *testing.T) {
v, ok := list.Take(3)
require.True(t, ok)
require.Equal(t, 3, v.ID)
})

t.Run("take taken", func(t *testing.T) {
_, ok := list.Take(3)
require.False(t, ok)
})

t.Run("take remaining", func(t *testing.T) {
allowed := []int{0, 1, 2, 4}
remaining := list.TakeRemaining()
require.Len(t, remaining, len(allowed))
for _, v := range remaining {
require.Contains(t, allowed, v.ID)
}
})

t.Run("now empty", func(t *testing.T) {
require.Empty(t, list.TakeRemaining())
})
}
1 change: 1 addition & 0 deletions common/metrics/metric_defs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,7 @@ var (
CommandTypeUpsertWorkflowSearchAttributesCounter = NewCounterDef("upsert_workflow_search_attributes_command")
CommandTypeModifyWorkflowPropertiesCounter = NewCounterDef("modify_workflow_properties_command")
CommandTypeChildWorkflowCounter = NewCounterDef("child_workflow_command")
CommandTypeProtocolMessage = NewCounterDef("protocol_message_command")
MessageTypeRequestWorkflowExecutionUpdateCounter = NewCounterDef("request_workflow_update_message")
MessageTypeAcceptWorkflowExecutionUpdateCounter = NewCounterDef("accept_workflow_update_message")
MessageTypeRespondWorkflowExecutionUpdateCounter = NewCounterDef("respond_workflow_update_message")
Expand Down
99 changes: 99 additions & 0 deletions internal/protocol/naming.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package protocol

import (
"fmt"
"strings"

"github.com/gogo/protobuf/types"
protocolpb "go.temporal.io/api/protocol/v1"
)

type (
// Type is a protocol type of which there may be many instances like the
// update protocol or the query protocol.
Type string

// MessageType is the type of a message within a protocol.
MessageType string

constErr string
)

const (
MessageTypeUnknown = MessageType("__message_type_unknown")
TypeUnknown = Type("__protocol_type_unknown")

errNilMsg = constErr("nil message")
errNilBody = constErr("nil message body")
errProtoFmt = constErr("failed to extract protocol type")
)

// String transforms a MessageType into a string
func (mt MessageType) String() string {
return string(mt)
}

// String tranforms a Type into a string
func (pt Type) String() string {
return string(pt)
}

// Identify is a function that given a protocol message gives the specific
// message type of the body and the type of the protocol to which this message
// belongs.
func Identify(msg *protocolpb.Message) (Type, MessageType, error) {
if msg == nil {
return TypeUnknown, MessageTypeUnknown, errNilMsg
}
if msg.Body == nil {
return TypeUnknown, MessageTypeUnknown, errNilBody
}
bodyTypeName, err := types.AnyMessageName(msg.Body)
if err != nil {
return TypeUnknown, MessageTypeUnknown, err
}
msgType := MessageType(bodyTypeName)

lastDot := strings.LastIndex(bodyTypeName, ".")
if lastDot < 0 {
err := fmt.Errorf("%w: no . found in %q", errProtoFmt, bodyTypeName)
return TypeUnknown, msgType, err
}
return Type(bodyTypeName[0:lastDot]), msgType, nil
}

// IdentifyOrUnknown wraps Identify to return TypeUnknown and/or
// MessageTypeUnknown in the case where either one cannot be determined due to
// an error.
func IdentifyOrUnknown(msg *protocolpb.Message) (Type, MessageType) {
pt, mt, _ := Identify(msg)
return pt, mt
}

func (cerr constErr) Error() string {
return string(cerr)
}
77 changes: 77 additions & 0 deletions internal/protocol/naming_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package protocol_test

import (
"testing"

"github.com/gogo/protobuf/types"
"github.com/stretchr/testify/require"
protocolpb "go.temporal.io/api/protocol/v1"
"go.temporal.io/server/internal/protocol"
)

func TestNilSafety(t *testing.T) {
t.Parallel()

t.Run("nil message", func(t *testing.T) {
pt, mt := protocol.IdentifyOrUnknown(nil)
require.Equal(t, protocol.TypeUnknown, pt)
require.Equal(t, protocol.MessageTypeUnknown, mt)
})

t.Run("nil message body", func(t *testing.T) {
pt, mt := protocol.IdentifyOrUnknown(&protocolpb.Message{})
require.Equal(t, protocol.TypeUnknown, pt)
require.Equal(t, protocol.MessageTypeUnknown, mt)
})
}

func TestWithValidMessage(t *testing.T) {
t.Parallel()

body, err := types.MarshalAny(&types.Empty{})
require.NoError(t, err)

msg := protocolpb.Message{Body: body}

pt, mt := protocol.IdentifyOrUnknown(&msg)

require.Equal(t, "google.protobuf", pt.String())
require.Equal(t, "google.protobuf.Empty", mt.String())
}

func TestWithInvalidBody(t *testing.T) {
t.Parallel()

body, err := types.MarshalAny(&types.Empty{})
require.NoError(t, err)

msg := protocolpb.Message{Body: body}
msg.Body.TypeUrl = "this isn't valid"

_, _, err = protocol.Identify(&msg)
require.Error(t, err)
}
18 changes: 18 additions & 0 deletions service/history/commandChecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,24 @@ func (c *workflowSizeChecker) checkIfSearchAttributesSizeExceedsLimit(
return err
}

func (v *commandAttrValidator) validateProtocolMessageAttributes(
namespaceID namespace.ID,
attributes *commandpb.ProtocolMessageCommandAttributes,
runTimeout time.Duration,
) (enumspb.WorkflowTaskFailedCause, error) {
const failedCause = enumspb.WORKFLOW_TASK_FAILED_CAUSE_BAD_UPDATE_WORKFLOW_EXECUTION_MESSAGE

if attributes == nil {
return failedCause, serviceerror.NewInvalidArgument("ProtocolMessageCommandAttributes is not set on command.")
}

if attributes.MessageId == "" {
return failedCause, serviceerror.NewInvalidArgument("MessageID is not set on command.")
}

return enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNSPECIFIED, nil
}

func (v *commandAttrValidator) validateActivityScheduleAttributes(
namespaceID namespace.ID,
attributes *commandpb.ScheduleActivityTaskCommandAttributes,
Expand Down
4 changes: 3 additions & 1 deletion service/history/workflow/update/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@

package update

const ProtocolV1 = "temporal.api.update.v1"
import "go.temporal.io/server/internal/protocol"

const ProtocolV1 = protocol.Type("temporal.api.update.v1")
Loading

0 comments on commit 9712711

Please sign in to comment.