Skip to content

Commit

Permalink
feat: Add cosmos message poller (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidNix authored and jtieri committed Jan 6, 2023
1 parent c7de889 commit bc0a203
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 20 deletions.
43 changes: 37 additions & 6 deletions chain/cosmos/poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@ package cosmos

import (
"context"
"errors"
"fmt"

codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/strangelove-ventures/ibctest/v3/test"
)

// PollForProposalStatus attempts to find a proposal with matching ID and status.
func PollForProposalStatus(ctx context.Context, chain *CosmosChain, startHeight, maxHeight uint64, proposalID string, status string) (ProposalResponse, error) {
var zero ProposalResponse
doPoll := func(ctx context.Context, height uint64) (any, error) {
doPoll := func(ctx context.Context, height uint64) (ProposalResponse, error) {
p, err := chain.QueryProposal(ctx, proposalID)
if err != nil {
return zero, err
Expand All @@ -20,10 +22,39 @@ func PollForProposalStatus(ctx context.Context, chain *CosmosChain, startHeight,
}
return *p, nil
}
bp := test.BlockPoller{CurrentHeight: chain.Height, PollFunc: doPoll}
p, err := bp.DoPoll(ctx, startHeight, maxHeight)
if err != nil {
return zero, err
bp := test.BlockPoller[ProposalResponse]{CurrentHeight: chain.Height, PollFunc: doPoll}
return bp.DoPoll(ctx, startHeight, maxHeight)
}

// PollForMessage searches every transaction for a message. Must pass a coded registry capable of decoding the cosmos transaction.
// fn is optional. Return true from the fn to stop polling and return the found message. If fn is nil, returns the first message to match type T.
func PollForMessage[T any](ctx context.Context, chain *CosmosChain, registry codectypes.InterfaceRegistry, startHeight, maxHeight uint64, fn func(found T) bool) (T, error) {
var zero T
if fn == nil {
fn = func(T) bool { return true }
}
doPoll := func(ctx context.Context, height uint64) (T, error) {
h := int64(height)
block, err := chain.getFullNode().Client.Block(ctx, &h)
if err != nil {
return zero, err
}
for _, tx := range block.Block.Txs {
sdkTx, err := decodeTX(registry, tx)
if err != nil {
return zero, err
}
for _, msg := range sdkTx.GetMsgs() {
if found, ok := msg.(T); ok {
if fn(found) {
return found, nil
}
}
}
}
return zero, errors.New("not found")
}
return p.(ProposalResponse), nil

bp := test.BlockPoller[T]{CurrentHeight: chain.Height, PollFunc: doPoll}
return bp.DoPoll(ctx, startHeight, maxHeight)
}
101 changes: 101 additions & 0 deletions examples/cosmos/light_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package cosmos_test

import (
"context"
"testing"

clienttypes "github.com/cosmos/ibc-go/v6/modules/core/02-client/types"
"github.com/strangelove-ventures/ibctest/v6"
"github.com/strangelove-ventures/ibctest/v6/chain/cosmos"
"github.com/strangelove-ventures/ibctest/v6/ibc"
"github.com/strangelove-ventures/ibctest/v6/testreporter"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
)

func TestUpdateLightClients(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}

t.Parallel()

ctx := context.Background()

// Chains
cf := ibctest.NewBuiltinChainFactory(zaptest.NewLogger(t), []*ibctest.ChainSpec{
{Name: "gaia", Version: gaiaVersion},
{Name: "osmosis", Version: osmosisVersion},
})

chains, err := cf.Chains(t.Name())
require.NoError(t, err)
gaia, osmosis := chains[0], chains[1]

// Relayer
client, network := ibctest.DockerSetup(t)
r := ibctest.NewBuiltinRelayerFactory(ibc.CosmosRly, zaptest.NewLogger(t)).Build(
t, client, network)

ic := ibctest.NewInterchain().
AddChain(gaia).
AddChain(osmosis).
AddRelayer(r, "relayer").
AddLink(ibctest.InterchainLink{
Chain1: gaia,
Chain2: osmosis,
Relayer: r,
Path: "client-test-path",
})

// Build interchain
rep := testreporter.NewNopReporter()
eRep := rep.RelayerExecReporter(t)
require.NoError(t, ic.Build(ctx, eRep, ibctest.InterchainBuildOptions{
TestName: t.Name(),
Client: client,
NetworkID: network,
}))
t.Cleanup(func() {
_ = ic.Close()
})

require.NoError(t, r.StartRelayer(ctx, eRep))
t.Cleanup(func() {
_ = r.StopRelayer(ctx, eRep)
})

// Create and Fund User Wallets
fundAmount := int64(10_000_000)
users := ibctest.GetAndFundTestUsers(t, ctx, "default", fundAmount, gaia, osmosis)
gaiaUser, osmoUser := users[0], users[1]

// Get Channel ID
gaiaChannelInfo, err := r.GetChannels(ctx, eRep, gaia.Config().ChainID)
require.NoError(t, err)
chanID := gaiaChannelInfo[0].ChannelID

height, err := osmosis.Height(ctx)
require.NoError(t, err)

amountToSend := int64(553255) // Unique amount to make log searching easier.
dstAddress := osmoUser.Bech32Address(osmosis.Config().Bech32Prefix)
tx, err := gaia.SendIBCTransfer(ctx, chanID, gaiaUser.KeyName, ibc.WalletAmount{
Address: dstAddress,
Denom: gaia.Config().Denom,
Amount: amountToSend,
},
nil,
)
require.NoError(t, err)
require.NoError(t, tx.Validate())

chain := osmosis.(*cosmos.CosmosChain)
reg := chain.Config().EncodingConfig.InterfaceRegistry
msg, err := cosmos.PollForMessage[*clienttypes.MsgUpdateClient](ctx, chain, reg, height, height+10, nil)
require.NoError(t, err)

require.Equal(t, "07-tendermint-0", msg.ClientId)
require.NotEmpty(t, msg.Signer)
// TODO: Assert header information
}
2 changes: 1 addition & 1 deletion examples/cosmos/state_sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

func TestCosmosHubStateSync(t *testing.T) {
CosmosChainStateSyncTest(t, "gaia", "v7.0.3")
CosmosChainStateSyncTest(t, "gaia", gaiaVersion)
}

const stateSyncSnapshotInterval = 10
Expand Down
6 changes: 6 additions & 0 deletions examples/cosmos/versions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package cosmos_test

const (
gaiaVersion = "v7.1.0"
osmosisVersion = "v12.2.0"
)
4 changes: 3 additions & 1 deletion examples/ibc/learn_ibc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ func TestLearn(t *testing.T) {
t.Skip("skipping in short mode")
}

t.Parallel()

ctx := context.Background()

// Chain Factory
Expand Down Expand Up @@ -73,7 +75,7 @@ func TestLearn(t *testing.T) {

// Create and Fund User Wallets
fundAmount := int64(10_000_000)
users := ibctest.GetAndFundTestUsers(t, ctx, "default", int64(fundAmount), gaia, osmosis)
users := ibctest.GetAndFundTestUsers(t, ctx, "default", fundAmount, gaia, osmosis)
gaiaUser := users[0]
osmosisUser := users[1]

Expand Down
27 changes: 15 additions & 12 deletions test/poll_for_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,26 @@ import (

var ErrNotFound = errors.New("not found")

type BlockPoller struct {
type BlockPoller[T any] struct {
CurrentHeight func(ctx context.Context) (uint64, error)
PollFunc func(ctx context.Context, height uint64) (any, error)
PollFunc func(ctx context.Context, height uint64) (T, error)
}

func (p BlockPoller) DoPoll(ctx context.Context, startHeight, maxHeight uint64) (any, error) {
func (p BlockPoller[T]) DoPoll(ctx context.Context, startHeight, maxHeight uint64) (T, error) {
if maxHeight < startHeight {
panic("maxHeight must be greater than or equal to startHeight")
}

var pollErr error
var (
pollErr error
zero T
)

cursor := startHeight
for cursor <= maxHeight {
curHeight, err := p.CurrentHeight(ctx)
if err != nil {
return nil, err
return zero, err
}
if cursor > curHeight {
continue
Expand All @@ -44,7 +47,7 @@ func (p BlockPoller) DoPoll(ctx context.Context, startHeight, maxHeight uint64)

return found, nil
}
return nil, pollErr
return zero, pollErr
}

// ChainAcker is a chain that can get its acknowledgements at a specified height
Expand All @@ -60,7 +63,7 @@ type ChainAcker interface {
func PollForAck(ctx context.Context, chain ChainAcker, startHeight, maxHeight uint64, packet ibc.Packet) (ibc.PacketAcknowledgement, error) {
var zero ibc.PacketAcknowledgement
pollError := &packetPollError{targetPacket: packet}
poll := func(ctx context.Context, height uint64) (any, error) {
poll := func(ctx context.Context, height uint64) (ibc.PacketAcknowledgement, error) {
acks, err := chain.Acknowledgements(ctx, height)
if err != nil {
return zero, err
Expand All @@ -74,13 +77,13 @@ func PollForAck(ctx context.Context, chain ChainAcker, startHeight, maxHeight ui
return zero, ErrNotFound
}

poller := BlockPoller{CurrentHeight: chain.Height, PollFunc: poll}
poller := BlockPoller[ibc.PacketAcknowledgement]{CurrentHeight: chain.Height, PollFunc: poll}
found, err := poller.DoPoll(ctx, startHeight, maxHeight)
if err != nil {
pollError.SetErr(err)
return zero, pollError
}
return found.(ibc.PacketAcknowledgement), nil
return found, nil
}

// ChainTimeouter is a chain that can get its timeouts at a specified height
Expand All @@ -94,7 +97,7 @@ type ChainTimeouter interface {
func PollForTimeout(ctx context.Context, chain ChainTimeouter, startHeight, maxHeight uint64, packet ibc.Packet) (ibc.PacketTimeout, error) {
pollError := &packetPollError{targetPacket: packet}
var zero ibc.PacketTimeout
poll := func(ctx context.Context, height uint64) (any, error) {
poll := func(ctx context.Context, height uint64) (ibc.PacketTimeout, error) {
timeouts, err := chain.Timeouts(ctx, height)
if err != nil {
return zero, err
Expand All @@ -108,13 +111,13 @@ func PollForTimeout(ctx context.Context, chain ChainTimeouter, startHeight, maxH
return zero, ErrNotFound
}

poller := BlockPoller{CurrentHeight: chain.Height, PollFunc: poll}
poller := BlockPoller[ibc.PacketTimeout]{CurrentHeight: chain.Height, PollFunc: poll}
found, err := poller.DoPoll(ctx, startHeight, maxHeight)
if err != nil {
pollError.SetErr(err)
return zero, pollError
}
return found.(ibc.PacketTimeout), nil
return found, nil
}

type packetPollError struct {
Expand Down

0 comments on commit bc0a203

Please sign in to comment.