Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set namespace on API if not present #3953

Merged
merged 2 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions common/rpc/interceptor/namespace_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ import (
)

type (
// NamespaceValidatorInterceptor contains LengthValidationIntercept and StateValidationIntercept
TaskTokenGetter interface {
GetTaskToken() []byte
}

// NamespaceValidatorInterceptor contains NamespaceValidateIntercept and StateValidationIntercept
NamespaceValidatorInterceptor struct {
namespaceRegistry namespace.Registry
tokenSerializer common.TaskTokenSerializer
Expand Down Expand Up @@ -71,7 +75,7 @@ var (
)

var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).StateValidationIntercept
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).LengthValidationIntercept
var _ grpc.UnaryServerInterceptor = (*NamespaceValidatorInterceptor)(nil).NamespaceValidateIntercept

func NewNamespaceValidatorInterceptor(
namespaceRegistry namespace.Registry,
Expand All @@ -86,12 +90,16 @@ func NewNamespaceValidatorInterceptor(
}
}

func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept(
func (ni *NamespaceValidatorInterceptor) NamespaceValidateIntercept(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
err := ni.setNamespaceIfNotPresent(req)
if err != nil {
return nil, err
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiminc @yycptt what if customer modify & put a namespace ID in the task token?

if returning not found error here legit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Existing behavior will also return not found in this case I think?

}
reqWithNamespace, hasNamespace := req.(NamespaceNameGetter)
if hasNamespace {
namespaceName := namespace.Name(reqWithNamespace.GetNamespace())
Expand All @@ -103,6 +111,60 @@ func (ni *NamespaceValidatorInterceptor) LengthValidationIntercept(
return handler(ctx, req)
}

func (ni *NamespaceValidatorInterceptor) setNamespaceIfNotPresent(
req interface{},
) error {
switch request := req.(type) {
case NamespaceNameGetter:
if request.GetNamespace() == "" {
namespaceEntry, err := ni.extractNamespaceFromTaskToken(req)
if err != nil {
return err
}
ni.setNamespace(namespaceEntry, req)
}
return nil
default:
return nil
}
}

func (ni *NamespaceValidatorInterceptor) setNamespace(
namespaceEntry *namespace.Namespace,
req interface{},
) {
switch request := req.(type) {
case *workflowservice.RespondQueryTaskCompletedRequest:
if request.Namespace == "" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks like we no longer need this check.

request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondWorkflowTaskCompletedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondWorkflowTaskFailedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RecordActivityTaskHeartbeatRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondActivityTaskCanceledRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondActivityTaskCompletedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
case *workflowservice.RespondActivityTaskFailedRequest:
if request.Namespace == "" {
request.Namespace = namespaceEntry.Name().String()
}
}
}

// StateValidationIntercept validates:
// 1. Namespace is specified in task token if there is a `task_token` field.
// 2. Namespace is specified in request if there is a `namespace` field and no `task_token` field.
Expand Down Expand Up @@ -202,7 +264,7 @@ func (ni *NamespaceValidatorInterceptor) extractNamespaceFromRequest(req interfa
}

func (ni *NamespaceValidatorInterceptor) extractNamespaceFromTaskToken(req interface{}) (*namespace.Namespace, error) {
reqWithTaskToken, hasTaskToken := req.(interface{ GetTaskToken() []byte })
reqWithTaskToken, hasTaskToken := req.(TaskTokenGetter)
if !hasTaskToken {
return nil, nil
}
Expand Down
105 changes: 102 additions & 3 deletions common/rpc/interceptor/namespace_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"testing"

"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
enumspb "go.temporal.io/api/enums/v1"
Expand Down Expand Up @@ -684,18 +685,44 @@ func (s *namespaceValidatorSuite) Test_Intercept_SearchAttributeRequests() {
}
}

func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() {
func (s *namespaceValidatorSuite) Test_NamespaceValidateIntercept() {
nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(10))
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/temporal/random",
}
requestNamespace := namespace.FromPersistentState(
&persistence.GetNamespaceResponse{
Namespace: &persistencespb.NamespaceDetail{
Config: &persistencespb.NamespaceConfig{},
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
Info: &persistencespb.NamespaceInfo{
Id: uuid.New().String(),
Name: "namespace",
State: enumspb.NAMESPACE_STATE_REGISTERED,
},
},
})
requestNamespaceTooLong := namespace.FromPersistentState(
&persistence.GetNamespaceResponse{
Namespace: &persistencespb.NamespaceDetail{
Config: &persistencespb.NamespaceConfig{},
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
Info: &persistencespb.NamespaceInfo{
Id: uuid.New().String(),
Name: "namespaceTooLong",
State: enumspb.NAMESPACE_STATE_REGISTERED,
},
},
})
s.mockRegistry.EXPECT().GetNamespace(namespace.Name("namespace")).Return(requestNamespace, nil).AnyTimes()
s.mockRegistry.EXPECT().GetNamespace(namespace.Name("namespaceTooLong")).Return(requestNamespaceTooLong, nil).AnyTimes()

req := &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespace"}
handlerCalled := false
_, err := nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err := nvi.NamespaceValidateIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
Expand All @@ -704,10 +731,82 @@ func (s *namespaceValidatorSuite) Test_LengthValidationIntercept() {

req = &workflowservice.StartWorkflowExecutionRequest{Namespace: "namespaceTooLong"}
handlerCalled = false
_, err = nvi.LengthValidationIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
_, err = nvi.NamespaceValidateIntercept(context.Background(), req, serverInfo, func(ctx context.Context, req interface{}) (interface{}, error) {
handlerCalled = true
return &workflowservice.StartWorkflowExecutionResponse{}, nil
})
s.False(handlerCalled)
s.Error(err)
}

func (s *namespaceValidatorSuite) TestSetNamespace() {
namespaceRequestName := uuid.New().String()
namespaceEntryName := uuid.New().String()
namespaceEntry := namespace.FromPersistentState(
&persistence.GetNamespaceResponse{
Namespace: &persistencespb.NamespaceDetail{
Config: &persistencespb.NamespaceConfig{},
ReplicationConfig: &persistencespb.NamespaceReplicationConfig{},
Info: &persistencespb.NamespaceInfo{
Id: uuid.New().String(),
Name: namespaceEntryName,
State: enumspb.NAMESPACE_STATE_REGISTERED,
},
},
})

nvi := NewNamespaceValidatorInterceptor(
s.mockRegistry,
dynamicconfig.GetBoolPropertyFn(false),
dynamicconfig.GetIntPropertyFn(10),
)

queryReq := &workflowservice.RespondQueryTaskCompletedRequest{}
nvi.setNamespace(namespaceEntry, queryReq)
s.Equal(namespaceEntryName, queryReq.Namespace)
queryReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, queryReq)
s.Equal(namespaceRequestName, queryReq.Namespace)

completeWorkflowTaskReq := &workflowservice.RespondWorkflowTaskCompletedRequest{}
nvi.setNamespace(namespaceEntry, completeWorkflowTaskReq)
s.Equal(namespaceEntryName, completeWorkflowTaskReq.Namespace)
completeWorkflowTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, completeWorkflowTaskReq)
s.Equal(namespaceRequestName, completeWorkflowTaskReq.Namespace)

failWorkflowTaskReq := &workflowservice.RespondWorkflowTaskFailedRequest{}
nvi.setNamespace(namespaceEntry, failWorkflowTaskReq)
s.Equal(namespaceEntryName, failWorkflowTaskReq.Namespace)
failWorkflowTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, failWorkflowTaskReq)
s.Equal(namespaceRequestName, failWorkflowTaskReq.Namespace)

heartbeatActivityTaskReq := &workflowservice.RecordActivityTaskHeartbeatRequest{}
nvi.setNamespace(namespaceEntry, heartbeatActivityTaskReq)
s.Equal(namespaceEntryName, heartbeatActivityTaskReq.Namespace)
heartbeatActivityTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, heartbeatActivityTaskReq)
s.Equal(namespaceRequestName, heartbeatActivityTaskReq.Namespace)

