Skip to content

Commit

Permalink
add/update multihop timeout logic with basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dshiell committed Jan 13, 2023
1 parent f979196 commit 581eccc
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 175 deletions.
190 changes: 33 additions & 157 deletions modules/core/04-channel/keeper/multihop_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"

clienttypes "github.com/cosmos/ibc-go/v6/modules/core/02-client/types"
"github.com/cosmos/ibc-go/v6/modules/core/04-channel/types"
Expand Down Expand Up @@ -95,163 +96,35 @@ func (suite *MultihopTestSuite) TestTimeoutPacket() {

// TestTimeoutOnClose tests the call TimeoutOnClose on chainA by closing the corresponding
// channel on chainB after the packet commitment has been created.
/*func (suite *KeeperTestSuite) TestTimeoutOnClose() {
func (suite *MultihopTestSuite) TestTimeoutOnClose() {
var (
path *ibctesting.Path
packet types.Packet
packet *types.Packet
chanCap *capabilitytypes.Capability
nextSeqRecv uint64
ordered bool
err error
)

testCases := []testCase{
{"success: ORDERED", func() {
ordered = true
path.SetChannelOrdered()
suite.coordinator.Setup(path)
timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext())
timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano())
testCases := []timeoutTestCase{
{"success: ORDERED", true, func() {
timeoutHeight := clienttypes.GetSelfHeight(suite.Z().Chain.GetContext())
timeoutTimestamp := uint64(suite.Z().Chain.GetContext().BlockTime().UnixNano())

sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData)
packet, err = suite.A().SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
path.EndpointB.SetChannelClosed()
suite.Z().SetChannelClosed()
// need to update chainA's client representing chainB to prove missing ack
path.EndpointA.UpdateClient()
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, timeoutTimestamp)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.A().UpdateAllClients()
chanCap = suite.A().Chain.GetChannelCapability(suite.A().ChannelConfig.PortID, suite.A().ChannelID)
}, true},
{"success: UNORDERED", func() {
ordered = false
suite.coordinator.Setup(path)
timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext())
sequence, err := path.EndpointA.SendPacket(timeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData)
{"success: UNORDERED", false, func() {
timeoutHeight := clienttypes.GetSelfHeight(suite.Z().Chain.GetContext())
packet, err = suite.A().SendPacket(timeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
path.EndpointB.SetChannelClosed()
suite.Z().SetChannelClosed()
// need to update chainA's client representing chainB to prove missing ack
path.EndpointA.UpdateClient()
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, disabledTimeoutTimestamp)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.A().UpdateAllClients()
chanCap = suite.A().Chain.GetChannelCapability(suite.A().ChannelConfig.PortID, suite.A().ChannelID)
}, true},
{"channel not found", func() {
// use wrong channel naming
suite.coordinator.Setup(path)
packet = types.NewPacket(ibctesting.MockPacketData, 1, ibctesting.InvalidID, ibctesting.InvalidID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp)
}, false},
{"packet dest port ≠ channel counterparty port", func() {
suite.coordinator.Setup(path)
// use wrong port for dest
packet = types.NewPacket(ibctesting.MockPacketData, 1, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, ibctesting.InvalidID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false},
{"packet dest channel ID ≠ channel counterparty channel ID", func() {
suite.coordinator.Setup(path)
// use wrong channel for dest
packet = types.NewPacket(ibctesting.MockPacketData, 1, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, ibctesting.InvalidID, defaultTimeoutHeight, disabledTimeoutTimestamp)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false},
{"connection not found", func() {
// pass channel check
suite.chainA.App.GetIBCKeeper().ChannelKeeper.SetChannel(
suite.chainA.GetContext(),
path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID,
types.NewChannel(types.OPEN, types.ORDERED, types.NewCounterparty(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID), []string{connIDA}, path.EndpointA.ChannelConfig.Version),
)
packet = types.NewPacket(ibctesting.MockPacketData, 1, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp)
// create chancap
suite.chainA.CreateChannelCapability(suite.chainA.GetSimApp().ScopedIBCMockKeeper, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false},
{"packet hasn't been sent ORDERED", func() {
path.SetChannelOrdered()
suite.coordinator.Setup(path)
packet = types.NewPacket(ibctesting.MockPacketData, 1, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, clienttypes.GetSelfHeight(suite.chainB.GetContext()), uint64(suite.chainB.GetContext().BlockTime().UnixNano()))
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false},
{"packet already received ORDERED", func() {
path.SetChannelOrdered()
nextSeqRecv = 2
ordered = true
suite.coordinator.Setup(path)
timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext())
timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano())
sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
path.EndpointB.SetChannelClosed()
// need to update chainA's client representing chainB to prove missing ack
path.EndpointA.UpdateClient()
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, timeoutTimestamp)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false},
{"channel verification failed ORDERED", func() {
ordered = true
path.SetChannelOrdered()
suite.coordinator.Setup(path)
timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext())
timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano())
sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, timeoutTimestamp)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false},
{"next seq receive verification failed ORDERED", func() {
// set ordered to false providing the wrong proof for ORDERED case
ordered = false
path.SetChannelOrdered()
suite.coordinator.Setup(path)
timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext())
timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano())
sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
path.EndpointB.SetChannelClosed()
path.EndpointA.UpdateClient()
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, clienttypes.GetSelfHeight(suite.chainB.GetContext()), uint64(suite.chainB.GetContext().BlockTime().UnixNano()))
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false},
{"packet ack verification failed", func() {
// set ordered to true providing the wrong proof for UNORDERED case
ordered = true
suite.coordinator.Setup(path)
timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext())
sequence, err := path.EndpointA.SendPacket(timeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
path.EndpointB.SetChannelClosed()
path.EndpointA.UpdateClient()
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, disabledTimeoutTimestamp)
chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
}, false},
{"channel capability not found ORDERED", func() {
ordered = true
path.SetChannelOrdered()
suite.coordinator.Setup(path)
timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext())
timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano())
sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
path.EndpointB.SetChannelClosed()
// need to update chainA's client representing chainB to prove missing ack
path.EndpointA.UpdateClient()
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, clienttypes.GetSelfHeight(suite.chainB.GetContext()), uint64(suite.chainB.GetContext().BlockTime().UnixNano()))
chanCap = capabilitytypes.NewCapability(100)
}, false},
}

for i, tc := range testCases {
Expand All @@ -260,24 +133,28 @@ func (suite *MultihopTestSuite) TestTimeoutPacket() {
var proof []byte

suite.SetupTest() // reset
nextSeqRecv = 1 // must be explicitly changed
path = ibctesting.NewPath(suite.chainA, suite.chainB)
if tc.orderedChannel {
suite.chanPath.SetChannelOrdered()
}
nextSeqRecv = 1 // must be explicitly changed
suite.SetupChannels() // setup multihop channels

tc.malleate()

channelKey := host.ChannelKey(packet.GetDestPort(), packet.GetDestChannel())
unorderedPacketKey := host.PacketReceiptKey(packet.GetDestPort(), packet.GetDestChannel(), packet.GetSequence())
orderedPacketKey := host.NextSequenceRecvKey(packet.GetDestPort(), packet.GetDestChannel())
proofClosed := suite.Z().QueryChannelProof()
proofHeight := suite.A().GetClientState().GetLatestHeight()

proofClosed, proofHeight := suite.chainB.QueryProof(channelKey)
if ordered {
proof, _ = suite.chainB.QueryProof(orderedPacketKey)
if tc.orderedChannel {
key := host.NextSequenceRecvKey(packet.GetDestPort(), packet.GetDestChannel())
val := sdk.Uint64ToBigEndian(nextSeqRecv)
proof = suite.Z().QueryMultihopProof(key, val, fmt.Sprintf("ordered packet timeout: %s", packet.String()))
} else {
proof, _ = suite.chainB.QueryProof(unorderedPacketKey)
// proof of absence of packet receipt
key := host.PacketReceiptKey(packet.GetDestPort(), packet.GetDestChannel(), packet.GetSequence())
proof = suite.Z().QueryMultihopProof(key, []byte(nil), fmt.Sprintf("unordered packet timeout: %s", packet.String()))
}

err := suite.chainA.App.GetIBCKeeper().ChannelKeeper.TimeoutOnClose(suite.chainA.GetContext(), chanCap, packet, proof, proofClosed, proofHeight, nextSeqRecv)
err := suite.A().Chain.App.GetIBCKeeper().ChannelKeeper.TimeoutOnClose(suite.A().Chain.GetContext(), chanCap, packet, proof, proofClosed, proofHeight, nextSeqRecv)

if tc.expPass {
suite.Require().NoError(err)
Expand All @@ -287,4 +164,3 @@ func (suite *MultihopTestSuite) TestTimeoutPacket() {
})
}
}
*/
119 changes: 101 additions & 18 deletions modules/core/04-channel/keeper/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ func (k Keeper) TimeoutOnClose(
)
}

