Skip to content

Commit

Permalink
items from first channel upgrades security audit session (cosmos#5664)
Browse files Browse the repository at this point in the history
* some items from the security audit

* update tests for cancelling when channel is in FLUSHCOMPLETE

* remove duplicate test

* nit: Move wrapping of returned error to direct return statement.

Do not create a new errChanUpgradeFailed for logging/returning, it obfuscates the logic and does not
yield much benefit.

* nit: consistency in err return cbs.OnChanUpgradeTry

Wrap the error before returning as is done in ChannelUpgradeInit.

* add comment

* nit: Call application callbacks before performing any state change operations.

* chore: Add documentation indicating to app devs that callbacks are invoked _before_ core state is written.

* upgrade error

* nit: Add note on handling wrapped error in IsUpgradeError docu string, add fmt.Errorf wrap test case.

* Update modules/core/04-channel/types/upgrade_test.go

* chore: Use consistent logging in writeUpgradeAck + writeUpgradeConfirm

Move log into the if to only log if channel state actually changes. Could be moved outside for both cases
but might then be logging when channel state doesn't change.

* Apply suggestions from code review

Co-authored-by: Cian Hatton <[email protected]>

* nit: review comments from Cian.

- Prefer using SetChannelState over look-up channel, modify, set.
- Add deeply test case with error wrapped multiple times.

* nit: Address Colin's comment

* nit: just linter thangs.

* Apply suggestions from code review

Co-authored-by: colin axnér <[email protected]>

* nit: Fix linter flatten if else.

---------

Co-authored-by: DimitrisJim <[email protected]>
Co-authored-by: Charly <[email protected]>
Co-authored-by: Charly <[email protected]>
Co-authored-by: Cian Hatton <[email protected]>
Co-authored-by: colin axnér <[email protected]>
  • Loading branch information
6 people authored Jan 23, 2024
1 parent 11e63f4 commit 41501a2
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 45 deletions.
3 changes: 2 additions & 1 deletion docs/docs/01-ibc/06-channel-upgrades.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ simd tx ibc channel prune-acknowledgements [port] [channel] [limit]

## IBC App Recommendations

IBC application callbacks should be primarily used to validate data fields and do compatibility checks.
IBC application callbacks should be primarily used to validate data fields and do compatibility checks. Application developers
should be aware that callbacks will be invoked before any core ibc state changes are written.

`OnChanUpgradeInit` should validate the proposed version, order, and connection hops, and should return the application version to upgrade to.

Expand Down
15 changes: 13 additions & 2 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,10 @@ func (k Keeper) WriteUpgradeAckChannel(ctx sdk.Context, portID, channelID string
}

if !k.HasInflightPackets(ctx, portID, channelID) {
previousState := channel.State
channel.State = types.FLUSHCOMPLETE
k.SetChannel(ctx, portID, channelID, channel)
k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", channel.State)
}

upgrade, found := k.GetUpgrade(ctx, portID, channelID)
Expand All @@ -361,7 +363,6 @@ func (k Keeper) WriteUpgradeAckChannel(ctx sdk.Context, portID, channelID string
k.SetUpgrade(ctx, portID, channelID, upgrade)
k.SetCounterpartyUpgrade(ctx, portID, channelID, counterpartyUpgrade)

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "state", channel.State.String())
return channel, upgrade
}

Expand Down Expand Up @@ -628,7 +629,17 @@ func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, err
return errorsmod.Wrap(commitmenttypes.ErrInvalidProof, "cannot submit an empty error receipt proof unless the sender is authorized to cancel upgrades AND channel is not in FLUSHCOMPLETE")
}