cancelActivityTaskReq := &workflowservice.RespondActivityTaskCanceledRequest{}
nvi.setNamespace(namespaceEntry, cancelActivityTaskReq)
s.Equal(namespaceEntryName, cancelActivityTaskReq.Namespace)
cancelActivityTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, cancelActivityTaskReq)
s.Equal(namespaceRequestName, cancelActivityTaskReq.Namespace)

completeActivityTaskReq := &workflowservice.RespondActivityTaskCompletedRequest{}
nvi.setNamespace(namespaceEntry, completeActivityTaskReq)
s.Equal(namespaceEntryName, completeActivityTaskReq.Namespace)
completeActivityTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, completeActivityTaskReq)
s.Equal(namespaceRequestName, completeActivityTaskReq.Namespace)

failActivityTaskReq := &workflowservice.RespondActivityTaskFailedRequest{}
nvi.setNamespace(namespaceEntry, failActivityTaskReq)
s.Equal(namespaceEntryName, failActivityTaskReq.Namespace)
failActivityTaskReq.Namespace = namespaceRequestName
nvi.setNamespace(namespaceEntry, failActivityTaskReq)
s.Equal(namespaceRequestName, failActivityTaskReq.Namespace)
}
2 changes: 1 addition & 1 deletion service/frontend/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func GrpcServerOptionsProvider(
interceptors := []grpc.UnaryServerInterceptor{
// Service Error Interceptor should be the most outer interceptor on error handling
rpc.ServiceErrorInterceptor,
namespaceValidatorInterceptor.LengthValidationIntercept,
namespaceValidatorInterceptor.NamespaceValidateIntercept,
namespaceLogInterceptor.Intercept, // TODO: Deprecate this with a outer custom interceptor
grpc.UnaryServerInterceptor(traceInterceptor),
metrics.NewServerMetricsContextInjectorInterceptor(),
Expand Down