Skip to content

Commit

Permalink
Merge commit from fork
Browse files Browse the repository at this point in the history
* add fix and test

remove packet data check

add check to unmarshal function for packet data

lint

rm test case

advisory fix main

* fix merge
  • Loading branch information
AdityaSripal authored Feb 27, 2025
1 parent 3aa30a2 commit ea60498
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 1 deletion.
6 changes: 6 additions & 0 deletions modules/apps/transfer/ibc_module.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package transfer

import (
"bytes"
"fmt"
"math"
"slices"
Expand Down Expand Up @@ -231,6 +232,11 @@ func (im IBCModule) OnAcknowledgementPacket(
return err
}

bz := types.ModuleCdc.MustMarshalJSON(&ack)
if !bytes.Equal(bz, acknowledgement) {
return errorsmod.Wrapf(ibcerrors.ErrInvalidType, "acknowledgement did not marshal to expected bytes: %X ≠ %X", bz, acknowledgement)
}

if err := im.keeper.OnAcknowledgementPacket(ctx, packet.SourcePort, packet.SourceChannel, data, ack); err != nil {
return err
}
Expand Down
118 changes: 118 additions & 0 deletions modules/apps/transfer/ibc_module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,124 @@ func (suite *TransferTestSuite) TestOnRecvPacket() {
}
}

func (suite *TransferTestSuite) TestOnAcknowledgePacket() {
var (
path *ibctesting.Path
packet channeltypes.Packet
ack []byte
)

testCases := []struct {
name string
malleate func()
expError error
expRefund bool
}{
{
"success",
func() {},
nil,
false,
},
{
"success: refund coins",
func() {
ack = channeltypes.NewErrorAcknowledgement(ibcerrors.ErrInsufficientFunds).Acknowledgement()
},
nil,
true,
},
{
"cannot refund ack on non-existent channel",
func() {
ack = channeltypes.NewErrorAcknowledgement(ibcerrors.ErrInsufficientFunds).Acknowledgement()

packet.SourceChannel = "channel-100"
},
errors.New("unable to unescrow tokens"),
false,
},
{
"invalid packet data",
func() {
packet.Data = []byte("invalid data")
},
ibcerrors.ErrInvalidType,
false,
},
{
"invalid acknowledgement",
func() {
ack = []byte("invalid ack")
},
ibcerrors.ErrUnknownRequest,
false,
},
{
"cannot refund already acknowledged packet",
func() {
ack = channeltypes.NewErrorAcknowledgement(ibcerrors.ErrInsufficientFunds).Acknowledgement()

cbs, ok := suite.chainA.App.GetIBCKeeper().PortKeeper.Route(ibctesting.TransferPort)
suite.Require().True(ok)

suite.Require().NoError(cbs.OnAcknowledgementPacket(suite.chainA.GetContext(), path.EndpointA.GetChannel().Version, packet, ack, suite.chainA.SenderAccount.GetAddress()))
},
errors.New("unable to unescrow tokens"),
false,
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest() // reset

path = ibctesting.NewTransferPath(suite.chainA, suite.chainB)
path.Setup()

timeoutHeight := suite.chainA.GetTimeoutHeight()
msg := types.NewMsgTransfer(
path.EndpointA.ChannelConfig.PortID,
path.EndpointA.ChannelID,
ibctesting.TestCoin,
suite.chainA.SenderAccount.GetAddress().String(),
suite.chainB.SenderAccount.GetAddress().String(),
timeoutHeight,
0,
"",
)
res, err := suite.chainA.SendMsgs(msg)
suite.Require().NoError(err) // message committed

packet, err = ibctesting.ParsePacketFromEvents(res.Events)
suite.Require().NoError(err)

cbs, ok := suite.chainA.App.GetIBCKeeper().PortKeeper.Route(ibctesting.TransferPort)
suite.Require().True(ok)

ack = channeltypes.NewResultAcknowledgement([]byte{byte(1)}).Acknowledgement()

tc.malleate() // change fields in packet

err = cbs.OnAcknowledgementPacket(suite.chainA.GetContext(), path.EndpointA.GetChannel().Version, packet, ack, suite.chainA.SenderAccount.GetAddress())

if tc.expError == nil {
suite.Require().NoError(err)

if tc.expRefund {
escrowAddress := types.GetEscrowAddress(packet.GetSourcePort(), packet.GetSourceChannel())
escrowBalanceAfter := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddress, sdk.DefaultBondDenom)
suite.Require().Equal(sdkmath.NewInt(0), escrowBalanceAfter.Amount)
}
} else {
suite.Require().Error(err)
suite.Require().Contains(err.Error(), tc.expError.Error())
}
})
}
}

func (suite *TransferTestSuite) TestOnTimeoutPacket() {
var path *ibctesting.Path
var packet channeltypes.Packet
Expand Down
10 changes: 9 additions & 1 deletion modules/apps/transfer/types/packet.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package types

import (
"bytes"
"encoding/json"
"errors"
"strings"
Expand Down Expand Up @@ -252,7 +253,7 @@ func UnmarshalPacketData(bz []byte, ics20Version string, encoding string) (Inter
return InternalTransferRepresentation{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, failedUnmarshalingErrorMsg, errorMsgVersion, err.Error())
}
default:
return InternalTransferRepresentation{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "invalid encoding provided, must be either empty or one of [%q, %q], got %s", EncodingJSON, EncodingProtobuf, encoding)
return InternalTransferRepresentation{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "invalid encoding provided, must be either empty or one of [%q, %q, %q], got %s", EncodingJSON, EncodingProtobuf, EncodingABI, encoding)
}

// When the unmarshaling is done, we want to retrieve the underlying data type based on the value of ics20Version
Expand All @@ -263,6 +264,13 @@ func UnmarshalPacketData(bz []byte, ics20Version string, encoding string) (Inter
// We should never get here, as we manually constructed the type at the beginning of the file
return InternalTransferRepresentation{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot convert proto message into FungibleTokenPacketData")
}
bz2, err := MarshalPacketData(*datav1, ics20Version, encoding)
if err != nil {
return InternalTransferRepresentation{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot marshal transfer packet data: %s", err.Error())
}
if !bytes.Equal(bz, bz2) {
return InternalTransferRepresentation{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "marshaled bytes are not equal: got %X, expected %X", bz2, bz)
}
// The call to ValidateBasic for V1 is done inside PacketDataV1toV2.
return PacketDataV1ToV2(*datav1)
}
Expand Down

0 comments on commit ea60498

Please sign in to comment.