Skip to content

Commit

Permalink
Fix checksum validation for SQL implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll committed Mar 15, 2024
1 parent 3091e41 commit bf45163
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 12 deletions.
1 change: 1 addition & 0 deletions common/persistence/data_manager_interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ type (
DomainID string
Execution types.WorkflowExecution
DomainName string
RangeID int64
}

// GetWorkflowExecutionResponse is the response to GetworkflowExecutionRequest
Expand Down
1 change: 1 addition & 0 deletions common/persistence/data_store_interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ type (
InternalGetWorkflowExecutionRequest struct {
DomainID string
Execution types.WorkflowExecution
RangeID int64
}

// InternalGetWorkflowExecutionResponse is the response to GetWorkflowExecution for Persistence Interface
Expand Down
9 changes: 6 additions & 3 deletions common/persistence/executionManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,12 @@ func (m *executionManagerImpl) SerializeWorkflowMutation(
if err != nil {
return nil, err
}
checksumData, err := m.serializer.SerializeChecksum(input.Checksum, common.EncodingTypeJSON)
if err != nil {
return nil, err
var checksumData *DataBlob
if len(input.Checksum.Value) > 0 {
checksumData, err = m.serializer.SerializeChecksum(input.Checksum, common.EncodingTypeJSON)
if err != nil {
return nil, err
}
}

return &InternalWorkflowMutation{
Expand Down
2 changes: 2 additions & 0 deletions common/persistence/persistence-tests/persistenceTestBase.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ func (s *TestBase) GetWorkflowExecutionInfoWithStats(ctx context.Context, domain
response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{
DomainID: domainID,
Execution: workflowExecution,
RangeID: s.ShardInfo.RangeID,
})
if err != nil {
return nil, nil, err
Expand All @@ -490,6 +491,7 @@ func (s *TestBase) GetWorkflowExecutionInfo(ctx context.Context, domainID string
response, err := s.ExecutionManager.GetWorkflowExecution(ctx, &persistence.GetWorkflowExecutionRequest{
DomainID: domainID,
Execution: workflowExecution,
RangeID: s.ShardInfo.RangeID,
})
if err != nil {
return nil, err
Expand Down
1 change: 1 addition & 0 deletions common/persistence/serializer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ func TestSerializers(t *testing.T) {
{
name: "checksum",
payloads: map[string]any{
"empty": checksum.Checksum{},
"normal": generateChecksum(),
},
serializeFn: func(payload any, encoding common.EncodingType) (*DataBlob, error) {
Expand Down
37 changes: 28 additions & 9 deletions common/persistence/sql/sql_execution_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/collection"
"github.com/uber/cadence/common/log"
"github.com/uber/cadence/common/persistence"
p "github.com/uber/cadence/common/persistence"
"github.com/uber/cadence/common/persistence/serialization"
"github.com/uber/cadence/common/persistence/sql/sqlplugin"
Expand Down Expand Up @@ -307,60 +308,60 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
var bufferedEvents []*p.DataBlob
var signalsRequested map[string]struct{}

g, ctx := errgroup.WithContext(ctx)
g, childCtx := errgroup.WithContext(ctx)

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
executions, e = m.getExecutions(ctx, request, domainID, wfID, runID)
executions, e = m.getExecutions(childCtx, request, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
activityInfos, e = getActivityInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
timerInfos, e = getTimerInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
childExecutionInfos, e = getChildExecutionInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
requestCancelInfos, e = getRequestCancelInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
signalInfos, e = getSignalInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
childCtx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
bufferedEvents, e = getBufferedEvents(
ctx, m.db, m.shardID, domainID, wfID, runID)
childCtx, m.db, m.shardID, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer func() { recoverPanic(recover(), &e) }()
signalsRequested, e = getSignalsRequested(
ctx, m.db, m.shardID, domainID, wfID, runID)
childCtx, m.db, m.shardID, domainID, wfID, runID)
return e
})

Expand All @@ -375,6 +376,24 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
Message: fmt.Sprintf("GetWorkflowExecution: failed. Error: %v", err),
}
}
// if we have checksum, we need to make sure the rangeID did not change
// if the rangeID changed, it means the shard ownership might have changed
// and the workflow might have been updated when we read the data, so the data
// we read might not from a consistent view, the checksum validation might fail
// in that case, we need to return an error
if state.ChecksumData != nil {
row, err := m.db.SelectFromShards(ctx, &sqlplugin.ShardsFilter{ShardID: int64(m.shardID)})
if err != nil {
return nil, convertCommonErrors(m.db, "GetWorkflowExecution", "", err)
}
if row.RangeID != request.RangeID {
return nil, &persistence.ShardOwnershipLostError{
ShardID: m.shardID,
Msg: fmt.Sprintf("GetWorkflowExecution failed. Previous rangeID: %v, new rangeID: %v", request.RangeID, row.RangeID),
}
}
}

state.ActivityInfos = activityInfos
state.TimerInfos = timerInfos
state.ChildExecutionInfos = childExecutionInfos
Expand Down
112 changes: 112 additions & 0 deletions common/persistence/sql/sql_execution_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2951,6 +2951,7 @@ func TestGetWorkflowExecution(t *testing.T) {
mockSetup func(*sqlplugin.MockDB, *serialization.MockParser)
want *persistence.InternalGetWorkflowExecutionResponse
wantErr bool
assertErr func(t *testing.T, err error)
}{
{
name: "Success case",
Expand All @@ -2960,6 +2961,7 @@ func TestGetWorkflowExecution(t *testing.T) {
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
RangeID: 1,
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
Expand Down Expand Up @@ -3180,6 +3182,9 @@ func TestGetWorkflowExecution(t *testing.T) {
Control: []byte("test control"),
RequestID: "test-signal-request-id",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(&sqlplugin.ShardsRow{
RangeID: 1,
}, nil)
},
want: &persistence.InternalGetWorkflowExecutionResponse{
State: &persistence.InternalWorkflowMutableState{
Expand Down Expand Up @@ -3342,6 +3347,110 @@ func TestGetWorkflowExecution(t *testing.T) {
},
wantErr: false,
},
{
name: "Error - Shard owner changed",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
{
ShardID: 0,
DomainID: serialization.MustParseUUID("ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d"),
WorkflowID: "test-workflow-id",
RunID: serialization.MustParseUUID("ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f"),
NextEventID: 101,
LastWriteVersion: 11,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}, nil)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
parser.EXPECT().WorkflowExecutionInfoFromBlob(gomock.Any(), gomock.Any()).Return(&serialization.WorkflowExecutionInfo{
Checksum: []byte("test-checksum"),
ChecksumEncoding: "test-checksum-encoding",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(&sqlplugin.ShardsRow{
RangeID: 1,
}, nil)
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.IsType(t, &persistence.ShardOwnershipLostError{}, err)
},
},
{
name: "Error - failed to get shard",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return([]sqlplugin.ExecutionsRow{
{
ShardID: 0,
DomainID: serialization.MustParseUUID("ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d"),
WorkflowID: "test-workflow-id",
RunID: serialization.MustParseUUID("ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f"),
NextEventID: 101,
LastWriteVersion: 11,
Data: []byte("test data"),
DataEncoding: "thriftrw",
},
}, nil)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
parser.EXPECT().WorkflowExecutionInfoFromBlob(gomock.Any(), gomock.Any()).Return(&serialization.WorkflowExecutionInfo{
Checksum: []byte("test-checksum"),
ChecksumEncoding: "test-checksum-encoding",
}, nil)
db.EXPECT().SelectFromShards(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes()
},
wantErr: true,
},
{
name: "Error - SelectFromExecutions no row",
req: &persistence.InternalGetWorkflowExecutionRequest{
DomainID: "ff9c8a3f-0e4f-4d3e-a4d2-6f5f8f3f7d9d",
Execution: types.WorkflowExecution{
WorkflowID: "test-workflow-id",
RunID: "ee8d7b6e-876c-4b1e-9b6e-5e3e3c6b6b3f",
},
},
mockSetup: func(db *sqlplugin.MockDB, parser *serialization.MockParser) {
db.EXPECT().SelectFromExecutions(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromActivityInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromTimerInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromChildExecutionInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromRequestCancelInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalInfoMaps(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromSignalsRequestedSets(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
db.EXPECT().SelectFromBufferedEvents(gomock.Any(), gomock.Any()).Return(nil, sql.ErrNoRows)
},
wantErr: true,
assertErr: func(t *testing.T, err error) {
assert.IsType(t, &types.EntityNotExistsError{}, err)
},
},
{
name: "Error - SelectFromExecutions failed",
req: &persistence.InternalGetWorkflowExecutionRequest{
Expand Down Expand Up @@ -3538,6 +3647,9 @@ func TestGetWorkflowExecution(t *testing.T) {
resp, err := s.GetWorkflowExecution(context.Background(), tc.req)
if tc.wantErr {
assert.Error(t, err, "Expected an error for test case")
if tc.assertErr != nil {
tc.assertErr(t, err)
}
} else {
assert.NoError(t, err, "Did not expect an error for test case")
assert.Equal(t, tc.want, resp, "Response mismatch")
Expand Down
3 changes: 3 additions & 0 deletions service/history/execution/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,9 @@ func (c *contextImpl) getWorkflowExecutionWithRetry(
case *types.EntityNotExistsError:
// it is possible that workflow does not exists
return nil, err
case *persistence.ShardOwnershipLostError:
// shard is stolen, should stop processing the workflow
return nil, err
default:
c.logger.Error(
"Persistent fetch operation failure",
Expand Down
2 changes: 2 additions & 0 deletions service/history/shard/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ func (s *contextImpl) GetWorkflowExecution(
if s.isClosed() {
return nil, ErrShardClosed
}
currentRangeID := s.getRangeID()
request.RangeID = currentRangeID
return s.executionManager.GetWorkflowExecution(ctx, request)
}

Expand Down
3 changes: 3 additions & 0 deletions service/history/shard/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ func TestGetWorkflowExecution(t *testing.T) {
mockExecutionMgr := &mocks.ExecutionManager{}
shardContext := &contextImpl{
executionManager: mockExecutionMgr,
shardInfo: &persistence.ShardInfo{
RangeID: 12,
},
}
if tc.isClosed {
shardContext.closed = 1
Expand Down

0 comments on commit bf45163

Please sign in to comment.