diff --git a/chain/cosmos/poll.go b/chain/cosmos/poll.go index f77724d82..8302f756a 100644 --- a/chain/cosmos/poll.go +++ b/chain/cosmos/poll.go @@ -10,7 +10,7 @@ import ( // PollForProposalStatus attempts to find a proposal with matching ID and status. func PollForProposalStatus(ctx context.Context, chain *CosmosChain, startHeight, maxHeight uint64, proposalID string, status string) (ProposalResponse, error) { var zero ProposalResponse - doPoll := func(ctx context.Context, height uint64) (any, error) { + doPoll := func(ctx context.Context, height uint64) (ProposalResponse, error) { p, err := chain.QueryProposal(ctx, proposalID) if err != nil { return zero, err @@ -20,10 +20,6 @@ func PollForProposalStatus(ctx context.Context, chain *CosmosChain, startHeight, } return *p, nil } - bp := test.BlockPoller{CurrentHeight: chain.Height, PollFunc: doPoll} - p, err := bp.DoPoll(ctx, startHeight, maxHeight) - if err != nil { - return zero, err - } - return p.(ProposalResponse), nil + bp := test.BlockPoller[ProposalResponse]{CurrentHeight: chain.Height, PollFunc: doPoll} + return bp.DoPoll(ctx, startHeight, maxHeight) } diff --git a/test/poll_for_state.go b/test/poll_for_state.go index 5155c4af5..cf1f63325 100644 --- a/test/poll_for_state.go +++ b/test/poll_for_state.go @@ -12,23 +12,26 @@ import ( var ErrNotFound = errors.New("not found") -type BlockPoller struct { +type BlockPoller[T any] struct { CurrentHeight func(ctx context.Context) (uint64, error) - PollFunc func(ctx context.Context, height uint64) (any, error) + PollFunc func(ctx context.Context, height uint64) (T, error) } -func (p BlockPoller) DoPoll(ctx context.Context, startHeight, maxHeight uint64) (any, error) { +func (p BlockPoller[T]) DoPoll(ctx context.Context, startHeight, maxHeight uint64) (T, error) { if maxHeight < startHeight { panic("maxHeight must be greater than or equal to startHeight") } - var pollErr error + var ( + pollErr error + zero T + ) cursor := startHeight for cursor <= maxHeight { curHeight, err := p.CurrentHeight(ctx) if err != nil { - return nil, err + return zero, err } if cursor > curHeight { continue @@ -44,7 +47,7 @@ func (p BlockPoller) DoPoll(ctx context.Context, startHeight, maxHeight uint64) return found, nil } - return nil, pollErr + return zero, pollErr } // ChainAcker is a chain that can get its acknowledgements at a specified height @@ -60,7 +63,7 @@ type ChainAcker interface { func PollForAck(ctx context.Context, chain ChainAcker, startHeight, maxHeight uint64, packet ibc.Packet) (ibc.PacketAcknowledgement, error) { var zero ibc.PacketAcknowledgement pollError := &packetPollError{targetPacket: packet} - poll := func(ctx context.Context, height uint64) (any, error) { + poll := func(ctx context.Context, height uint64) (ibc.PacketAcknowledgement, error) { acks, err := chain.Acknowledgements(ctx, height) if err != nil { return zero, err @@ -74,13 +77,13 @@ func PollForAck(ctx context.Context, chain ChainAcker, startHeight, maxHeight ui return zero, ErrNotFound } - poller := BlockPoller{CurrentHeight: chain.Height, PollFunc: poll} + poller := BlockPoller[ibc.PacketAcknowledgement]{CurrentHeight: chain.Height, PollFunc: poll} found, err := poller.DoPoll(ctx, startHeight, maxHeight) if err != nil { pollError.SetErr(err) return zero, pollError } - return found.(ibc.PacketAcknowledgement), nil + return found, nil } // ChainTimeouter is a chain that can get its timeouts at a specified height @@ -94,7 +97,7 @@ type ChainTimeouter interface { func PollForTimeout(ctx context.Context, chain ChainTimeouter, startHeight, maxHeight uint64, packet ibc.Packet) (ibc.PacketTimeout, error) { pollError := &packetPollError{targetPacket: packet} var zero ibc.PacketTimeout - poll := func(ctx context.Context, height uint64) (any, error) { + poll := func(ctx context.Context, height uint64) (ibc.PacketTimeout, error) { timeouts, err := chain.Timeouts(ctx, height) if err != nil { return zero, err @@ -108,13 +111,13 @@ func PollForTimeout(ctx context.Context, chain ChainTimeouter, startHeight, maxH return zero, ErrNotFound } - poller := BlockPoller{CurrentHeight: chain.Height, PollFunc: poll} + poller := BlockPoller[ibc.PacketTimeout]{CurrentHeight: chain.Height, PollFunc: poll} found, err := poller.DoPoll(ctx, startHeight, maxHeight) if err != nil { pollError.SetErr(err) return zero, pollError } - return found.(ibc.PacketTimeout), nil + return found, nil } type packetPollError struct {