From d892b2d5a5bb313392e1276358c36c87c7728fea Mon Sep 17 00:00:00 2001 From: mossid Date: Sun, 14 Jul 2019 05:03:31 +0900 Subject: [PATCH] manually pass proofs --- x/ibc/04-channel/handshake.go | 23 +++++++++++++++++++++-- x/ibc/04-channel/manager.go | 13 +++++++++++-- x/ibc/04-channel/tests/types.go | 14 +++++++------- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/x/ibc/04-channel/handshake.go b/x/ibc/04-channel/handshake.go index eb38c63f5475..2889aa8183fc 100644 --- a/x/ibc/04-channel/handshake.go +++ b/x/ibc/04-channel/handshake.go @@ -159,8 +159,9 @@ func (man Handshaker) OpenInit(ctx sdk.Context, return obj, nil } -// Using proofs: counterparty.{handshake,state,nextTimeout,clientid,client} +// Using proofs: counterparty.{channel,state,nextTimeout} func (man Handshaker) OpenTry(ctx sdk.Context, + pchannel, pstate, ptimeout commitment.Proof, connid, chanid string, channel Channel, timeoutHeight, nextTimeoutHeight uint64, ) (obj HandshakeObject, err error) { obj, err = man.create(ctx, connid, chanid, channel) @@ -168,6 +169,11 @@ func (man Handshaker) OpenTry(ctx sdk.Context, return } + ctx, err = obj.Context(ctx, pchannel, pstate, ptimeout) + if err != nil { + return + } + err = assertTimeout(ctx, timeoutHeight) if err != nil { return @@ -216,6 +222,7 @@ func (man Handshaker) OpenTry(ctx sdk.Context, // Using proofs: counterparty.{handshake,state,nextTimeout,clientid,client} func (man Handshaker) OpenAck(ctx sdk.Context, + pchannel, pstate, ptimeout commitment.Proof, connid, chanid string, timeoutHeight, nextTimeoutHeight uint64, ) (obj HandshakeObject, err error) { obj, err = man.query(ctx, connid, chanid) @@ -223,6 +230,11 @@ func (man Handshaker) OpenAck(ctx sdk.Context, return } + ctx, err = obj.Context(ctx, pchannel, pstate, ptimeout) + if err != nil { + return + } + if !obj.state.Transit(ctx, Init, Open) { err = errors.New("ack on non-init connection") return @@ -269,12 +281,19 @@ func (man Handshaker) OpenAck(ctx sdk.Context, } // Using proofs: counterparty.{connection,state, nextTimeout} -func (man Handshaker) OpenConfirm(ctx sdk.Context, connid, chanid string, timeoutHeight uint64) (obj HandshakeObject, err error) { +func (man Handshaker) OpenConfirm(ctx sdk.Context, + pstate, ptimeout commitment.Proof, + connid, chanid string, timeoutHeight uint64) (obj HandshakeObject, err error) { obj, err = man.query(ctx, connid, chanid) if err != nil { return } + ctx, err = obj.Context(ctx, pstate, ptimeout) + if err != nil { + return + } + if !obj.state.Transit(ctx, OpenTry, Open) { err = errors.New("confirm on non-try connection") return diff --git a/x/ibc/04-channel/manager.go b/x/ibc/04-channel/manager.go index 8d8287a1f3d8..fe095ec8dfe4 100644 --- a/x/ibc/04-channel/manager.go +++ b/x/ibc/04-channel/manager.go @@ -200,6 +200,10 @@ type CounterObject struct { connection connection.CounterObject } +func (obj Object) Context(ctx sdk.Context, proofs ...commitment.Proof) (sdk.Context, error) { + return obj.connection.Context(ctx, nil, proofs...) +} + func (obj Object) ChanID() string { return obj.chanid } @@ -257,12 +261,17 @@ func (obj Object) Send(ctx sdk.Context, packet Packet) error { return nil } -func (obj Object) Receive(ctx sdk.Context, packet Packet) error { +func (obj Object) Receive(ctx sdk.Context, ppacket commitment.Proof, packet Packet) error { if !obj.Receivable(ctx) { return errors.New("cannot receive packets on this channel") } - err := assertTimeout(ctx, packet.Timeout()) + ctx, err := obj.Context(ctx, ppacket) + if err != nil { + return err + } + + err = assertTimeout(ctx, packet.Timeout()) if err != nil { return err } diff --git a/x/ibc/04-channel/tests/types.go b/x/ibc/04-channel/tests/types.go index fc62aa251510..80f94e7e248c 100644 --- a/x/ibc/04-channel/tests/types.go +++ b/x/ibc/04-channel/tests/types.go @@ -54,7 +54,7 @@ func NewNode(self, counter tendermint.MockValidators, cdc *codec.Codec) *Node { } func (node *Node) Handshaker(t *testing.T, proofs []commitment.Proof) (sdk.Context, channel.Handshaker) { - ctx := node.Context(t, proofs) + ctx := node.Context() store, err := commitment.NewStore(node.Counterparty.Root(), node.Counterparty.Path, proofs) require.NoError(t, err) ctx = commitment.WithStore(ctx, store) @@ -90,7 +90,7 @@ func (node *Node) OpenInit(t *testing.T, proofs ...commitment.Proof) { func (node *Node) OpenTry(t *testing.T, proofs ...commitment.Proof) { ctx, man := node.Handshaker(t, proofs) - obj, err := man.OpenTry(ctx, node.Name, node.Name, node.Channel, 100 /*TODO*/, 100 /*TODO*/) + obj, err := man.OpenTry(ctx, proofs[0], proofs[1], proofs[2], node.Name, node.Name, node.Channel, 100 /*TODO*/, 100 /*TODO*/) require.NoError(t, err) require.Equal(t, channel.OpenTry, obj.State(ctx)) require.Equal(t, node.Channel, obj.Channel(ctx)) @@ -100,7 +100,7 @@ func (node *Node) OpenTry(t *testing.T, proofs ...commitment.Proof) { func (node *Node) OpenAck(t *testing.T, proofs ...commitment.Proof) { ctx, man := node.Handshaker(t, proofs) - obj, err := man.OpenAck(ctx, node.Name, node.Name, 100 /*TODO*/, 100 /*TODO*/) + obj, err := man.OpenAck(ctx, proofs[0], proofs[1], proofs[2], node.Name, node.Name, 100 /*TODO*/, 100 /*TODO*/) require.NoError(t, err) require.Equal(t, channel.Open, obj.State(ctx)) require.Equal(t, node.Channel, obj.Channel(ctx)) @@ -110,7 +110,7 @@ func (node *Node) OpenAck(t *testing.T, proofs ...commitment.Proof) { func (node *Node) OpenConfirm(t *testing.T, proofs ...commitment.Proof) { ctx, man := node.Handshaker(t, proofs) - obj, err := man.OpenConfirm(ctx, node.Name, node.Name, 100 /*TODO*/) + obj, err := man.OpenConfirm(ctx, proofs[0], proofs[1], node.Name, node.Name, 100 /*TODO*/) require.NoError(t, err) require.Equal(t, channel.Open, obj.State(ctx)) require.Equal(t, node.Channel, obj.Channel(ctx)) @@ -152,7 +152,7 @@ func (node *Node) Handshake(t *testing.T) { } func (node *Node) Send(t *testing.T, packet channel.Packet) { - ctx, man := node.Context(t, nil), node.Manager() + ctx, man := node.Context(), node.Manager() obj, err := man.Query(ctx, node.Name, node.Name) require.NoError(t, err) seq := obj.SeqSend(ctx) @@ -163,11 +163,11 @@ func (node *Node) Send(t *testing.T, packet channel.Packet) { } func (node *Node) Receive(t *testing.T, packet channel.Packet, proofs ...commitment.Proof) { - ctx, man := node.Context(t, proofs), node.Manager() + ctx, man := node.Context(), node.Manager() obj, err := man.Query(ctx, node.Name, node.Name) require.NoError(t, err) seq := obj.SeqRecv(ctx) - err = obj.Receive(ctx, packet) + err = obj.Receive(ctx, proofs[0], packet) require.NoError(t, err) require.Equal(t, seq+1, obj.SeqRecv(ctx)) }