From 2d210e234ac9445ff13408efd4b56e569c1e725b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?colin=20axn=C3=A9r?= <25233464+colin-axner@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:18:03 +0100 Subject: [PATCH] refactor!: remove GetState() on connection interface (#5769) * rm: GetState() on connection interface * lint * lint --- .../interchain_accounts/upgrades_test.go | 7 +-- e2e/tests/upgrades/upgrade_test.go | 3 +- e2e/testsuite/sanitize/messages.go | 3 +- .../core/03-connection/types/connection.go | 5 -- modules/core/04-channel/keeper/handshake.go | 35 ++++--------- modules/core/04-channel/keeper/packet.go | 14 ++--- modules/core/04-channel/keeper/upgrade.go | 51 ++++++++----------- modules/core/exported/connection.go | 1 - 8 files changed, 40 insertions(+), 79 deletions(-) diff --git a/e2e/tests/interchain_accounts/upgrades_test.go b/e2e/tests/interchain_accounts/upgrades_test.go index 46342dcdcbd..a14641b8336 100644 --- a/e2e/tests/interchain_accounts/upgrades_test.go +++ b/e2e/tests/interchain_accounts/upgrades_test.go @@ -7,16 +7,17 @@ import ( "testing" "time" - sdkmath "cosmossdk.io/math" + "github.com/cosmos/gogoproto/proto" "github.com/strangelove-ventures/interchaintest/v8" "github.com/strangelove-ventures/interchaintest/v8/ibc" test "github.com/strangelove-ventures/interchaintest/v8/testutil" testifysuite "github.com/stretchr/testify/suite" + sdkmath "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" - "github.com/cosmos/gogoproto/proto" "github.com/cosmos/ibc-go/e2e/testsuite" "github.com/cosmos/ibc-go/e2e/testvalues" @@ -162,7 +163,7 @@ func (s *InterchainAccountsChannelUpgradesTestSuite) TestChannelUpgrade_ICAChann Memo: "e2e", } - timeout := uint64(1) + timeout := uint64(1) msgSendTx := controllertypes.NewMsgSendTx(controllerAddress, ibctesting.FirstConnectionID, timeout, packetData) resp := s.BroadcastMessages( diff --git a/e2e/tests/upgrades/upgrade_test.go b/e2e/tests/upgrades/upgrade_test.go index 4e00bc5e07b..6dd0219efc3 100644 --- a/e2e/tests/upgrades/upgrade_test.go +++ b/e2e/tests/upgrades/upgrade_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types" - "github.com/cosmos/gogoproto/proto" interchaintest "github.com/strangelove-ventures/interchaintest/v8" "github.com/strangelove-ventures/interchaintest/v8/chain/cosmos" @@ -29,6 +27,7 @@ import ( "github.com/cosmos/ibc-go/e2e/testsuite" "github.com/cosmos/ibc-go/e2e/testvalues" feetypes "github.com/cosmos/ibc-go/v8/modules/apps/29-fee/types" + transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types" v7migrations "github.com/cosmos/ibc-go/v8/modules/core/02-client/migrations/v7" clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types" connectiontypes "github.com/cosmos/ibc-go/v8/modules/core/03-connection/types" diff --git a/e2e/testsuite/sanitize/messages.go b/e2e/testsuite/sanitize/messages.go index ac475314f4e..7cf901d6913 100644 --- a/e2e/testsuite/sanitize/messages.go +++ b/e2e/testsuite/sanitize/messages.go @@ -5,10 +5,9 @@ import ( govtypesv1 "github.com/cosmos/cosmos-sdk/x/gov/types/v1" grouptypes "github.com/cosmos/cosmos-sdk/x/group" + "github.com/cosmos/ibc-go/e2e/semverutil" icacontrollertypes "github.com/cosmos/ibc-go/v8/modules/apps/27-interchain-accounts/controller/types" channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" - - "github.com/cosmos/ibc-go/e2e/semverutil" ) var ( diff --git a/modules/core/03-connection/types/connection.go b/modules/core/03-connection/types/connection.go index d9c1a554437..e341088d7e7 100644 --- a/modules/core/03-connection/types/connection.go +++ b/modules/core/03-connection/types/connection.go @@ -22,11 +22,6 @@ func NewConnectionEnd(state State, clientID string, counterparty Counterparty, v } } -// GetState implements the Connection interface -func (c ConnectionEnd) GetState() int32 { - return int32(c.State) -} - // GetClientID implements the Connection interface func (c ConnectionEnd) GetClientID() string { return c.ClientId diff --git a/modules/core/04-channel/keeper/handshake.go b/modules/core/04-channel/keeper/handshake.go index fe647727676..eebefb94717 100644 --- a/modules/core/04-channel/keeper/handshake.go +++ b/modules/core/04-channel/keeper/handshake.go @@ -131,11 +131,8 @@ func (k Keeper) ChanOpenTry( return "", nil, errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, connectionHops[0]) } - if connectionEnd.GetState() != int32(connectiontypes.OPEN) { - return "", nil, errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(), - ) + if connectionEnd.State != connectiontypes.OPEN { + return "", nil, errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State) } getVersions := connectionEnd.GetVersions() @@ -242,11 +239,8 @@ func (k Keeper) ChanOpenAck( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connectionEnd.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(), - ) + if connectionEnd.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State) } counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()} @@ -321,11 +315,8 @@ func (k Keeper) ChanOpenConfirm( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connectionEnd.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(), - ) + if connectionEnd.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State) } counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()} @@ -405,11 +396,8 @@ func (k Keeper) ChanCloseInit( return errorsmod.Wrapf(clienttypes.ErrClientNotActive, "client (%s) status is %s", connectionEnd.ClientId, status) } - if connectionEnd.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(), - ) + if connectionEnd.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State) } k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", channel.State.String(), "new-state", types.CLOSED.String()) @@ -453,11 +441,8 @@ func (k Keeper) ChanCloseConfirm( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connectionEnd.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(), - ) + if connectionEnd.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State) } counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()} diff --git a/modules/core/04-channel/keeper/packet.go b/modules/core/04-channel/keeper/packet.go index f6a68c1e123..794590cfad9 100644 --- a/modules/core/04-channel/keeper/packet.go +++ b/modules/core/04-channel/keeper/packet.go @@ -167,11 +167,8 @@ func (k Keeper) RecvPacket( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connectionEnd.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(), - ) + if connectionEnd.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State) } // check if packet timed out by comparing it with the latest height of the chain @@ -400,11 +397,8 @@ func (k Keeper) AcknowledgePacket( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connectionEnd.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(), - ) + if connectionEnd.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State) } commitment := k.GetPacketCommitment(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 1dae7f7a862..8e3a9e5ede0 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -98,10 +98,8 @@ func (k Keeper) ChanUpgradeTry( return types.Channel{}, types.Upgrade{}, errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connection.GetState() != int32(connectiontypes.OPEN) { - return types.Channel{}, types.Upgrade{}, errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(), - ) + if connection.State != connectiontypes.OPEN { + return types.Channel{}, types.Upgrade{}, errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State) } // construct expected counterparty channel from information in state @@ -276,8 +274,8 @@ func (k Keeper) ChanUpgradeAck( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connection.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String()) + if connection.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State) } counterpartyHops := []string{connection.GetCounterparty().GetConnectionID()} @@ -412,8 +410,8 @@ func (k Keeper) ChanUpgradeConfirm( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connection.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String()) + if connection.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State) } counterpartyHops := []string{connection.GetCounterparty().GetConnectionID()} @@ -507,8 +505,8 @@ func (k Keeper) ChanUpgradeOpen( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connection.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String()) + if connection.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State) } var counterpartyChannel types.Channel @@ -524,8 +522,8 @@ func (k Keeper) ChanUpgradeOpen( return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, upgrade.Fields.ConnectionHops[0]) } - if upgradeConnection.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(upgradeConnection.GetState()).String()) + if upgradeConnection.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", upgradeConnection.State) } // The counterparty upgrade sequence must be greater than or equal to @@ -675,11 +673,8 @@ func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, err return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connection.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(), - ) + if connection.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State) } if err := k.connectionKeeper.VerifyChannelUpgradeError( @@ -746,11 +741,8 @@ func (k Keeper) ChanUpgradeTimeout( ) } - if connection.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(), - ) + if connection.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State) } proofTimestamp, err := k.connectionKeeper.GetTimestampAtHeight(ctx, connection, proofHeight) @@ -853,8 +845,8 @@ func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string, upgrade return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0]) } - if connection.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String()) + if connection.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State) } channel.State = types.FLUSHING @@ -896,10 +888,10 @@ func (k Keeper) checkForUpgradeCompatibility(ctx sdk.Context, upgradeFields, cou return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, upgradeFields.ConnectionHops[0]) } - if connection.GetState() != int32(connectiontypes.OPEN) { + if connection.State != connectiontypes.OPEN { // NOTE: this error is expected to be unreachable as the proposed upgrade connectionID should have been // validated in the upgrade INIT and TRY handlers - return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "expected proposed connection to be OPEN (got %s)", connectiontypes.State(connection.GetState()).String()) + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "expected proposed connection to be OPEN (got %s)", connection.State) } // connectionHops can change in a channelUpgrade, however both sides must still be each other's counterparty. @@ -930,11 +922,8 @@ func (k Keeper) validateSelfUpgradeFields(ctx sdk.Context, proposedUpgrade types return errorsmod.Wrapf(connectiontypes.ErrConnectionNotFound, "failed to retrieve connection: %s", connectionID) } - if connection.GetState() != int32(connectiontypes.OPEN) { - return errorsmod.Wrapf( - connectiontypes.ErrInvalidConnectionState, - "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(), - ) + if connection.State != connectiontypes.OPEN { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State) } getVersions := connection.GetVersions() diff --git a/modules/core/exported/connection.go b/modules/core/exported/connection.go index a8341a255a9..f4353ffc278 100644 --- a/modules/core/exported/connection.go +++ b/modules/core/exported/connection.go @@ -6,7 +6,6 @@ const LocalhostConnectionID string = "connection-localhost" // ConnectionI describes the required methods for a connection. type ConnectionI interface { GetClientID() string - GetState() int32 GetCounterparty() CounterpartyConnectionI GetDelayPeriod() uint64 ValidateBasic() error