diff --git a/modules/apps/transfer/keeper/export_test.go b/modules/apps/transfer/keeper/export_test.go index c2a30839dc6..8c6f3961c79 100644 --- a/modules/apps/transfer/keeper/export_test.go +++ b/modules/apps/transfer/keeper/export_test.go @@ -31,3 +31,8 @@ func (k Keeper) GetAllDenomTraces(ctx sdk.Context) []types.DenomTrace { func (k Keeper) TokenFromCoin(ctx sdk.Context, coin sdk.Coin) (types.Token, error) { return k.tokenFromCoin(ctx, coin) } + +// CreatePacketDataBytesFromVersion is a wrapper around createPacketDataBytesFromVersion for testing purposes +func CreatePacketDataBytesFromVersion(appVersion, sender, receiver, memo string, tokens types.Tokens) []byte { + return createPacketDataBytesFromVersion(appVersion, sender, receiver, memo, tokens) +} diff --git a/modules/apps/transfer/keeper/relay.go b/modules/apps/transfer/keeper/relay.go index e1a0a7e7185..287ad5b5445 100644 --- a/modules/apps/transfer/keeper/relay.go +++ b/modules/apps/transfer/keeper/relay.go @@ -138,19 +138,7 @@ func (k Keeper) sendTransfer( tokens = append(tokens, token) } - var packetDataBytes []byte - switch appVersion { - case types.V1: - // Length of coins has been checked earlier to be 1 if version is V1. - token := tokens[0] - packetData := types.NewFungibleTokenPacketData(token.Denom.Path(), token.Amount, sender.String(), receiver, memo) - packetDataBytes = packetData.GetBytes() - case types.V2: - packetData := types.NewFungibleTokenPacketDataV2(tokens, sender.String(), receiver, memo) - packetDataBytes = packetData.GetBytes() - default: - panic(fmt.Errorf("app version must be one of %s", types.SupportedVersions)) - } + packetDataBytes := createPacketDataBytesFromVersion(appVersion, sender.String(), receiver, memo, tokens) sequence, err := k.ics4Wrapper.SendPacket(ctx, channelCap, sourcePort, sourceChannel, timeoutHeight, timeoutTimestamp, packetDataBytes) if err != nil { @@ -450,3 +438,26 @@ func (k Keeper) tokenFromCoin(ctx sdk.Context, coin sdk.Coin) (types.Token, erro Amount: coin.Amount.String(), }, nil } + +// createPacketDataBytesFromVersion creates the packet data bytes to be sent based on the application version. +func createPacketDataBytesFromVersion(appVersion, sender, receiver, memo string, tokens types.Tokens) []byte { + var packetDataBytes []byte + switch appVersion { + case types.V1: + // Sanity check, tokens must always be of length 1 if using app version V1. + if len(tokens) != 1 { + panic(fmt.Errorf("length of tokens must be equal to 1 if using %s version", types.V1)) + } + + token := tokens[0] + packetData := types.NewFungibleTokenPacketData(token.Denom.Path(), token.Amount, sender, receiver, memo) + packetDataBytes = packetData.GetBytes() + case types.V2: + packetData := types.NewFungibleTokenPacketDataV2(tokens, sender, receiver, memo) + packetDataBytes = packetData.GetBytes() + default: + panic(fmt.Errorf("app version must be one of %s", types.SupportedVersions)) + } + + return packetDataBytes +} diff --git a/modules/apps/transfer/keeper/relay_test.go b/modules/apps/transfer/keeper/relay_test.go index 5873107bd13..630bfac7a24 100644 --- a/modules/apps/transfer/keeper/relay_test.go +++ b/modules/apps/transfer/keeper/relay_test.go @@ -12,11 +12,13 @@ import ( banktestutil "github.com/cosmos/cosmos-sdk/x/bank/testutil" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + transferkeeper "github.com/cosmos/ibc-go/v8/modules/apps/transfer/keeper" "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types" clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" ibcerrors "github.com/cosmos/ibc-go/v8/modules/core/errors" ibctesting "github.com/cosmos/ibc-go/v8/testing" + ibcmock "github.com/cosmos/ibc-go/v8/testing/mock" ) // TestSendTransfer tests sending from chainA to chainB using both coin @@ -1243,3 +1245,75 @@ func (suite *KeeperTestSuite) TestPacketForwardsCompatibility() { }) } } + +func (suite *KeeperTestSuite) TestCreatePacketDataBytesFromVersion() { + var ( + bz []byte + tokens types.Tokens + ) + + testCases := []struct { + name string + appVersion string + malleate func() + expResult func(bz []byte) + expPanicErr error + }{ + { + "success", + types.V1, + func() {}, + func(bz []byte) { + expPacketData := types.NewFungibleTokenPacketData("", "", "", "", "") + suite.Require().Equal(bz, expPacketData.GetBytes()) + }, + nil, + }, + { + "success: version 2", + types.V2, + func() {}, + func(bz []byte) { + expPacketData := types.NewFungibleTokenPacketDataV2(types.Tokens{types.Token{}}, "", "", "") + suite.Require().Equal(bz, expPacketData.GetBytes()) + }, + nil, + }, + { + "failure: must have single coin if using version 1.", + types.V1, + func() { + tokens = types.Tokens{} + }, + nil, + fmt.Errorf("length of tokens must be equal to 1 if using %s version", types.V1), + }, + { + "failure: invalid version", + ibcmock.Version, + func() {}, + nil, + fmt.Errorf("app version must be one of %s", types.SupportedVersions), + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + tokens = types.Tokens{types.Token{}} + + tc.malleate() + + createFunc := func() { + bz = transferkeeper.CreatePacketDataBytesFromVersion(tc.appVersion, "", "", "", tokens) + } + + expPanic := tc.expPanicErr != nil + if expPanic { + suite.Require().PanicsWithError(tc.expPanicErr.Error(), createFunc) + } else { + createFunc() + tc.expResult(bz) + } + }) + } +}