// the error receipt should also have a sequence greater than or equal to the current upgrade sequence.
// REPLAY PROTECTION: The error receipt MUST have a sequence greater than or equal to the current upgrade sequence,
// except when the channel state is FLUSHCOMPLETE, in which case the sequences MUST match. This is required
// to guarantee that when a counterparty successfully completes an upgrade and moves to OPEN, this channel
// cannot cancel its upgrade. Without this strict sequence check, it would be possible for the counterparty
// to complete its upgrade, move to OPEN, initiate a new upgrade (and thus increment the upgrade sequence) and
// then cancel the new upgrade, all in the same block. This results in a valid error receipt being written at channel.UpgradeSequence + 1.
// The desired behaviour in this circumstance is for this channel to complete its current upgrade despite proof
// of an error receipt at a greater upgrade sequence
if channel.State == types.FLUSHCOMPLETE && errorReceipt.Sequence != channel.UpgradeSequence {
return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be equal to current upgrade sequence (%d) when the channel is in FLUSHCOMPLETE", errorReceipt.Sequence, channel.UpgradeSequence)
}
if errorReceipt.Sequence < channel.UpgradeSequence {
return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be greater than or equal to current upgrade sequence (%d)", errorReceipt.Sequence, channel.UpgradeSequence)
}
Expand Down
69 changes: 50 additions & 19 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1573,35 +1573,24 @@ func (suite *KeeperTestSuite) TestChanUpgradeCancel() {
malleate func()
expError error
}{
{
name: "success with flush complete state",
malleate: func() {},
expError: nil,
},
{
name: "success with flushing state",
malleate: func() {
channel := path.EndpointA.GetChannel()
channel.State = types.FLUSHING
path.EndpointA.SetChannel(channel)
},
expError: nil,
},
{
name: "upgrade cannot be cancelled in FLUSHCOMPLETE with invalid error receipt",
malleate: func() {
errorReceiptProof = nil
},
expError: commitmenttypes.ErrInvalidProof,
},
{
name: "upgrade cannot be cancelled in FLUSHCOMPLETE with error receipt sequence less than channel upgrade sequence",
name: "success with flush complete state",
malleate: func() {
err := path.EndpointA.SetChannelState(types.FLUSHCOMPLETE)
suite.Require().NoError(err)

var ok bool
errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
suite.Require().True(ok)

errorReceipt.Sequence = path.EndpointA.GetChannel().UpgradeSequence - 1
// the error receipt upgrade sequence and the channel upgrade sequence must match
errorReceipt.Sequence = path.EndpointA.GetChannel().UpgradeSequence

suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt)

Expand All @@ -1612,7 +1601,17 @@ func (suite *KeeperTestSuite) TestChanUpgradeCancel() {
upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey)
},
expError: types.ErrInvalidUpgradeSequence,
expError: nil,
},
{
name: "upgrade cannot be cancelled in FLUSHCOMPLETE with invalid error receipt",
malleate: func() {
err := path.EndpointA.SetChannelState(types.FLUSHCOMPLETE)
suite.Require().NoError(err)

errorReceiptProof = nil
},
expError: commitmenttypes.ErrInvalidProof,
},
{
name: "channel not found",
Expand Down Expand Up @@ -1648,6 +1647,38 @@ func (suite *KeeperTestSuite) TestChanUpgradeCancel() {
},
expError: types.ErrInvalidUpgradeSequence,
},
{
name: "error receipt sequence greater than channel upgrade sequence when channel in FLUSHCOMPLETE",
malleate: func() {
err := path.EndpointA.SetChannelState(types.FLUSHCOMPLETE)
suite.Require().NoError(err)
},
expError: types.ErrInvalidUpgradeSequence,
},
{
name: "error receipt sequence smaller than channel upgrade sequence when channel in FLUSHCOMPLETE",
malleate: func() {
channel := path.EndpointA.GetChannel()
channel.State = types.FLUSHCOMPLETE
path.EndpointA.SetChannel(channel)

var ok bool
errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
suite.Require().True(ok)

errorReceipt.Sequence = path.EndpointA.GetChannel().UpgradeSequence - 1

suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt)

suite.coordinator.CommitBlock(suite.chainB)

suite.Require().NoError(path.EndpointA.UpdateClient())

upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey)
},
expError: types.ErrInvalidUpgradeSequence,
},
{
name: "connection not found",
malleate: func() {
Expand Down Expand Up @@ -1732,7 +1763,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeCancel() {
suite.Require().True(ok)

channel = path.EndpointA.GetChannel()
channel.State = types.FLUSHCOMPLETE
channel.State = types.FLUSHING
path.EndpointA.SetChannel(channel)

tc.malleate()
Expand Down
15 changes: 12 additions & 3 deletions modules/core/04-channel/types/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,17 @@ func (u *UpgradeError) GetErrorReceipt() ErrorReceipt {
}
}

// IsUpgradeError returns true if err is of type UpgradeError, otherwise false.
// IsUpgradeError returns true if err is of type UpgradeError or contained
// in the error chain of err and false otherwise.
func IsUpgradeError(err error) bool {
_, ok := err.(*UpgradeError)
return ok
for {
_, ok := err.(*UpgradeError)
if ok {
return true
}

if err = errors.Unwrap(err); err == nil {
return false
}
}
}
31 changes: 31 additions & 0 deletions modules/core/04-channel/types/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package types_test

import (
"errors"
"fmt"

errorsmod "cosmossdk.io/errors"

Expand Down Expand Up @@ -185,13 +186,43 @@ func (suite *TypesTestSuite) TestIsUpgradeError() {
func() {},
true,
},
{
"true with wrapped upgrade err",
func() {
upgradeError := types.NewUpgradeError(1, types.ErrInvalidChannel)
err = errorsmod.Wrap(upgradeError, "wrapped upgrade error")
},
true,
},
{
"true with Errorf wrapped upgrade error",
func() {
err = fmt.Errorf("%w", types.NewUpgradeError(1, types.ErrInvalidChannel))
},
true,
},
{
"true with nested Errorf wrapped upgrade error",
func() {
err = fmt.Errorf("%w", fmt.Errorf("%w", fmt.Errorf("%w", types.NewUpgradeError(1, types.ErrInvalidChannel))))
},
true,
},
{
"false with non upgrade error",
func() {
err = errors.New("error")
},
false,
},
{
"false with wrapped non upgrade error",
func() {
randomErr := errors.New("error")
err = errorsmod.Wrap(randomErr, "wrapped random error")
},
false,
},
{
"false with nil error",
func() {
Expand Down
27 changes: 11 additions & 16 deletions modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh
upgradeVersion, err := cbs.OnChanUpgradeTry(ctx, msg.PortId, msg.ChannelId, upgrade.Fields.Ordering, upgrade.Fields.ConnectionHops, upgrade.Fields.Version)
if err != nil {
ctx.Logger().Error("channel upgrade try callback failed", "port-id", msg.PortId, "channel-id", msg.ChannelId, "error", err.Error())
return nil, err
return nil, errorsmod.Wrapf(err, "channel upgrade try callback failed for port ID: %s, channel ID: %s", msg.PortId, msg.ChannelId)
}

channel, upgrade = k.ChannelKeeper.WriteUpgradeTryChannel(ctx, msg.PortId, msg.ChannelId, upgrade, upgradeVersion)
Expand Down Expand Up @@ -873,27 +873,26 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh

err = k.ChannelKeeper.ChanUpgradeAck(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
errChanUpgradeFailed := errorsmod.Wrap(err, "channel upgrade ack failed")
ctx.Logger().Error("channel upgrade ack failed", "error", errChanUpgradeFailed)
ctx.Logger().Error("channel upgrade ack failed", "error", errorsmod.Wrap(err, "channel upgrade ack failed"))
if channeltypes.IsUpgradeError(err) {
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)
cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)

// NOTE: a FAILURE result is returned to the client and an error receipt is written to state.
// This signals to the relayer to begin the cancel upgrade handshake subprotocol.
return &channeltypes.MsgChannelUpgradeAckResponse{Result: channeltypes.FAILURE}, nil
}

// NOTE: an error is returned to baseapp and transaction state is not committed.
return nil, errChanUpgradeFailed
return nil, errorsmod.Wrap(err, "channel upgrade ack failed")
}

cacheCtx, writeFn := ctx.CacheContext()
err = cbs.OnChanUpgradeAck(cacheCtx, msg.PortId, msg.ChannelId, msg.CounterpartyUpgrade.Fields.Version)
if err != nil {
ctx.Logger().Error("channel upgrade ack callback failed", "port-id", msg.PortId, "channel-id", msg.ChannelId, "error", err.Error())
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)
cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)

return &channeltypes.MsgChannelUpgradeAckResponse{Result: channeltypes.FAILURE}, nil
}
Expand Down Expand Up @@ -936,8 +935,8 @@ func (k Keeper) ChannelUpgradeConfirm(goCtx context.Context, msg *channeltypes.M
if err != nil {
ctx.Logger().Error("channel upgrade confirm failed", "error", errorsmod.Wrap(err, "channel upgrade confirm failed"))
if channeltypes.IsUpgradeError(err) {
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)
cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)

// NOTE: a FAILURE result is returned to the client and an error receipt is written to state.
// This signals to the relayer to begin the cancel upgrade handshake subprotocol.
Expand All @@ -954,14 +953,14 @@ func (k Keeper) ChannelUpgradeConfirm(goCtx context.Context, msg *channeltypes.M

// Move channel to OPEN state if both chains have finished flushing in-flight packets.
// Counterparty channel state has been verified in ChanUpgradeConfirm.
if msg.CounterpartyChannelState == channeltypes.FLUSHCOMPLETE && !k.ChannelKeeper.HasInflightPackets(ctx, msg.PortId, msg.ChannelId) {
if channel.State == channeltypes.FLUSHCOMPLETE && msg.CounterpartyChannelState == channeltypes.FLUSHCOMPLETE {
upgrade, found := k.ChannelKeeper.GetUpgrade(ctx, msg.PortId, msg.ChannelId)
if !found {
return nil, errorsmod.Wrapf(channeltypes.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", msg.PortId, msg.ChannelId)
}

channel := k.ChannelKeeper.WriteUpgradeOpenChannel(ctx, msg.PortId, msg.ChannelId)
cbs.OnChanUpgradeOpen(ctx, msg.PortId, msg.ChannelId, upgrade.Fields.Ordering, upgrade.Fields.ConnectionHops, upgrade.Fields.Version)
channel := k.ChannelKeeper.WriteUpgradeOpenChannel(ctx, msg.PortId, msg.ChannelId)

ctx.Logger().Info("channel upgrade open succeeded", "port-id", msg.PortId, "channel-id", msg.ChannelId)
keeper.EmitChannelUpgradeOpenEvent(ctx, msg.PortId, msg.ChannelId, channel)
Expand Down Expand Up @@ -1005,7 +1004,6 @@ func (k Keeper) ChannelUpgradeOpen(goCtx context.Context, msg *channeltypes.MsgC
}

cbs.OnChanUpgradeOpen(ctx, msg.PortId, msg.ChannelId, upgrade.Fields.Ordering, upgrade.Fields.ConnectionHops, upgrade.Fields.Version)

channel := k.ChannelKeeper.WriteUpgradeOpenChannel(ctx, msg.PortId, msg.ChannelId)

ctx.Logger().Info("channel upgrade open succeeded", "port-id", msg.PortId, "channel-id", msg.ChannelId)
Expand Down Expand Up @@ -1042,9 +1040,8 @@ func (k Keeper) ChannelUpgradeTimeout(goCtx context.Context, msg *channeltypes.M
return nil, errorsmod.Wrapf(err, "could not timeout upgrade for channel: %s", msg.ChannelId)
}

channel, upgrade := k.ChannelKeeper.WriteUpgradeTimeoutChannel(ctx, msg.PortId, msg.ChannelId)

cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)
channel, upgrade := k.ChannelKeeper.WriteUpgradeTimeoutChannel(ctx, msg.PortId, msg.ChannelId)

ctx.Logger().Info("channel upgrade timeout callback succeeded: portID %s, channelID %s", msg.PortId, msg.ChannelId)
keeper.EmitChannelUpgradeTimeoutEvent(ctx, msg.PortId, msg.ChannelId, channel, upgrade)
Expand Down Expand Up @@ -1089,9 +1086,8 @@ func (k Keeper) ChannelUpgradeCancel(goCtx context.Context, msg *channeltypes.Ms
return nil, errorsmod.Wrapf(channeltypes.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", msg.PortId, msg.ChannelId)
}

k.ChannelKeeper.WriteUpgradeCancelChannel(ctx, msg.PortId, msg.ChannelId, channel.UpgradeSequence)

cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)
k.ChannelKeeper.WriteUpgradeCancelChannel(ctx, msg.PortId, msg.ChannelId, channel.UpgradeSequence)

ctx.Logger().Info("channel upgrade cancel succeeded", "port-id", msg.PortId, "channel-id", msg.ChannelId)

Expand All @@ -1111,9 +1107,8 @@ func (k Keeper) ChannelUpgradeCancel(goCtx context.Context, msg *channeltypes.Ms
return nil, errorsmod.Wrapf(channeltypes.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", msg.PortId, msg.ChannelId)
}

k.ChannelKeeper.WriteUpgradeCancelChannel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt.Sequence)

cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)
k.ChannelKeeper.WriteUpgradeCancelChannel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt.Sequence)

ctx.Logger().Info("channel upgrade cancel succeeded", "port-id", msg.PortId, "channel-id", msg.ChannelId)

Expand Down
Loading

0 comments on commit 41501a2

Please sign in to comment.