From 2b95231cf8a6baf312a7937c925e70e9fcb987aa Mon Sep 17 00:00:00 2001 From: Zijian Date: Tue, 5 Mar 2024 02:00:43 +0000 Subject: [PATCH] Add unit tests for CreateWorkflowExecution --- common/persistence/sql/sql_execution_store.go | 21 +- .../sql/sql_execution_store_test.go | 299 ++++++++++++++++++ .../sql/sql_execution_store_util_test.go | 2 + 3 files changed, 316 insertions(+), 6 deletions(-) diff --git a/common/persistence/sql/sql_execution_store.go b/common/persistence/sql/sql_execution_store.go index 13c5c200f4e..fb6a4bd49f8 100644 --- a/common/persistence/sql/sql_execution_store.go +++ b/common/persistence/sql/sql_execution_store.go @@ -50,6 +50,10 @@ const ( type sqlExecutionStore struct { sqlStore shardID int + txExecuteShardLockedFn func(context.Context, int, string, int64, func(sqlplugin.Tx) error) error + lockCurrentExecutionIfExistsFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) + createOrUpdateCurrentExecutionFn func(context.Context, sqlplugin.Tx, p.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error + applyWorkflowSnapshotTxAsNewFn func(context.Context, sqlplugin.Tx, int, *p.InternalWorkflowSnapshot, serialization.Parser) error } var _ p.ExecutionStore = (*sqlExecutionStore)(nil) @@ -63,15 +67,20 @@ func NewSQLExecutionStore( dc *p.DynamicConfiguration, ) (p.ExecutionStore, error) { - return &sqlExecutionStore{ + store := &sqlExecutionStore{ shardID: shardID, + lockCurrentExecutionIfExistsFn: lockCurrentExecutionIfExists, + createOrUpdateCurrentExecutionFn: createOrUpdateCurrentExecution, + applyWorkflowSnapshotTxAsNewFn: applyWorkflowSnapshotTxAsNew, sqlStore: sqlStore{ db: db, logger: logger, parser: parser, dc: dc, }, - }, nil + } + store.txExecuteShardLockedFn = store.txExecuteShardLocked + return store, nil } // txExecuteShardLocked executes f under transaction and with read lock on shard row @@ -105,7 +114,7 @@ func (m *sqlExecutionStore) CreateWorkflowExecution( ) (response *p.CreateWorkflowExecutionResponse, err error) { dbShardID := sqlplugin.GetDBShardIDFromHistoryShardID(m.shardID, m.db.GetTotalNumDBShards()) - err = m.txExecuteShardLocked(ctx, dbShardID, "CreateWorkflowExecution", request.RangeID, func(tx sqlplugin.Tx) error { + err = m.txExecuteShardLockedFn(ctx, dbShardID, "CreateWorkflowExecution", request.RangeID, func(tx sqlplugin.Tx) error { response, err = m.createWorkflowExecutionTx(ctx, tx, request) return err }) @@ -136,7 +145,7 @@ func (m *sqlExecutionStore) createWorkflowExecutionTx( var err error var row *sqlplugin.CurrentExecutionsRow - if row, err = lockCurrentExecutionIfExists(ctx, tx, m.shardID, domainID, workflowID); err != nil { + if row, err = m.lockCurrentExecutionIfExistsFn(ctx, tx, m.shardID, domainID, workflowID); err != nil { return nil, err } @@ -204,7 +213,7 @@ func (m *sqlExecutionStore) createWorkflowExecutionTx( } } - if err := createOrUpdateCurrentExecution( + if err := m.createOrUpdateCurrentExecutionFn( ctx, tx, request.Mode, @@ -220,7 +229,7 @@ func (m *sqlExecutionStore) createWorkflowExecutionTx( return nil, err } - if err := applyWorkflowSnapshotTxAsNew(ctx, tx, shardID, &request.NewWorkflowSnapshot, m.parser); err != nil { + if err := m.applyWorkflowSnapshotTxAsNewFn(ctx, tx, shardID, &request.NewWorkflowSnapshot, m.parser); err != nil { return nil, err } diff --git a/common/persistence/sql/sql_execution_store_test.go b/common/persistence/sql/sql_execution_store_test.go index a720d962047..413f04d6b9e 100644 --- a/common/persistence/sql/sql_execution_store_test.go +++ b/common/persistence/sql/sql_execution_store_test.go @@ -1955,3 +1955,302 @@ func TestTxExecuteShardLocked(t *testing.T) { }) } } + +func TestCreateWorkflowExecution(t *testing.T) { + testCases := []struct { + name string + req *persistence.InternalCreateWorkflowExecutionRequest + lockCurrentExecutionIfExistsFn func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) + createOrUpdateCurrentExecutionFn func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error + applyWorkflowSnapshotTxAsNewFn func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error + wantErr bool + want *persistence.CreateWorkflowExecutionResponse + assertErr func(t *testing.T, err error) + }{ + { + name: "Success - mode brand new", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeBrandNew, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return nil, nil + }, + createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error { + return nil + }, + applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error { + return nil + }, + want: &persistence.CreateWorkflowExecutionResponse{}, + }, + { + name: "Success - mode workflow ID reuse", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeWorkflowIDReuse, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return &sqlplugin.CurrentExecutionsRow{ + State: persistence.WorkflowStateCompleted, + }, nil + }, + createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error { + return nil + }, + applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error { + return nil + }, + want: &persistence.CreateWorkflowExecutionResponse{}, + }, + { + name: "Success - mode zombie", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeZombie, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{ + State: persistence.WorkflowStateZombie, + }, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return &sqlplugin.CurrentExecutionsRow{ + RunID: serialization.MustParseUUID("abdcea69-61d5-44c3-9d55-afe23505a54a"), + }, nil + }, + createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error { + return nil + }, + applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error { + return nil + }, + want: &persistence.CreateWorkflowExecutionResponse{}, + }, + { + name: "Error - mode state validation failed", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeZombie, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{ + State: persistence.WorkflowStateCreated, + }, + }, + }, + wantErr: true, + }, + { + name: "Error - lockCurrentExecutionIfExists failed", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeBrandNew, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return nil, errors.New("some random error") + }, + wantErr: true, + }, + { + name: "Error - mode brand new", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeBrandNew, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return &sqlplugin.CurrentExecutionsRow{ + CreateRequestID: "test", + WorkflowID: "test", + RunID: serialization.MustParseUUID("abdcea69-61d5-44c3-9d55-afe23505a54a"), + State: persistence.WorkflowStateCreated, + CloseStatus: persistence.WorkflowCloseStatusNone, + LastWriteVersion: 10, + }, nil + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, &persistence.WorkflowExecutionAlreadyStartedError{ + Msg: "Workflow execution already running. WorkflowId: test", + StartRequestID: "test", + RunID: "abdcea69-61d5-44c3-9d55-afe23505a54a", + State: persistence.WorkflowStateCreated, + CloseStatus: persistence.WorkflowCloseStatusNone, + LastWriteVersion: 10, + }, err) + }, + }, + { + name: "Error - mode workflow ID reuse, version mismatch", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeWorkflowIDReuse, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return &sqlplugin.CurrentExecutionsRow{ + State: persistence.WorkflowStateCompleted, + LastWriteVersion: 10, + }, nil + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, &persistence.CurrentWorkflowConditionFailedError{ + Msg: "Workflow execution creation condition failed. WorkflowId: , LastWriteVersion: 10, PreviousLastWriteVersion: 0", + }, err) + }, + }, + { + name: "Error - mode workflow ID reuse, state mismatch", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeWorkflowIDReuse, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return &sqlplugin.CurrentExecutionsRow{ + State: persistence.WorkflowStateCreated, + }, nil + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, &persistence.CurrentWorkflowConditionFailedError{ + Msg: "Workflow execution creation condition failed. WorkflowId: , State: 0, Expected: 2", + }, err) + }, + }, + { + name: "Error - mode workflow ID reuse, run ID mismatch", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeWorkflowIDReuse, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return &sqlplugin.CurrentExecutionsRow{ + State: persistence.WorkflowStateCompleted, + RunID: serialization.MustParseUUID("abdcea69-61d5-44c3-9d55-afe23505a54a"), + }, nil + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, &persistence.CurrentWorkflowConditionFailedError{ + Msg: "Workflow execution creation condition failed. WorkflowId: , RunID: abdcea69-61d5-44c3-9d55-afe23505a54a, PreviousRunID: ", + }, err) + }, + }, + { + name: "Error - mode zombie, run ID match", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeZombie, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{ + State: persistence.WorkflowStateZombie, + }, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return &sqlplugin.CurrentExecutionsRow{}, nil + }, + wantErr: true, + }, + { + name: "Error - unknown mode", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowMode(100), + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + wantErr: true, + }, + { + name: "Error - createOrUpdateCurrentExecution failed", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeBrandNew, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return nil, nil + }, + createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error { + return errors.New("some random error") + }, + wantErr: true, + }, + { + name: "Error - applyWorkflowSnapshotTxAsNew failed", + req: &persistence.InternalCreateWorkflowExecutionRequest{ + RangeID: 1, + Mode: persistence.CreateWorkflowModeBrandNew, + NewWorkflowSnapshot: persistence.InternalWorkflowSnapshot{ + ExecutionInfo: &persistence.InternalWorkflowExecutionInfo{}, + }, + }, + lockCurrentExecutionIfExistsFn: func(context.Context, sqlplugin.Tx, int, serialization.UUID, string) (*sqlplugin.CurrentExecutionsRow, error) { + return nil, nil + }, + createOrUpdateCurrentExecutionFn: func(context.Context, sqlplugin.Tx, persistence.CreateWorkflowMode, int, serialization.UUID, string, serialization.UUID, int, int, string, int64, int64) error { + return nil + }, + applyWorkflowSnapshotTxAsNewFn: func(context.Context, sqlplugin.Tx, int, *persistence.InternalWorkflowSnapshot, serialization.Parser) error { + return errors.New("some random error") + }, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := sqlplugin.NewMockDB(ctrl) + mockDB.EXPECT().GetTotalNumDBShards().Return(1) + s := &sqlExecutionStore{ + shardID: 0, + sqlStore: sqlStore{ + db: mockDB, + logger: testlogger.New(t), + }, + txExecuteShardLockedFn: func(_ context.Context, _ int, _ string, _ int64, fn func(sqlplugin.Tx) error) error { + return fn(nil) + }, + lockCurrentExecutionIfExistsFn: tc.lockCurrentExecutionIfExistsFn, + createOrUpdateCurrentExecutionFn: tc.createOrUpdateCurrentExecutionFn, + applyWorkflowSnapshotTxAsNewFn: tc.applyWorkflowSnapshotTxAsNewFn, + } + + got, err := s.CreateWorkflowExecution(context.Background(), tc.req) + if tc.wantErr { + assert.Error(t, err, "Expected an error for test case") + if tc.assertErr != nil { + tc.assertErr(t, err) + } + } else { + assert.NoError(t, err, "Did not expect an error for test case") + assert.Equal(t, tc.want, got, "Unexpected result for test case") + } + }) + } +} diff --git a/common/persistence/sql/sql_execution_store_util_test.go b/common/persistence/sql/sql_execution_store_util_test.go index 0a38717cf66..7f2129a0630 100644 --- a/common/persistence/sql/sql_execution_store_util_test.go +++ b/common/persistence/sql/sql_execution_store_util_test.go @@ -499,6 +499,7 @@ func TestApplyWorkflowMutationTx(t *testing.T) { DeleteSignalInfos: []int64{1, 2}, UpsertSignalRequestedIDs: []string{"a", "b"}, DeleteSignalRequestedIDs: []string{"c", "d"}, + ClearBufferedEvents: true, }, mockSetup: func(mockTx *sqlplugin.MockTx, mockParser *serialization.MockParser) { mockSetupLockAndCheckNextEventID(mockTx, shardID, serialization.MustParseUUID("8be8a310-7d20-483e-a5d2-48659dc47602"), "abc", serialization.MustParseUUID("8be8a310-7d20-483e-a5d2-48659dc47603"), 9, false) @@ -510,6 +511,7 @@ func TestApplyWorkflowMutationTx(t *testing.T) { mockUpdateRequestCancelInfos(mockTx, mockParser, 1, 2, false) mockUpdateSignalInfos(mockTx, mockParser, 1, 2, false) mockUpdateSignalRequested(mockTx, mockParser, 1, 2, false) + mockDeleteBufferedEvents(mockTx, shardID, serialization.MustParseUUID("8be8a310-7d20-483e-a5d2-48659dc47602"), "abc", serialization.MustParseUUID("8be8a310-7d20-483e-a5d2-48659dc47603"), false) }, wantErr: false, },