// TODO: add multihop logic

connectionEnd, found := k.connectionKeeper.GetConnection(ctx, channel.ConnectionHops[0])
if !found {
return sdkerrors.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
Expand All @@ -299,20 +297,65 @@ func (k Keeper) TimeoutOnClose(
return sdkerrors.Wrapf(types.ErrInvalidPacket, "packet commitment bytes are not equal: got (%v), expected (%v)", commitment, packetCommitment)
}

counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()}
var mProof types.MsgMultihopProofs
if len(channel.ConnectionHops) > 1 {
if err := k.cdc.Unmarshal(proofClosed, &mProof); err != nil {
return err
}
}

// connectionHops Z --> A
var counterpartyHops []string
if len(channel.ConnectionHops) > 1 {
var err error
counterpartyHops, err = mProof.GetCounterpartyHops(k.cdc, &connectionEnd)
if err != nil {
return err
}
} else {
counterpartyHops = []string{connectionEnd.GetCounterparty().GetConnectionID()}
}

counterparty := types.NewCounterparty(packet.GetSourcePort(), packet.GetSourceChannel())
expectedChannel := types.NewChannel(
types.CLOSED, channel.Ordering, counterparty, counterpartyHops, channel.Version,
)

// check that the opposing channel end has closed
if err := k.connectionKeeper.VerifyChannelState(
ctx, connectionEnd, proofHeight, proofClosed,
channel.Counterparty.PortId, channel.Counterparty.ChannelId,
expectedChannel,
); err != nil {
return err
if len(channel.ConnectionHops) > 1 {

// expected value bytes
value, err := expectedChannel.Marshal()
if err != nil {
return err
}

// verify multihop proof
consensusState, found := k.clientKeeper.GetClientConsensusState(ctx, connectionEnd.ClientId, proofHeight)
if !found {
return sdkerrors.Wrapf(clienttypes.ErrConsensusStateNotFound,
"consensus state not found for client id: %s", connectionEnd.ClientId)
}

multihopConnectionEnd, err := mProof.GetMultihopConnectionEnd(k.cdc)
if err != nil {
return err
}

key := host.ChannelPath(counterparty.PortId, counterparty.ChannelId)
prefix := multihopConnectionEnd.GetCounterparty().GetPrefix()

if err := mh.VerifyMultihopProof(k.cdc, consensusState, channel.ConnectionHops, proofClosed, prefix, key, value); err != nil {
return err
}
} else {
// check that the opposing channel end has closed
if err := k.connectionKeeper.VerifyChannelState(
ctx, connectionEnd, proofHeight, proofClosed,
channel.Counterparty.PortId, channel.Counterparty.ChannelId,
expectedChannel,
); err != nil {
return err
}
}

var err error
Expand All @@ -324,15 +367,55 @@ func (k Keeper) TimeoutOnClose(
}

// check that the recv sequence is as claimed
err = k.connectionKeeper.VerifyNextSequenceRecv(
ctx, connectionEnd, proofHeight, proof,
packet.GetDestPort(), packet.GetDestChannel(), nextSequenceRecv,
)
if len(channel.ConnectionHops) > 1 {
// verify multihop proof
consensusState, found := k.clientKeeper.GetClientConsensusState(ctx, connectionEnd.ClientId, proofHeight)
if !found {
err = sdkerrors.Wrapf(clienttypes.ErrConsensusStateNotFound,
"consensus state not found for client id: %s", connectionEnd.ClientId)
}
key := host.NextSequenceRecvPath(packet.GetSourcePort(), packet.GetSourceChannel())
prefix := connectionEnd.GetCounterparty().GetPrefix()
val := sdk.Uint64ToBigEndian(nextSequenceRecv)
err = mh.VerifyMultihopProof(k.cdc, consensusState, channel.ConnectionHops, proof, prefix, key, val)
} else {
err = k.connectionKeeper.VerifyNextSequenceRecv(
ctx, connectionEnd, proofHeight, proof,
packet.GetDestPort(), packet.GetDestChannel(), nextSequenceRecv,
)
}
case types.UNORDERED:
err = k.connectionKeeper.VerifyPacketReceiptAbsence(
ctx, connectionEnd, proofHeight, proof,
packet.GetDestPort(), packet.GetDestChannel(), packet.GetSequence(),
)
if len(channel.ConnectionHops) > 1 {
// verify multihop proof
consensusState, found := k.clientKeeper.GetClientConsensusState(ctx, connectionEnd.ClientId, proofHeight)
if !found {
err = sdkerrors.Wrapf(clienttypes.ErrConsensusStateNotFound,
"consensus state not found for client id: %s", connectionEnd.ClientId)
}
key := host.PacketReceiptPath(
packet.GetSourcePort(),
packet.GetSourceChannel(),
packet.GetSequence(),
)
prefix := connectionEnd.GetCounterparty().GetPrefix()
var value []byte = nil
clientState, found := k.clientKeeper.GetClientState(ctx, connectionEnd.Counterparty.ClientId)

////////////////////////////////////////////////////////////////////////////////////////////////
// NOTE: If the clientState is found and type is virtual the do a timeout inclusion check. This
// is a hack to work around the fact that virtual chains may not support non-inclusion proofs.
////////////////////////////////////////////////////////////////////////////////////////////////
if found && clientState.ClientType() == "virtual" {
value = commitment
}

err = mh.VerifyMultihopProof(k.cdc, consensusState, channel.ConnectionHops, proof, prefix, key, value)
} else {
err = k.connectionKeeper.VerifyPacketReceiptAbsence(
ctx, connectionEnd, proofHeight, proof,
packet.GetDestPort(), packet.GetDestChannel(), packet.GetSequence(),
)
}
default:
panic(sdkerrors.Wrapf(types.ErrInvalidChannelOrdering, channel.Ordering.String()))
}
Expand Down

0 comments on commit 581eccc

Please sign in to comment.