diff --git a/modules/core/04-channel/keeper/timeout.go b/modules/core/04-channel/keeper/timeout.go index b954ed4e8ad..2cf23767c73 100644 --- a/modules/core/04-channel/keeper/timeout.go +++ b/modules/core/04-channel/keeper/timeout.go @@ -153,6 +153,11 @@ func (k Keeper) TimeoutExecuted( k.deletePacketCommitment(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) + if channel.FlushStatus == types.FLUSHING && !k.hasInflightPackets(ctx, packet.GetSourcePort(), packet.GetSourceChannel()) { + channel.FlushStatus = types.FLUSHCOMPLETE + k.SetChannel(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), channel) + } + if channel.Ordering == types.ORDERED { channel.Close() k.SetChannel(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), channel) diff --git a/modules/core/04-channel/keeper/timeout_test.go b/modules/core/04-channel/keeper/timeout_test.go index e2f8ebc9305..1d3b33b0936 100644 --- a/modules/core/04-channel/keeper/timeout_test.go +++ b/modules/core/04-channel/keeper/timeout_test.go @@ -13,6 +13,7 @@ import ( host "github.com/cosmos/ibc-go/v7/modules/core/24-host" "github.com/cosmos/ibc-go/v7/modules/core/exported" ibctesting "github.com/cosmos/ibc-go/v7/testing" + "github.com/cosmos/ibc-go/v7/testing/mock" ) // TestTimeoutPacket test the TimeoutPacket call on chainA by ensuring the timeout has passed @@ -246,8 +247,9 @@ func (suite *KeeperTestSuite) TestTimeoutPacket() { } } -// TestTimeoutExectued verifies that packet commitments are deleted on chainA after the -// channel capabilities are verified. +// TestTimeoutExecuted verifies that packet commitments are deleted on chainA after the +// channel capabilities are verified. In addition, the test verifies that the channel state +// after a timeout is updated accordingly. func (suite *KeeperTestSuite) TestTimeoutExecuted() { var ( path *ibctesting.Path @@ -255,38 +257,111 @@ func (suite *KeeperTestSuite) TestTimeoutExecuted() { chanCap *capabilitytypes.Capability ) - testCases := []testCase{ - {"success ORDERED", func() { - path.SetChannelOrdered() - suite.coordinator.Setup(path) + testCases := []struct { + msg string + malleate func() + expResult func(packetCommitment []byte, err error) + }{ + { + "success ORDERED", + func() { + path.SetChannelOrdered() + suite.coordinator.Setup(path) + + timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext()) + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano()) + + sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData) + suite.Require().NoError(err) - timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext()) - timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano()) + packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, timeoutTimestamp) + chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + }, + func(packetCommitment []byte, err error) { + suite.Require().NoError(err) + suite.Require().Nil(packetCommitment) + + // Check channel has been closed and flush status is set to NOTINFLUSH + channel := path.EndpointA.GetChannel() + suite.Require().Equal(channel.State, types.CLOSED) + suite.Require().Equal(channel.FlushStatus, types.NOTINFLUSH) + }, + }, + { + "success UNORDERED channel in FLUSHING state", + func() { + suite.coordinator.Setup(path) + + timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext()) + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano()) + + sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData) + suite.Require().NoError(err) - sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData) - suite.Require().NoError(err) + packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, timeoutTimestamp) + chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) - packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, timeoutTimestamp) - chanCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) - }, true}, - {"channel not found", func() { - // use wrong channel naming - suite.coordinator.Setup(path) - packet = types.NewPacket(ibctesting.MockPacketData, 1, ibctesting.InvalidID, ibctesting.InvalidID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp) - }, false}, - {"incorrect capability ORDERED", func() { - path.SetChannelOrdered() - suite.coordinator.Setup(path) + // Move channel to FLUSHING state + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion - timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext()) - timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano()) + err = path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) - sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData) - suite.Require().NoError(err) + err = path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) - packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, timeoutTimestamp) - chanCap = capabilitytypes.NewCapability(100) - }, false}, + err = path.EndpointA.ChanUpgradeAck() + suite.Require().NoError(err) + }, + func(packetCommitment []byte, err error) { + suite.Require().NoError(err) + suite.Require().Nil(packetCommitment) + + // Check flush status has been set to FLUSHCOMPLETE + channel := path.EndpointA.GetChannel() + suite.Require().Equal(channel.State, types.ACKUPGRADE) + suite.Require().Equal(channel.FlushStatus, types.FLUSHCOMPLETE) + }, + }, + { + "channel not found", + func() { + // use wrong channel naming + suite.coordinator.Setup(path) + packet = types.NewPacket(ibctesting.MockPacketData, 1, ibctesting.InvalidID, ibctesting.InvalidID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp) + }, + func(packetCommitment []byte, err error) { + suite.Require().Error(err) + suite.Require().ErrorIs(err, types.ErrChannelNotFound) + + // packet never sent. + suite.Require().Nil(packetCommitment) + }, + }, + { + "incorrect capability ORDERED", + func() { + path.SetChannelOrdered() + suite.coordinator.Setup(path) + + timeoutHeight := clienttypes.GetSelfHeight(suite.chainB.GetContext()) + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().UnixNano()) + + sequence, err := path.EndpointA.SendPacket(timeoutHeight, timeoutTimestamp, ibctesting.MockPacketData) + suite.Require().NoError(err) + + packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, timeoutTimestamp) + chanCap = capabilitytypes.NewCapability(100) + }, + func(packetCommitment []byte, err error) { + suite.Require().Error(err) + suite.Require().ErrorIs(err, types.ErrChannelCapabilityNotFound) + + // packet sent, never deleted. + suite.Require().NotNil(packetCommitment) + }, + }, } for i, tc := range testCases { @@ -300,12 +375,7 @@ func (suite *KeeperTestSuite) TestTimeoutExecuted() { err := suite.chainA.App.GetIBCKeeper().ChannelKeeper.TimeoutExecuted(suite.chainA.GetContext(), chanCap, packet) pc := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetPacketCommitment(suite.chainA.GetContext(), packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) - if tc.expPass { - suite.NoError(err) - suite.Nil(pc) - } else { - suite.Error(err) - } + tc.expResult(pc, err) }) } }