Skip to content

Commit

Permalink
Refactor poll to use generics
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidNix committed Oct 14, 2022
1 parent cdb1669 commit f6d3fd9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
10 changes: 3 additions & 7 deletions chain/cosmos/poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
// PollForProposalStatus attempts to find a proposal with matching ID and status.
func PollForProposalStatus(ctx context.Context, chain *CosmosChain, startHeight, maxHeight uint64, proposalID string, status string) (ProposalResponse, error) {
var zero ProposalResponse
doPoll := func(ctx context.Context, height uint64) (any, error) {
doPoll := func(ctx context.Context, height uint64) (ProposalResponse, error) {
p, err := chain.QueryProposal(ctx, proposalID)
if err != nil {
return zero, err
Expand All @@ -20,10 +20,6 @@ func PollForProposalStatus(ctx context.Context, chain *CosmosChain, startHeight,
}
return *p, nil
}
bp := test.BlockPoller{CurrentHeight: chain.Height, PollFunc: doPoll}
p, err := bp.DoPoll(ctx, startHeight, maxHeight)
if err != nil {
return zero, err
}
return p.(ProposalResponse), nil
bp := test.BlockPoller[ProposalResponse]{CurrentHeight: chain.Height, PollFunc: doPoll}
return bp.DoPoll(ctx, startHeight, maxHeight)
}
27 changes: 15 additions & 12 deletions test/poll_for_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,26 @@ import (

var ErrNotFound = errors.New("not found")

type BlockPoller struct {
type BlockPoller[T any] struct {
CurrentHeight func(ctx context.Context) (uint64, error)
PollFunc func(ctx context.Context, height uint64) (any, error)
PollFunc func(ctx context.Context, height uint64) (T, error)
}

func (p BlockPoller) DoPoll(ctx context.Context, startHeight, maxHeight uint64) (any, error) {
func (p BlockPoller[T]) DoPoll(ctx context.Context, startHeight, maxHeight uint64) (T, error) {
if maxHeight < startHeight {
panic("maxHeight must be greater than or equal to startHeight")
}

var pollErr error
var (
pollErr error
zero T
)

cursor := startHeight
for cursor <= maxHeight {
curHeight, err := p.CurrentHeight(ctx)
if err != nil {
return nil, err
return zero, err
}
if cursor > curHeight {
continue
Expand All @@ -44,7 +47,7 @@ func (p BlockPoller) DoPoll(ctx context.Context, startHeight, maxHeight uint64)

return found, nil
}
return nil, pollErr
return zero, pollErr
}

// ChainAcker is a chain that can get its acknowledgements at a specified height
Expand All @@ -60,7 +63,7 @@ type ChainAcker interface {
func PollForAck(ctx context.Context, chain ChainAcker, startHeight, maxHeight uint64, packet ibc.Packet) (ibc.PacketAcknowledgement, error) {
var zero ibc.PacketAcknowledgement
pollError := &packetPollError{targetPacket: packet}
poll := func(ctx context.Context, height uint64) (any, error) {
poll := func(ctx context.Context, height uint64) (ibc.PacketAcknowledgement, error) {
acks, err := chain.Acknowledgements(ctx, height)
if err != nil {
return zero, err
Expand All @@ -74,13 +77,13 @@ func PollForAck(ctx context.Context, chain ChainAcker, startHeight, maxHeight ui
return zero, ErrNotFound
}

poller := BlockPoller{CurrentHeight: chain.Height, PollFunc: poll}
poller := BlockPoller[ibc.PacketAcknowledgement]{CurrentHeight: chain.Height, PollFunc: poll}
found, err := poller.DoPoll(ctx, startHeight, maxHeight)
if err != nil {
pollError.SetErr(err)
return zero, pollError
}
return found.(ibc.PacketAcknowledgement), nil
return found, nil
}

// ChainTimeouter is a chain that can get its timeouts at a specified height
Expand All @@ -94,7 +97,7 @@ type ChainTimeouter interface {
func PollForTimeout(ctx context.Context, chain ChainTimeouter, startHeight, maxHeight uint64, packet ibc.Packet) (ibc.PacketTimeout, error) {
pollError := &packetPollError{targetPacket: packet}
var zero ibc.PacketTimeout
poll := func(ctx context.Context, height uint64) (any, error) {
poll := func(ctx context.Context, height uint64) (ibc.PacketTimeout, error) {
timeouts, err := chain.Timeouts(ctx, height)
if err != nil {
return zero, err
Expand All @@ -108,13 +111,13 @@ func PollForTimeout(ctx context.Context, chain ChainTimeouter, startHeight, maxH
return zero, ErrNotFound
}

poller := BlockPoller{CurrentHeight: chain.Height, PollFunc: poll}
poller := BlockPoller[ibc.PacketTimeout]{CurrentHeight: chain.Height, PollFunc: poll}
found, err := poller.DoPoll(ctx, startHeight, maxHeight)
if err != nil {
pollError.SetErr(err)
return zero, pollError
}
return found.(ibc.PacketTimeout), nil
return found, nil
}

type packetPollError struct {
Expand Down

0 comments on commit f6d3fd9

Please sign in to comment.