diff --git a/app/app.go b/app/app.go index c4a928674..2d411de33 100644 --- a/app/app.go +++ b/app/app.go @@ -297,6 +297,11 @@ func NewAxelarApp( setKeeper(keepers, initAxelarIBCKeeper(keepers)) + messageRouter := nexusTypes.NewMessageRouter(). + AddRoute(evmTypes.ModuleName, evmKeeper.NewMessageRoute()). + AddRoute(axelarnetTypes.ModuleName, axelarnetKeeper.NewMessageRoute(getKeeper[axelarnetKeeper.Keeper](keepers), getKeeper[axelarnetKeeper.IBCKeeper](keepers), getKeeper[feegrantkeeper.Keeper](keepers), axelarbankkeeper.NewBankKeeper(getKeeper[bankkeeper.BaseKeeper](keepers)), getKeeper[nexusKeeper.Keeper](keepers), getKeeper[authkeeper.AccountKeeper](keepers))) + getKeeperAsRef[nexusKeeper.Keeper](keepers).SetMessageRouter(messageRouter) + axelarnetModule := axelarnet.NewAppModule(getKeeper[axelarnetKeeper.Keeper](keepers), getKeeper[nexusKeeper.Keeper](keepers), axelarbankkeeper.NewBankKeeper(getKeeper[bankkeeper.BaseKeeper](keepers)), getKeeper[authkeeper.AccountKeeper](keepers), getKeeper[axelarnetKeeper.IBCKeeper](keepers), transferStack, rateLimiter, logger) // Create static IBC router, add axelarnet module as the IBC transfer route, and seal it diff --git a/app/keepers.go b/app/keepers.go index a4fd0927f..1ec63e03f 100644 --- a/app/keepers.go +++ b/app/keepers.go @@ -284,15 +284,15 @@ func initEvmKeeper(appCodec codec.Codec, keys map[string]*sdk.KVStoreKey, keeper } func initNexusKeeper(appCodec codec.Codec, keys map[string]*sdk.KVStoreKey, keepers *keeperCache) *nexusKeeper.Keeper { - // Setting Router will finalize all routes by sealing router - // No more routes can be added - nexusRouter := nexusTypes.NewRouter() - nexusRouter. + // setting validator will finalize all by sealing it + // no more validators can be added + addressValidator := nexusTypes.NewAddressValidator(). AddAddressValidator(evmTypes.ModuleName, evmKeeper.NewAddressValidator()). AddAddressValidator(axelarnetTypes.ModuleName, axelarnetKeeper.NewAddressValidator(getKeeper[axelarnetKeeper.Keeper](keepers))) nexusK := nexusKeeper.NewKeeper(appCodec, keys[nexusTypes.StoreKey], keepers.getSubspace(nexusTypes.ModuleName)) - nexusK.SetRouter(nexusRouter) + nexusK.SetAddressValidator(addressValidator) + return &nexusK } diff --git a/x/axelarnet/keeper/genesis_test.go b/x/axelarnet/keeper/genesis_test.go index 853951802..334ed7420 100644 --- a/x/axelarnet/keeper/genesis_test.go +++ b/x/axelarnet/keeper/genesis_test.go @@ -138,10 +138,6 @@ func randomChains() []types.CosmosChain { return chains } -func randomNormalizedStr(min, max int) string { - return strings.ReplaceAll(utils.NormalizeString(rand.StrBetween(min, max)), utils.DefaultDelimiter, "-") -} - // randomTransferQueue returns a random (valid) transfer queue state for testing func randomTransferQueue(cdc codec.Codec, transfers []types.IBCTransfer) utils.QueueState { qs := utils.QueueState{Items: make(map[string]utils.QueueState_Item)} diff --git a/x/axelarnet/keeper/message_route.go b/x/axelarnet/keeper/message_route.go new file mode 100644 index 000000000..7a9dea84e --- /dev/null +++ b/x/axelarnet/keeper/message_route.go @@ -0,0 +1,91 @@ +package keeper + +import ( + "fmt" + + storetypes "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + + "github.com/axelarnetwork/axelar-core/x/axelarnet/exported" + "github.com/axelarnetwork/axelar-core/x/axelarnet/types" + nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" +) + +// for IBC execution +const gasCost = storetypes.Gas(1000000) + +func NewMessageRoute( + keeper Keeper, + ibcK types.IBCKeeper, + feegrantK types.FeegrantKeeper, + bankK types.BankKeeper, + nexusK types.Nexus, + accountK types.AccountKeeper, +) nexus.MessageRoute { + return func(ctx sdk.Context, routingCtx nexus.RoutingContext, msg nexus.GeneralMessage) error { + if routingCtx.Payload == nil { + return fmt.Errorf("payload is required for routing messages to a cosmos chain") + } + + bz, err := types.TranslateMessage(msg, routingCtx.Payload) + if err != nil { + return sdkerrors.Wrap(err, "invalid payload") + } + + asset, err := escrowAssetToMessageSender(ctx, keeper, feegrantK, bankK, nexusK, accountK, routingCtx, msg) + if err != nil { + return err + } + + ctx.GasMeter().ConsumeGas(gasCost, "execute-message") + + return ibcK.SendMessage(ctx.Context(), msg.Recipient, asset, string(bz), msg.ID) + } +} + +// all general messages are sent from the Axelar general message sender, so receiver can use the packet sender to authenticate the message +// escrowAssetToMessageSender sends the asset to general msg sender account +func escrowAssetToMessageSender( + ctx sdk.Context, + keeper Keeper, + feegrantK types.FeegrantKeeper, + bankK types.BankKeeper, + nexusK types.Nexus, + accountK types.AccountKeeper, + routingCtx nexus.RoutingContext, + msg nexus.GeneralMessage, +) (sdk.Coin, error) { + switch msg.Type() { + case nexus.TypeGeneralMessage: + // pure general message, take dust amount from sender to satisfy ibc transfer requirements + asset := sdk.NewCoin(exported.NativeAsset, sdk.OneInt()) + sender := routingCtx.Sender + + if !routingCtx.FeeGranter.Empty() { + req := types.RouteMessageRequest{ + Sender: routingCtx.Sender, + ID: msg.ID, + Payload: routingCtx.Payload, + Feegranter: routingCtx.FeeGranter, + } + if err := feegrantK.UseGrantedFees(ctx, routingCtx.FeeGranter, routingCtx.Sender, sdk.NewCoins(asset), []sdk.Msg{&req}); err != nil { + return sdk.Coin{}, err + } + + sender = routingCtx.FeeGranter + } + + return asset, bankK.SendCoins(ctx, sender, types.AxelarGMPAccount, sdk.NewCoins(asset)) + case nexus.TypeGeneralMessageWithToken: + // general message with token, get token from corresponding account + asset, sender, err := prepareTransfer(ctx, keeper, nexusK, bankK, accountK, *msg.Asset) + if err != nil { + return sdk.Coin{}, err + } + + return asset, bankK.SendCoins(ctx, sender, types.AxelarGMPAccount, sdk.NewCoins(asset)) + default: + return sdk.Coin{}, fmt.Errorf("unrecognized message type") + } +} diff --git a/x/axelarnet/keeper/message_route_test.go b/x/axelarnet/keeper/message_route_test.go new file mode 100644 index 000000000..fc4f2486d --- /dev/null +++ b/x/axelarnet/keeper/message_route_test.go @@ -0,0 +1,218 @@ +package keeper_test + +import ( + "context" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/stretchr/testify/assert" + + "github.com/axelarnetwork/axelar-core/testutils/rand" + "github.com/axelarnetwork/axelar-core/x/axelarnet/exported" + "github.com/axelarnetwork/axelar-core/x/axelarnet/keeper" + "github.com/axelarnetwork/axelar-core/x/axelarnet/types" + "github.com/axelarnetwork/axelar-core/x/axelarnet/types/mock" + evmtestutils "github.com/axelarnetwork/axelar-core/x/evm/types/testutils" + nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" + nexustestutils "github.com/axelarnetwork/axelar-core/x/nexus/exported/testutils" + "github.com/axelarnetwork/utils/funcs" + "github.com/axelarnetwork/utils/slices" + . "github.com/axelarnetwork/utils/test" +) + +func randPayload() []byte { + bytesType := funcs.Must(abi.NewType("bytes", "bytes", nil)) + stringType := funcs.Must(abi.NewType("string", "string", nil)) + stringArrayType := funcs.Must(abi.NewType("string[]", "string[]", nil)) + + argNum := int(rand.I64Between(1, 10)) + + var args abi.Arguments + for i := 0; i < argNum; i += 1 { + args = append(args, abi.Argument{Type: stringType}) + } + + schema := abi.Arguments{{Type: stringType}, {Type: stringArrayType}, {Type: stringArrayType}, {Type: bytesType}} + payload := funcs.Must( + schema.Pack( + rand.StrBetween(5, 10), + slices.Expand2(func() string { return rand.Str(5) }, argNum), + slices.Expand2(func() string { return "string" }, argNum), + funcs.Must(args.Pack(slices.Expand2(func() interface{} { return "string" }, argNum)...)), + ), + ) + + return append(funcs.Must(hexutil.Decode(types.CosmWasmV1)), payload...) +} + +func randMsg(status nexus.GeneralMessage_Status, payload []byte, token ...*sdk.Coin) nexus.GeneralMessage { + var asset *sdk.Coin + if len(token) > 0 { + asset = token[0] + } + + return nexus.GeneralMessage{ + ID: rand.NormalizedStr(10), + Sender: nexus.CrossChainAddress{ + Chain: nexustestutils.RandomChain(), + Address: rand.NormalizedStr(42), + }, + Recipient: nexus.CrossChainAddress{ + Chain: nexustestutils.RandomChain(), + Address: rand.NormalizedStr(42), + }, + PayloadHash: evmtestutils.RandomHash().Bytes(), + Status: status, + Asset: asset, + SourceTxID: evmtestutils.RandomHash().Bytes(), + SourceTxIndex: uint64(rand.I64Between(0, 100)), + } +} + +func TestNewMessageRoute(t *testing.T) { + var ( + ctx sdk.Context + routingCtx nexus.RoutingContext + msg nexus.GeneralMessage + route nexus.MessageRoute + + k keeper.Keeper + feegrantK *mock.FeegrantKeeperMock + ibcK *mock.IBCKeeperMock + bankK *mock.BankKeeperMock + nexusK *mock.NexusMock + accountK *mock.AccountKeeperMock + ) + + givenMessageRoute := Given("the message route", func() { + ctx, k, _, feegrantK = setup() + + ibcK = &mock.IBCKeeperMock{} + bankK = &mock.BankKeeperMock{} + nexusK = &mock.NexusMock{} + accountK = &mock.AccountKeeperMock{} + + route = keeper.NewMessageRoute(k, ibcK, feegrantK, bankK, nexusK, accountK) + }) + + givenMessageRoute. + When("payload is nil", func() { + routingCtx = nexus.RoutingContext{Payload: nil} + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, route(ctx, routingCtx, msg), "payload is required") + }). + Run(t) + + givenMessageRoute. + When("the message cannot be translated", func() { + routingCtx = nexus.RoutingContext{ + Sender: rand.AccAddr(), + FeeGranter: nil, + Payload: rand.Bytes(100), + } + msg = randMsg(nexus.Processing, routingCtx.Payload) + }). + Then("should return error", func(t *testing.T) { + assert.ErrorContains(t, route(ctx, routingCtx, msg), "invalid payload") + }). + Run(t) + + whenTheMessageCanBeTranslated := When("the message can be translated", func() { + routingCtx = nexus.RoutingContext{ + Sender: rand.AccAddr(), + Payload: randPayload(), + } + }) + + givenMessageRoute. + When2(whenTheMessageCanBeTranslated). + When("the message has no token transfer", func() { + msg = randMsg(nexus.Processing, routingCtx.Payload) + }). + Branch( + When("the fee granter is not set", func() { + routingCtx.FeeGranter = nil + }). + Then("should deduct the fee from the sender", func(t *testing.T) { + bankK.SendCoinsFunc = func(_ sdk.Context, _, _ sdk.AccAddress, _ sdk.Coins) error { return nil } + ibcK.SendMessageFunc = func(_ context.Context, _ nexus.CrossChainAddress, _ sdk.Coin, _, _ string) error { + return nil + } + + assert.NoError(t, route(ctx, routingCtx, msg)) + + assert.Len(t, bankK.SendCoinsCalls(), 1) + assert.Equal(t, routingCtx.Sender, bankK.SendCoinsCalls()[0].FromAddr) + assert.Equal(t, types.AxelarGMPAccount, bankK.SendCoinsCalls()[0].ToAddr) + assert.Equal(t, sdk.NewCoins(sdk.NewCoin(exported.NativeAsset, sdk.OneInt())), bankK.SendCoinsCalls()[0].Amt) + + assert.Len(t, ibcK.SendMessageCalls(), 1) + assert.Equal(t, msg.Recipient, ibcK.SendMessageCalls()[0].Recipient) + assert.Equal(t, sdk.NewCoin(exported.NativeAsset, sdk.OneInt()), ibcK.SendMessageCalls()[0].Asset) + assert.Equal(t, msg.ID, ibcK.SendMessageCalls()[0].ID) + }), + + When("the fee granter is set", func() { + routingCtx.FeeGranter = rand.AccAddr() + }). + Then("should deduct the fee from the fee granter", func(t *testing.T) { + feegrantK.UseGrantedFeesFunc = func(_ sdk.Context, granter, _ sdk.AccAddress, _ sdk.Coins, _ []sdk.Msg) error { + return nil + } + bankK.SendCoinsFunc = func(_ sdk.Context, _, _ sdk.AccAddress, _ sdk.Coins) error { return nil } + ibcK.SendMessageFunc = func(_ context.Context, _ nexus.CrossChainAddress, _ sdk.Coin, _, _ string) error { + return nil + } + + assert.NoError(t, route(ctx, routingCtx, msg)) + + assert.Len(t, feegrantK.UseGrantedFeesCalls(), 1) + assert.Equal(t, routingCtx.FeeGranter, feegrantK.UseGrantedFeesCalls()[0].Granter) + assert.Equal(t, routingCtx.Sender, feegrantK.UseGrantedFeesCalls()[0].Grantee) + assert.Equal(t, sdk.NewCoins(sdk.NewCoin(exported.NativeAsset, sdk.OneInt())), feegrantK.UseGrantedFeesCalls()[0].Fee) + + assert.Len(t, bankK.SendCoinsCalls(), 1) + assert.Equal(t, routingCtx.FeeGranter, bankK.SendCoinsCalls()[0].FromAddr) + assert.Equal(t, types.AxelarGMPAccount, bankK.SendCoinsCalls()[0].ToAddr) + assert.Equal(t, sdk.NewCoins(sdk.NewCoin(exported.NativeAsset, sdk.OneInt())), bankK.SendCoinsCalls()[0].Amt) + + assert.Len(t, ibcK.SendMessageCalls(), 1) + assert.Equal(t, msg.Recipient, ibcK.SendMessageCalls()[0].Recipient) + assert.Equal(t, sdk.NewCoin(exported.NativeAsset, sdk.OneInt()), ibcK.SendMessageCalls()[0].Asset) + assert.Equal(t, msg.ID, ibcK.SendMessageCalls()[0].ID) + }), + ). + Run(t) + + givenMessageRoute. + When2(whenTheMessageCanBeTranslated). + When("the message has token transfer", func() { + coin := rand.Coin() + msg = randMsg(nexus.Processing, routingCtx.Payload, &coin) + }). + Then("should deduct from the corresponding account", func(t *testing.T) { + nexusK.GetChainByNativeAssetFunc = func(_ sdk.Context, _ string) (nexus.Chain, bool) { + return exported.Axelarnet, true + } + bankK.SendCoinsFunc = func(_ sdk.Context, _, _ sdk.AccAddress, _ sdk.Coins) error { return nil } + ibcK.SendMessageFunc = func(_ context.Context, _ nexus.CrossChainAddress, _ sdk.Coin, _, _ string) error { + return nil + } + + assert.NoError(t, route(ctx, routingCtx, msg)) + + assert.Len(t, bankK.SendCoinsCalls(), 1) + assert.Equal(t, types.GetEscrowAddress(msg.Asset.Denom), bankK.SendCoinsCalls()[0].FromAddr) + assert.Equal(t, types.AxelarGMPAccount, bankK.SendCoinsCalls()[0].ToAddr) + assert.Equal(t, sdk.NewCoins(*msg.Asset), bankK.SendCoinsCalls()[0].Amt) + + assert.Len(t, ibcK.SendMessageCalls(), 1) + assert.Equal(t, msg.Recipient, ibcK.SendMessageCalls()[0].Recipient) + assert.Equal(t, *msg.Asset, ibcK.SendMessageCalls()[0].Asset) + assert.Equal(t, msg.ID, ibcK.SendMessageCalls()[0].ID) + }). + Run(t) +} diff --git a/x/axelarnet/keeper/msg_server.go b/x/axelarnet/keeper/msg_server.go index 651584261..99bb342b0 100644 --- a/x/axelarnet/keeper/msg_server.go +++ b/x/axelarnet/keeper/msg_server.go @@ -6,7 +6,6 @@ import ( "fmt" "strings" - storetypes "github.com/cosmos/cosmos-sdk/store/types" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/query" @@ -17,7 +16,6 @@ import ( "github.com/axelarnetwork/axelar-core/utils/events" "github.com/axelarnetwork/axelar-core/x/axelarnet/exported" "github.com/axelarnetwork/axelar-core/x/axelarnet/types" - evmtypes "github.com/axelarnetwork/axelar-core/x/evm/types" nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" tss "github.com/axelarnetwork/axelar-core/x/tss/exported" "github.com/axelarnetwork/utils/funcs" @@ -25,11 +23,6 @@ import ( var _ types.MsgServiceServer = msgServer{} -const ( - evmCallContractGasCost = storetypes.Gas(10000000) - cosmosCallContractGasCost = storetypes.Gas(1000000) -) - type msgServer struct { Keeper nexus types.Nexus @@ -491,46 +484,15 @@ func (s msgServer) RetryIBCTransfer(c context.Context, req *types.RetryIBCTransf func (s msgServer) RouteMessage(c context.Context, req *types.RouteMessageRequest) (*types.RouteMessageResponse, error) { ctx := sdk.UnwrapSDKContext(c) - msg, ok := s.nexus.GetMessage(ctx, req.ID) - if !ok { - return nil, fmt.Errorf("message %s not found", req.ID) - } - - if !msg.Match(req.Payload) { - return nil, fmt.Errorf("payload hash does not match") + routingCtx := nexus.RoutingContext{ + Sender: req.Sender, + FeeGranter: req.Feegranter, + Payload: req.Payload, } - - // send ibc message if destination is cosmos - if msg.Recipient.Chain.IsFrom(exported.ModuleName) { - bz, err := types.TranslateMessage(msg, req.Payload) - if err != nil { - return nil, sdkerrors.Wrap(err, "invalid payload") - } - - asset, err := s.escrowAssetToMessageSender(ctx, req, msg) - if err != nil { - return nil, err - } - - err = s.ibcK.SendMessage(c, msg.Recipient, asset, string(bz), msg.ID) - if err != nil { - return nil, err - } - } - - err := s.nexus.SetMessageProcessing(ctx, msg.ID) - if err != nil { + if err := s.nexus.RouteMessage(ctx, req.ID, routingCtx); err != nil { return nil, err } - if msg.Recipient.Chain.IsFrom(evmtypes.ModuleName) { - ctx.GasMeter().ConsumeGas(evmCallContractGasCost, "execute-message") - } else { - ctx.GasMeter().ConsumeGas(cosmosCallContractGasCost, "execute-message") - } - - s.Logger(ctx).Debug("set general message status to processing", "messageID", msg.ID) - return &types.RouteMessageResponse{}, nil } @@ -602,42 +564,3 @@ func prepareTransfer(ctx sdk.Context, k Keeper, n types.Nexus, b types.BankKeepe return coin, sender, nil } - -// all general messages are sent from the Axelar general message sender, so receiver can use the packet sender to authenticate the message -// escrowAssetToMessageSender sends the asset to general msg sender account -func (s msgServer) escrowAssetToMessageSender(ctx sdk.Context, req *types.RouteMessageRequest, msg nexus.GeneralMessage) (sdk.Coin, error) { - var asset sdk.Coin - var sender sdk.AccAddress - var err error - - switch msg.Type() { - case nexus.TypeGeneralMessage: - // pure general message, take dust amount from sender to satisfy ibc transfer requirements - asset = sdk.NewCoin(exported.NativeAsset, sdk.OneInt()) - sender = req.Sender - - if req.Feegranter != nil { - if err := s.feegrantK.UseGrantedFees(ctx, req.Feegranter, req.Sender, sdk.NewCoins(asset), []sdk.Msg{req}); err != nil { - return sdk.Coin{}, err - } - - sender = req.Feegranter - } - case nexus.TypeGeneralMessageWithToken: - // general message with token, get token from corresponding account - asset, sender, err = prepareTransfer(ctx, s.Keeper, s.nexus, s.bank, s.account, *msg.Asset) - if err != nil { - return sdk.Coin{}, err - } - default: - return sdk.Coin{}, fmt.Errorf("unrecognized message type") - } - - // use GeneralMessageSender account as the canonical general message sender - err = s.bank.SendCoins(ctx, sender, types.AxelarGMPAccount, sdk.NewCoins(asset)) - if err != nil { - return sdk.Coin{}, err - } - - return asset, nil -} diff --git a/x/axelarnet/keeper/msg_server_test.go b/x/axelarnet/keeper/msg_server_test.go index edf2e57c8..fec023e0f 100644 --- a/x/axelarnet/keeper/msg_server_test.go +++ b/x/axelarnet/keeper/msg_server_test.go @@ -1,11 +1,8 @@ package keeper_test import ( - "bytes" - "context" "crypto/sha256" "encoding/hex" - "encoding/json" "fmt" mathRand "math/rand" "strings" @@ -17,8 +14,6 @@ import ( ibctypes "github.com/cosmos/ibc-go/v4/modules/apps/transfer/types" clienttypes "github.com/cosmos/ibc-go/v4/modules/core/02-client/types" ibcclient "github.com/cosmos/ibc-go/v4/modules/core/exported" - "github.com/ethereum/go-ethereum/accounts/abi" - "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/assert" tmbytes "github.com/tendermint/tendermint/libs/bytes" @@ -37,7 +32,6 @@ import ( "github.com/axelarnetwork/utils/funcs" "github.com/axelarnetwork/utils/slices" . "github.com/axelarnetwork/utils/test" - rand2 "github.com/axelarnetwork/utils/test/rand" ) func TestHandleMsgLink(t *testing.T) { @@ -970,249 +964,44 @@ func TestAddCosmosBasedChain(t *testing.T) { func TestRouteMessage(t *testing.T) { var ( - server types.MsgServiceServer - k keeper.Keeper - nexusK *mock.NexusMock - feegrantK *mock.FeegrantKeeperMock - ctx sdk.Context - req *types.RouteMessageRequest - msg nexus.GeneralMessage + server types.MsgServiceServer + nexusK *mock.NexusMock + ctx sdk.Context ) - chain := nexustestutils.RandomChain() - chain.Module = evmtypes.ModuleName - id := rand.StrBetween(5, 100) - payload := randPayload() - coin := rand.Coin() - - msg = nexus.GeneralMessage{ - Sender: nexus.CrossChainAddress{ - Chain: nexustestutils.RandomChain(), - Address: evmtestutils.RandomAddress().Hex(), - }, - Recipient: nexus.CrossChainAddress{ - Chain: chain, - Address: rand.AccAddr().String(), - }, - PayloadHash: crypto.Keccak256Hash(payload).Bytes(), - Asset: &coin, + req := types.RouteMessageRequest{ + ID: rand.Str(10), + Sender: rand.AccAddr(), + Feegranter: rand.AccAddr(), + Payload: rand.BytesBetween(5, 100), } givenMsgServer := Given("an axelarnet msg server", func() { - ctx, k, _, feegrantK = setup() - k.InitGenesis(ctx, types.DefaultGenesisState()) - funcs.MustNoErr(k.SetCosmosChain(ctx, types.CosmosChain{ - Name: chain.Name, - AddrPrefix: rand.StrBetween(1, 10), - IBCPath: axelartestutils.RandomIBCPath(), - })) - nexusK = &mock.NexusMock{ - GetChainByNativeAssetFunc: func(sdk.Context, string) (nexus.Chain, bool) { - return chain, true - }, - SetMessageProcessingFunc: func(sdk.Context, string) error { - return nil - }, - } - ibcK := keeper.NewIBCKeeper(k, &mock.IBCTransferKeeperMock{ - TransferFunc: func(context.Context, *ibctypes.MsgTransfer) (*ibctypes.MsgTransferResponse, error) { - return &ibctypes.MsgTransferResponse{Sequence: uint64(rand2.I64Between(1, 100000))}, nil - }, - }, &mock.ChannelKeeperMock{ - GetChannelClientStateFunc: func(sdk.Context, string, string) (string, ibcclient.ClientState, error) { - return "07-tendermint-0", axelartestutils.ClientState(), nil - }, - }) - bankK := &mock.BankKeeperMock{ - MintCoinsFunc: func(sdk.Context, string, sdk.Coins) error { return nil }, - SendCoinsFunc: func(sdk.Context, sdk.AccAddress, sdk.AccAddress, sdk.Coins) error { return nil }, - } - accountK := &mock.AccountKeeperMock{ - GetModuleAddressFunc: func(moduleName string) sdk.AccAddress { - return rand.AccAddr() - }, - } - server = keeper.NewMsgServerImpl(k, nexusK, bankK, accountK, ibcK) - }) - - isMessageFound := func(isFound bool, status nexus.GeneralMessage_Status) func() { - return func() { - nexusK.GetMessageFunc = func(ctx sdk.Context, messageID string) (nexus.GeneralMessage, bool) { - if !isFound { - return nexus.GeneralMessage{}, false - } - msg.Status = status - return msg, true + c, k, _, _ := setup() + ctx = c - } - } - } - - whenMessageIsFromEVM := When("message is from evm", func() { - isMessageFound(true, nexus.Approved)() - msg.Sender.Chain.Module = evmtypes.ModuleName - }) - whenMessageIsFromCosmos := When("message is from cosmos", func() { - isMessageFound(true, nexus.Approved)() - msg.Sender.Chain.Module = exported.ModuleName - }) - whenMessageIsToEVM := When("message is to evm", func() { - msg.Recipient.Chain.Module = evmtypes.ModuleName - }) - whenMessageIsToCosmos := When("message is to cosmos", func() { - msg.Recipient.Chain.Module = exported.ModuleName - }) - - requestIsMade := When("an execute message request is made", func() { - req = types.NewRouteMessage( - rand.AccAddr(), - nil, - id, - payload, - ) + nexusK = &mock.NexusMock{} + ibcK := keeper.NewIBCKeeper(k, &mock.IBCTransferKeeperMock{}, &mock.ChannelKeeperMock{}) + bankK := &mock.BankKeeperMock{} + accountK := &mock.AccountKeeperMock{} + server = keeper.NewMsgServerImpl(k, nexusK, bankK, accountK, ibcK) }) - routeFailsWithError := func(msg string) func(t *testing.T) { - return func(t *testing.T) { - _, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req) - assert.ErrorContains(t, err, msg) - } - } - - t.Run("route message", func(t *testing.T) { - givenMsgServer. - Branch( - When("general message is not found", isMessageFound(false, nexus.NonExistent)). - When2(requestIsMade). - Then("should fail", routeFailsWithError("not found")), - - whenMessageIsFromEVM. - When2(whenMessageIsToCosmos). - When("payload does not match", func() { - req = types.NewRouteMessage( - rand.AccAddr(), - nil, - id, - rand.BytesBetween(100, 500), - ) - }). - Then("should fail", routeFailsWithError("payload hash does not match")), - - whenMessageIsFromEVM. - When2(whenMessageIsToCosmos). - When("payload with version is invalid", func() { - payload = rand.Bytes(4) - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - Then("should fail", routeFailsWithError("invalid versioned payload")), - - whenMessageIsFromEVM. - When2(whenMessageIsToCosmos). - When("payload is invalid", func() { - payload = axelartestutils.PackPayloadWithVersion(types.CosmWasmV1, rand.BytesBetween(100, 500)) - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - Then("should fail", routeFailsWithError("invalid payload")), - - whenMessageIsFromCosmos. - When2(whenMessageIsToCosmos). - When("payload is invalid", func() { - payload = rand.BytesBetween(100, 500) - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - Then("should fail", routeFailsWithError("invalid payload")), - - whenMessageIsFromEVM. - When2(whenMessageIsToCosmos). - When("payload is valid", func() { - payload = randWasmPayload() - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - When("feegranter is set with no allowance", func() { - req.Feegranter = rand.AccAddr() - msg.Asset = nil - feegrantK.UseGrantedFeesFunc = func(ctx sdk.Context, granter sdk.AccAddress, addr sdk.AccAddress, amt sdk.Coins, msgs []sdk.Msg) error { - return fmt.Errorf("feegrant error") - } - }). - Then("should fail", routeFailsWithError("feegrant error")), - - whenMessageIsFromEVM. - When2(whenMessageIsToCosmos). - When("payload is valid", func() { - payload = randPayload() - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - Then("should success", func(t *testing.T) { - _, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req) - assert.NoError(t, err) - }), - - whenMessageIsFromCosmos. - When2(whenMessageIsToEVM). - When("payload is valid", func() { - payload = rand.BytesBetween(100, 500) - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - Then("should success", func(t *testing.T) { - _, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req) - assert.NoError(t, err) - }), - - whenMessageIsFromCosmos. - When2(whenMessageIsToCosmos). - When("payload is valid", func() { - payload = randPayload() - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - Then("should success", func(t *testing.T) { - _, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req) - assert.NoError(t, err) - }), - - whenMessageIsFromCosmos. - When2(whenMessageIsToCosmos). - When("payload is valid", func() { - payload = randWasmPayload() - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - Then("should success", func(t *testing.T) { - _, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req) - assert.NoError(t, err) - }), - - whenMessageIsFromEVM. - When2(whenMessageIsToCosmos). - When("payload is valid", func() { - payload = randWasmPayload() - msg.PayloadHash = crypto.Keccak256Hash(payload).Bytes() - }). - When2(requestIsMade). - When("feegranter is set with no allowance", func() { - req.Feegranter = rand.AccAddr() - msg.Asset = nil - feegrantK.UseGrantedFeesFunc = func(ctx sdk.Context, granter sdk.AccAddress, addr sdk.AccAddress, amt sdk.Coins, msgs []sdk.Msg) error { - if !bytes.Equal(req.Sender, addr) || !amt[0].Equal(sdk.NewCoin("uaxl", sdk.NewInt(1))) { - return fmt.Errorf("invalid %s %s", addr, amt) - } + givenMsgServer. + When("route message successfully", func() { + nexusK.RouteMessageFunc = func(_ sdk.Context, _ string, _ ...nexus.RoutingContext) error { return nil } + }). + Then("should route the correct message", func(t *testing.T) { + _, err := server.RouteMessage(sdk.WrapSDKContext(ctx), &req) - return nil - } - }). - Then("should success", func(t *testing.T) { - _, err := server.RouteMessage(sdk.WrapSDKContext(ctx), req) - assert.NoError(t, err) - }), - ).Run(t) - }) + assert.NoError(t, err) + assert.Len(t, nexusK.RouteMessageCalls(), 1) + assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx[0].Sender, req.Sender) + assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx[0].FeeGranter, req.Feegranter) + assert.Equal(t, nexusK.RouteMessageCalls()[0].RoutingCtx[0].Payload, req.Payload) + assert.Equal(t, nexusK.RouteMessageCalls()[0].ID, req.ID) + }). + Run(t) } func TestHandleCallContract(t *testing.T) { @@ -1430,45 +1219,3 @@ func randomTransfer(asset string, chain nexus.ChainName) nexus.CrossChainTransfe sdk.NewInt64Coin(asset, rand.I64Between(1, 10000000000)), ) } - -func randPayload() []byte { - bytesType := funcs.Must(abi.NewType("bytes", "bytes", nil)) - stringType := funcs.Must(abi.NewType("string", "string", nil)) - stringArrayType := funcs.Must(abi.NewType("string[]", "string[]", nil)) - - argNum := int(rand2.I64Between(1, 10)) - - var args abi.Arguments - for i := 0; i < argNum; i += 1 { - args = append(args, abi.Argument{Type: stringType}) - } - - schema := abi.Arguments{{Type: stringType}, {Type: stringArrayType}, {Type: stringArrayType}, {Type: bytesType}} - payload := funcs.Must( - schema.Pack( - rand.StrBetween(5, 10), - slices.Expand2(func() string { return rand.Str(5) }, argNum), - slices.Expand2(func() string { return "string" }, argNum), - funcs.Must(args.Pack(slices.Expand2(func() interface{} { return "string" }, argNum)...)), - ), - ) - - return append(funcs.Must(hexutil.Decode(types.CosmWasmV1)), payload...) -} - -func randWasmPayload() []byte { - args := make(map[string]string) - - randStr := func() string { return rand.Str(int(rand.I64Between(1, 32))) } - - argNum := int(rand2.I64Between(1, 10)) - - for i := 0; i < argNum; i += 1 { - args[randStr()] = randStr() - } - msg := make(map[string]map[string]string) - msg[randStr()] = args - payload := funcs.Must(json.Marshal(msg)) - - return axelartestutils.PackPayloadWithVersion(types.CosmWasmV2, payload) -} diff --git a/x/axelarnet/types/expected_keepers.go b/x/axelarnet/types/expected_keepers.go index b09d0e771..18867fd71 100644 --- a/x/axelarnet/types/expected_keepers.go +++ b/x/axelarnet/types/expected_keepers.go @@ -20,7 +20,7 @@ import ( nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" ) -//go:generate moq -out ./mock/expected_keepers.go -pkg mock . BaseKeeper Nexus BankKeeper IBCTransferKeeper ChannelKeeper AccountKeeper PortKeeper GovKeeper FeegrantKeeper +//go:generate moq -out ./mock/expected_keepers.go -pkg mock . BaseKeeper Nexus BankKeeper IBCTransferKeeper ChannelKeeper AccountKeeper PortKeeper GovKeeper FeegrantKeeper IBCKeeper // BaseKeeper is implemented by this module's base keeper type BaseKeeper interface { @@ -62,6 +62,7 @@ type Nexus interface { SetMessageFailed(ctx sdk.Context, id string) error GenerateMessageID(ctx sdk.Context) (string, []byte, uint64) ValidateAddress(ctx sdk.Context, address nexus.CrossChainAddress) error + RouteMessage(ctx sdk.Context, id string, routingCtx ...nexus.RoutingContext) error } // BankKeeper defines the expected interface contract the vesting module requires @@ -126,3 +127,8 @@ type GovKeeper interface { type FeegrantKeeper interface { UseGrantedFees(ctx sdk.Context, granter, grantee sdk.AccAddress, fee sdk.Coins, msgs []sdk.Msg) error } + +// IBCKeeper defines the expected IBC keeper +type IBCKeeper interface { + SendMessage(c context.Context, recipient nexus.CrossChainAddress, asset sdk.Coin, payload string, id string) error +} diff --git a/x/axelarnet/types/mock/expected_keepers.go b/x/axelarnet/types/mock/expected_keepers.go index 8c8bcac98..fffe7519b 100644 --- a/x/axelarnet/types/mock/expected_keepers.go +++ b/x/axelarnet/types/mock/expected_keepers.go @@ -611,6 +611,9 @@ var _ axelarnettypes.Nexus = &NexusMock{} // RegisterAssetFunc: func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain, asset github_com_axelarnetwork_axelar_core_x_nexus_exported.Asset, limit cosmossdktypes.Uint, window time.Duration) error { // panic("mock out the RegisterAsset method") // }, +// RouteMessageFunc: func(ctx cosmossdktypes.Context, id string, routingCtx ...github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext) error { +// panic("mock out the RouteMessage method") +// }, // SetChainFunc: func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain) { // panic("mock out the SetChain method") // }, @@ -687,6 +690,9 @@ type NexusMock struct { // RegisterAssetFunc mocks the RegisterAsset method. RegisterAssetFunc func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain, asset github_com_axelarnetwork_axelar_core_x_nexus_exported.Asset, limit cosmossdktypes.Uint, window time.Duration) error + // RouteMessageFunc mocks the RouteMessage method. + RouteMessageFunc func(ctx cosmossdktypes.Context, id string, routingCtx ...github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext) error + // SetChainFunc mocks the SetChain method. SetChainFunc func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain) @@ -842,6 +848,15 @@ type NexusMock struct { // Window is the window argument value. Window time.Duration } + // RouteMessage holds details about calls to the RouteMessage method. + RouteMessage []struct { + // Ctx is the ctx argument value. + Ctx cosmossdktypes.Context + // ID is the id argument value. + ID string + // RoutingCtx is the routingCtx argument value. + RoutingCtx []github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext + } // SetChain holds details about calls to the SetChain method. SetChain []struct { // Ctx is the ctx argument value. @@ -908,6 +923,7 @@ type NexusMock struct { lockLinkAddresses sync.RWMutex lockRateLimitTransfer sync.RWMutex lockRegisterAsset sync.RWMutex + lockRouteMessage sync.RWMutex lockSetChain sync.RWMutex lockSetMessageExecuted sync.RWMutex lockSetMessageFailed sync.RWMutex @@ -1533,6 +1549,46 @@ func (mock *NexusMock) RegisterAssetCalls() []struct { return calls } +// RouteMessage calls RouteMessageFunc. +func (mock *NexusMock) RouteMessage(ctx cosmossdktypes.Context, id string, routingCtx ...github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext) error { + if mock.RouteMessageFunc == nil { + panic("NexusMock.RouteMessageFunc: method is nil but Nexus.RouteMessage was just called") + } + callInfo := struct { + Ctx cosmossdktypes.Context + ID string + RoutingCtx []github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext + }{ + Ctx: ctx, + ID: id, + RoutingCtx: routingCtx, + } + mock.lockRouteMessage.Lock() + mock.calls.RouteMessage = append(mock.calls.RouteMessage, callInfo) + mock.lockRouteMessage.Unlock() + return mock.RouteMessageFunc(ctx, id, routingCtx...) +} + +// RouteMessageCalls gets all the calls that were made to RouteMessage. +// Check the length with: +// +// len(mockedNexus.RouteMessageCalls()) +func (mock *NexusMock) RouteMessageCalls() []struct { + Ctx cosmossdktypes.Context + ID string + RoutingCtx []github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext +} { + var calls []struct { + Ctx cosmossdktypes.Context + ID string + RoutingCtx []github_com_axelarnetwork_axelar_core_x_nexus_exported.RoutingContext + } + mock.lockRouteMessage.RLock() + calls = mock.calls.RouteMessage + mock.lockRouteMessage.RUnlock() + return calls +} + // SetChain calls SetChainFunc. func (mock *NexusMock) SetChain(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain) { if mock.SetChainFunc == nil { @@ -3226,3 +3282,93 @@ func (mock *FeegrantKeeperMock) UseGrantedFeesCalls() []struct { mock.lockUseGrantedFees.RUnlock() return calls } + +// Ensure, that IBCKeeperMock does implement axelarnettypes.IBCKeeper. +// If this is not the case, regenerate this file with moq. +var _ axelarnettypes.IBCKeeper = &IBCKeeperMock{} + +// IBCKeeperMock is a mock implementation of axelarnettypes.IBCKeeper. +// +// func TestSomethingThatUsesIBCKeeper(t *testing.T) { +// +// // make and configure a mocked axelarnettypes.IBCKeeper +// mockedIBCKeeper := &IBCKeeperMock{ +// SendMessageFunc: func(c context.Context, recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress, asset cosmossdktypes.Coin, payload string, id string) error { +// panic("mock out the SendMessage method") +// }, +// } +// +// // use mockedIBCKeeper in code that requires axelarnettypes.IBCKeeper +// // and then make assertions. +// +// } +type IBCKeeperMock struct { + // SendMessageFunc mocks the SendMessage method. + SendMessageFunc func(c context.Context, recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress, asset cosmossdktypes.Coin, payload string, id string) error + + // calls tracks calls to the methods. + calls struct { + // SendMessage holds details about calls to the SendMessage method. + SendMessage []struct { + // C is the c argument value. + C context.Context + // Recipient is the recipient argument value. + Recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress + // Asset is the asset argument value. + Asset cosmossdktypes.Coin + // Payload is the payload argument value. + Payload string + // ID is the id argument value. + ID string + } + } + lockSendMessage sync.RWMutex +} + +// SendMessage calls SendMessageFunc. +func (mock *IBCKeeperMock) SendMessage(c context.Context, recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress, asset cosmossdktypes.Coin, payload string, id string) error { + if mock.SendMessageFunc == nil { + panic("IBCKeeperMock.SendMessageFunc: method is nil but IBCKeeper.SendMessage was just called") + } + callInfo := struct { + C context.Context + Recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress + Asset cosmossdktypes.Coin + Payload string + ID string + }{ + C: c, + Recipient: recipient, + Asset: asset, + Payload: payload, + ID: id, + } + mock.lockSendMessage.Lock() + mock.calls.SendMessage = append(mock.calls.SendMessage, callInfo) + mock.lockSendMessage.Unlock() + return mock.SendMessageFunc(c, recipient, asset, payload, id) +} + +// SendMessageCalls gets all the calls that were made to SendMessage. +// Check the length with: +// +// len(mockedIBCKeeper.SendMessageCalls()) +func (mock *IBCKeeperMock) SendMessageCalls() []struct { + C context.Context + Recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress + Asset cosmossdktypes.Coin + Payload string + ID string +} { + var calls []struct { + C context.Context + Recipient github_com_axelarnetwork_axelar_core_x_nexus_exported.CrossChainAddress + Asset cosmossdktypes.Coin + Payload string + ID string + } + mock.lockSendMessage.RLock() + calls = mock.calls.SendMessage + mock.lockSendMessage.RUnlock() + return calls +} diff --git a/x/evm/keeper/message_route.go b/x/evm/keeper/message_route.go new file mode 100644 index 000000000..02a252988 --- /dev/null +++ b/x/evm/keeper/message_route.go @@ -0,0 +1,19 @@ +package keeper + +import ( + storetypes "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + + nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" +) + +// for commands approval +const gasCost = storetypes.Gas(10000000) + +func NewMessageRoute() nexus.MessageRoute { + return func(ctx sdk.Context, _ nexus.RoutingContext, _ nexus.GeneralMessage) error { + ctx.GasMeter().ConsumeGas(gasCost, "execute-message") + + return nil + } +} diff --git a/x/evm/keeper/message_route_test.go b/x/evm/keeper/message_route_test.go new file mode 100644 index 000000000..74446d81e --- /dev/null +++ b/x/evm/keeper/message_route_test.go @@ -0,0 +1,25 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/assert" + "github.com/tendermint/tendermint/libs/log" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + + "github.com/axelarnetwork/axelar-core/testutils/fake" + "github.com/axelarnetwork/axelar-core/x/evm/keeper" + nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" +) + +func TestNewMessageRoute(t *testing.T) { + route := keeper.NewMessageRoute() + + t.Run("should increment the gas meter", func(t *testing.T) { + ctx := sdk.NewContext(fake.NewMultiStore(), tmproto.Header{}, false, log.TestingLogger()) + + assert.NoError(t, route(ctx, nexus.RoutingContext{}, nexus.GeneralMessage{})) + assert.Positive(t, ctx.GasMeter().GasConsumed()) + }) +} diff --git a/x/nexus/exported/types.go b/x/nexus/exported/types.go index cb5a89015..2aec97b50 100644 --- a/x/nexus/exported/types.go +++ b/x/nexus/exported/types.go @@ -20,6 +20,15 @@ import ( // AddressValidator defines a function that implements address verification upon a request to link addresses type AddressValidator func(ctx sdk.Context, address CrossChainAddress) error +type RoutingContext struct { + Sender sdk.AccAddress + FeeGranter sdk.AccAddress + Payload []byte +} + +// MessageRoute defines a function that implements message routing +type MessageRoute func(ctx sdk.Context, routingCtx RoutingContext, msg GeneralMessage) error + // TransferStateFromString converts a describing state string to the corresponding TransferState func TransferStateFromString(s string) TransferState { state, ok := TransferState_value["TRANSFER_STATE_"+strings.ToUpper(s)] diff --git a/x/nexus/keeper/address.go b/x/nexus/keeper/address.go index 7eae6fb21..1049a7c98 100644 --- a/x/nexus/keeper/address.go +++ b/x/nexus/keeper/address.go @@ -90,7 +90,7 @@ func (k Keeper) GetRecipient(ctx sdk.Context, depositAddress exported.CrossChain // ValidateAddress validates the given cross chain address func (k Keeper) ValidateAddress(ctx sdk.Context, address exported.CrossChainAddress) error { - validator := k.GetRouter().GetAddressValidator(address.Chain.Module) + validator := k.getAddressValidator().GetAddressValidator(address.Chain.Module) if validator == nil { return fmt.Errorf("unknown module for chain %s", address.Chain.String()) } diff --git a/x/nexus/keeper/general_message.go b/x/nexus/keeper/general_message.go index 1132cf654..4fb486a83 100644 --- a/x/nexus/keeper/general_message.go +++ b/x/nexus/keeper/general_message.go @@ -230,3 +230,19 @@ func (k Keeper) validateAddressAndAsset(ctx sdk.Context, address exported.CrossC return k.validateAsset(ctx, address.Chain, asset.Denom) } + +// RouteMessage routes the given general message to the corresponding module and +// set the message status to processing +func (k Keeper) RouteMessage(ctx sdk.Context, id string, routingCtx ...exported.RoutingContext) error { + err := k.SetMessageProcessing(ctx, id) + if err != nil { + return err + } + + k.Logger(ctx).Debug("set general message status to processing", "messageID", id) + + if len(routingCtx) == 0 { + routingCtx = []exported.RoutingContext{{}} + } + return k.getMessageRouter().Route(ctx, routingCtx[0], funcs.MustOk(k.GetMessage(ctx, id))) +} diff --git a/x/nexus/keeper/genesis_test.go b/x/nexus/keeper/genesis_test.go index 37d986ef0..a905bbf17 100644 --- a/x/nexus/keeper/genesis_test.go +++ b/x/nexus/keeper/genesis_test.go @@ -48,10 +48,10 @@ func setup() (sdk.Context, Keeper) { }, } - router := types.NewRouter() + router := types.NewAddressValidator() router.AddAddressValidator(evmTypes.ModuleName, evmkeeper.NewAddressValidator()). AddAddressValidator(axelarnetTypes.ModuleName, axelarnetkeeper.NewAddressValidator(axelarnetK)) - keeper.SetRouter(router) + keeper.SetAddressValidator(router) return ctx, keeper } diff --git a/x/nexus/keeper/grpc_query_test.go b/x/nexus/keeper/grpc_query_test.go index 01dd80bfb..1f5aea5e0 100644 --- a/x/nexus/keeper/grpc_query_test.go +++ b/x/nexus/keeper/grpc_query_test.go @@ -62,13 +62,13 @@ func TestKeeper_TransfersForChain(t *testing.T) { funcs.MustNoErr(k.RegisterAsset(ctx, evm.Ethereum, exported.NewAsset(axelarnet.NativeAsset, false), utils.MaxUint, time.Hour)) funcs.MustNoErr(k.RegisterAsset(ctx, axelarnet.Axelarnet, exported.NewAsset(axelarnet.NativeAsset, true), utils.MaxUint, time.Hour)) - nexusRouter := types.NewRouter(). + nexusRouter := types.NewAddressValidator(). AddAddressValidator("evm", func(sdk.Context, exported.CrossChainAddress) error { return nil }).AddAddressValidator("axelarnet", func(sdk.Context, exported.CrossChainAddress) error { return nil }) - k.SetRouter(nexusRouter) + k.SetAddressValidator(nexusRouter) }). When("there are some pending transfers", func() { diff --git a/x/nexus/keeper/keeper.go b/x/nexus/keeper/keeper.go index 6ddb1a8d2..ab4160eb7 100644 --- a/x/nexus/keeper/keeper.go +++ b/x/nexus/keeper/keeper.go @@ -41,7 +41,9 @@ type Keeper struct { storeKey sdk.StoreKey cdc codec.BinaryCodec params params.Subspace - router types.Router + + addressValidator types.AddressValidator + messageRouter types.MessageRouter } // NewKeeper returns a new nexus keeper @@ -66,26 +68,45 @@ func (k Keeper) GetParams(ctx sdk.Context) types.Params { return p } -// SetRouter sets the nexus router. It will panic if called more than once -func (k *Keeper) SetRouter(router types.Router) { - if k.router != nil { - panic("router already set") +// SetAddressValidator sets the nexus address validator. It will panic if called more than once +func (k *Keeper) SetAddressValidator(validator types.AddressValidator) { + if k.addressValidator != nil { + panic("validator already set") + } + + k.addressValidator = validator + + // In order to avoid invalid or non-deterministic behavior, we seal the validator immediately + // to prevent additionals handlers from being registered after the keeper is initialized. + k.addressValidator.Seal() +} + +// getAddressValidator returns the nexus address validator. If not set, it returns a sealed empty validator +func (k Keeper) getAddressValidator() types.AddressValidator { + if k.addressValidator == nil { + k.SetAddressValidator(types.NewAddressValidator()) } - k.router = router + return k.addressValidator +} + +func (k *Keeper) SetMessageRouter(router types.MessageRouter) { + if k.messageRouter != nil { + panic("router already set") + } + k.messageRouter = router // In order to avoid invalid or non-deterministic behavior, we seal the router immediately // to prevent additionals handlers from being registered after the keeper is initialized. - k.router.Seal() + k.messageRouter.Seal() } -// GetRouter returns the nexus router. If no router was set, it returns a (sealed) router with no handlers -func (k Keeper) GetRouter() types.Router { - if k.router == nil { - k.SetRouter(types.NewRouter()) +func (k Keeper) getMessageRouter() types.MessageRouter { + if k.messageRouter == nil { + k.SetMessageRouter(types.NewMessageRouter()) } - return k.router + return k.messageRouter } func (k Keeper) getStore(ctx sdk.Context) utils.KVStore { diff --git a/x/nexus/keeper/keeper_test.go b/x/nexus/keeper/keeper_test.go index 2a0c740b5..60e0e3d9a 100644 --- a/x/nexus/keeper/keeper_test.go +++ b/x/nexus/keeper/keeper_test.go @@ -31,7 +31,7 @@ const maxAmount int64 = 100000000000 var k keeper.Keeper -func addressValidator() types.Router { +func addressValidator() types.AddressValidator { axelarnetK := &axelarnetmock.BaseKeeperMock{ GetCosmosChainByNameFunc: func(ctx sdk.Context, chain exported.ChainName) (axelarnetTypes.CosmosChain, bool) { var prefix string @@ -47,7 +47,7 @@ func addressValidator() types.Router { }, } - router := types.NewRouter() + router := types.NewAddressValidator() router.AddAddressValidator(evmTypes.ModuleName, evmkeeper.NewAddressValidator()). AddAddressValidator(axelarnetTypes.ModuleName, axelarnetkeeper.NewAddressValidator(axelarnetK)) @@ -58,7 +58,7 @@ func init() { encCfg := app.MakeEncodingConfig() subspace := params.NewSubspace(encCfg.Codec, encCfg.Amino, sdk.NewKVStoreKey("nexusKey"), sdk.NewKVStoreKey("tNexusKey"), "nexus") k = keeper.NewKeeper(encCfg.Codec, sdk.NewKVStoreKey("nexus"), subspace) - k.SetRouter(addressValidator()) + k.SetAddressValidator(addressValidator()) } func TestLinkAddress(t *testing.T) { diff --git a/x/nexus/keeper/transfer_test.go b/x/nexus/keeper/transfer_test.go index 5aec7567a..c76ebbc43 100644 --- a/x/nexus/keeper/transfer_test.go +++ b/x/nexus/keeper/transfer_test.go @@ -442,7 +442,7 @@ func setup(cfg params.EncodingConfig) (nexusKeeper.Keeper, sdk.Context) { ctx := sdk.NewContext(fake.NewMultiStore(), tmproto.Header{}, false, log.TestingLogger()) k.SetParams(ctx, types.DefaultParams()) - k.SetRouter(addressValidator()) + k.SetAddressValidator(addressValidator()) // register asset in ChainState for _, chain := range chains { diff --git a/x/nexus/types/address_validator.go b/x/nexus/types/address_validator.go new file mode 100644 index 000000000..40467c25b --- /dev/null +++ b/x/nexus/types/address_validator.go @@ -0,0 +1,67 @@ +package types + +import ( + "fmt" + + "github.com/axelarnetwork/axelar-core/x/nexus/exported" +) + +// AddressValidator implements a AddressValidator based on module name. +type AddressValidator interface { + AddAddressValidator(module string, validator exported.AddressValidator) AddressValidator + HasAddressValidator(module string) bool + GetAddressValidator(module string) exported.AddressValidator + Seal() +} + +var _ AddressValidator = (*addressValidator)(nil) + +type addressValidator struct { + validators map[string]exported.AddressValidator + sealed bool +} + +// NewAddressValidator creates a new AddressValidator interface instance +func NewAddressValidator() AddressValidator { + return &addressValidator{ + validators: make(map[string]exported.AddressValidator), + } +} + +// Seal prevents additional validators from being added +func (r *addressValidator) Seal() { + r.sealed = true +} + +// AddAddressValidator registers a validator for a given path +// panics if the validator is sealed, module is an empty string, or if the module has been registered already +func (r *addressValidator) AddAddressValidator(module string, validator exported.AddressValidator) AddressValidator { + if r.sealed { + panic("cannot add validator (validator sealed)") + } + + if module == "" { + panic("module name cannot be an empty string") + } + + if r.HasAddressValidator(module) { + panic(fmt.Sprintf("validator for module %s has already been registered", module)) + } + + r.validators[module] = validator + return r +} + +// HasAddressValidator returns true if a validator is registered for the given module +func (r *addressValidator) HasAddressValidator(module string) bool { + return r.validators[module] != nil +} + +// GetAddressValidator returns a validator for a given module +func (r *addressValidator) GetAddressValidator(module string) exported.AddressValidator { + if !r.HasAddressValidator(module) { + panic(fmt.Sprintf("validator for module \"%s\" not registered", module)) + } + + return r.validators[module] +} diff --git a/x/nexus/types/message_router.go b/x/nexus/types/message_router.go new file mode 100644 index 000000000..0659ff8b0 --- /dev/null +++ b/x/nexus/types/message_router.go @@ -0,0 +1,71 @@ +package types + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + + exported "github.com/axelarnetwork/axelar-core/x/nexus/exported" +) + +// MessageRouter implements a message router based on the message's destination +// chain's module name +type MessageRouter interface { + AddRoute(module string, route exported.MessageRoute) MessageRouter + Route(ctx sdk.Context, routingCtx exported.RoutingContext, msg exported.GeneralMessage) error + Seal() +} + +var _ MessageRouter = (*messageRouter)(nil) + +type messageRouter struct { + routes map[string]exported.MessageRoute + sealed bool +} + +// NewMessageRouter creates a new MessageRouter interface instance +func NewMessageRouter() MessageRouter { + return &messageRouter{ + routes: make(map[string]exported.MessageRoute), + sealed: false, + } +} + +func (r *messageRouter) AddRoute(module string, route exported.MessageRoute) MessageRouter { + if r.sealed { + panic("cannot add route (router sealed)") + } + + if module == "" { + panic("module name cannot be an empty string") + } + + if _, ok := r.routes[module]; ok { + panic(fmt.Sprintf("route for module %s has already been registered", module)) + } + + r.routes[module] = route + + return r +} + +func (r messageRouter) Route(ctx sdk.Context, routingCtx exported.RoutingContext, msg exported.GeneralMessage) error { + if !r.sealed { + panic("cannot route message (router not sealed)") + } + + route, ok := r.routes[msg.Recipient.Chain.Module] + if !ok { + return fmt.Errorf("no route found for module %s", msg.Recipient.Chain.Module) + } + + if routingCtx.Payload != nil && !msg.Match(routingCtx.Payload) { + return fmt.Errorf("payload hash does not match") + } + + return route(ctx, routingCtx, msg) +} + +func (r *messageRouter) Seal() { + r.sealed = true +} diff --git a/x/nexus/types/message_router_test.go b/x/nexus/types/message_router_test.go new file mode 100644 index 000000000..71306085a --- /dev/null +++ b/x/nexus/types/message_router_test.go @@ -0,0 +1,151 @@ +package types_test + +import ( + "fmt" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/assert" + "github.com/tendermint/tendermint/libs/log" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + + "github.com/axelarnetwork/axelar-core/testutils/fake" + "github.com/axelarnetwork/axelar-core/testutils/rand" + "github.com/axelarnetwork/axelar-core/x/nexus/exported" + "github.com/axelarnetwork/axelar-core/x/nexus/types" + . "github.com/axelarnetwork/utils/test" +) + +func TestAddRoute(t *testing.T) { + var ( + router types.MessageRouter + module string + ) + + givenRouter := Given("a message router", func() { + router = types.NewMessageRouter() + }) + + givenRouter. + When("it is sealed", func() { + router.Seal() + module = "module" + }). + Then("it panics when adding a route", func(t *testing.T) { + assert.PanicsWithValue(t, "cannot add route (router sealed)", func() { + router.AddRoute(module, nil) + }) + }). + Run(t) + + givenRouter. + When("module is empty", func() { + module = "" + }). + Then("it panics when adding a route", func(t *testing.T) { + assert.PanicsWithValue(t, "module name cannot be an empty string", func() { + router.AddRoute(module, nil) + }) + }). + Run(t) + + givenRouter. + When("module route is added already", func() { + module = "module" + router.AddRoute(module, func(_ sdk.Context, _ exported.RoutingContext, _ exported.GeneralMessage) error { + return nil + }) + }). + Then("it panics when adding a route again", func(t *testing.T) { + assert.PanicsWithValue(t, fmt.Sprintf("route for module %s has already been registered", module), func() { + router.AddRoute(module, func(_ sdk.Context, _ exported.RoutingContext, _ exported.GeneralMessage) error { + return nil + }) + }) + }). + Run(t) +} + +func TestRoute(t *testing.T) { + var ( + ctx sdk.Context + routingCtx exported.RoutingContext + msg exported.GeneralMessage + router types.MessageRouter + module string + routeCount uint + route exported.MessageRoute + ) + + givenRouter := Given("a message router", func() { + ctx = sdk.NewContext(fake.NewMultiStore(), tmproto.Header{}, false, log.TestingLogger()) + router = types.NewMessageRouter() + }) + + givenRouter. + When("it is not sealed", func() {}). + Then("it panics when routing a message", func(t *testing.T) { + assert.PanicsWithValue(t, "cannot route message (router not sealed)", func() { + router.Route(ctx, exported.RoutingContext{}, exported.GeneralMessage{}) + }) + }). + Run(t) + + whenIsSealed := When("it is sealed", func() { + router.Seal() + }) + + givenRouter. + When2(whenIsSealed). + When("module is not found", func() { + msg = exported.GeneralMessage{Recipient: exported.CrossChainAddress{Chain: exported.Chain{Module: "unknown"}}} + }). + Then("it should return error", func(t *testing.T) { + assert.ErrorContains(t, router.Route(ctx, routingCtx, msg), "no route found") + }). + Run(t) + + givenRouter. + When("route is added", func() { + module = "module" + routeCount = 0 + route = func(_ sdk.Context, _ exported.RoutingContext, msg exported.GeneralMessage) error { + routeCount++ + return nil + } + + router.AddRoute(module, route) + }). + When2(whenIsSealed). + Branch( + When("payload is provided but does not match the payload hash", func() { + routingCtx = exported.RoutingContext{Payload: []byte("payload")} + msg = exported.GeneralMessage{PayloadHash: rand.Bytes(common.HashLength), Recipient: exported.CrossChainAddress{Chain: exported.Chain{Module: module}}} + }). + Then("it should return error", func(t *testing.T) { + assert.ErrorContains(t, router.Route(ctx, routingCtx, msg), "payload hash does not match") + }), + + When("payload is provided and matches the payload hash", func() { + payload := rand.Bytes(100) + routingCtx = exported.RoutingContext{Payload: payload} + msg = exported.GeneralMessage{PayloadHash: crypto.Keccak256Hash(payload).Bytes(), Recipient: exported.CrossChainAddress{Chain: exported.Chain{Module: module}}} + }). + Then("it should succeed", func(t *testing.T) { + assert.NoError(t, router.Route(ctx, routingCtx, msg), "payload hash does not match") + assert.Equal(t, uint(1), routeCount) + }), + + When("payload is not provided", func() { + routingCtx = exported.RoutingContext{Payload: nil} + msg = exported.GeneralMessage{Recipient: exported.CrossChainAddress{Chain: exported.Chain{Module: module}}} + }). + Then("it should succeed", func(t *testing.T) { + assert.NoError(t, router.Route(ctx, routingCtx, msg), "payload hash does not match") + assert.Equal(t, uint(1), routeCount) + }), + ). + Run(t) +} diff --git a/x/nexus/types/router.go b/x/nexus/types/router.go deleted file mode 100644 index 6409e05e8..000000000 --- a/x/nexus/types/router.go +++ /dev/null @@ -1,67 +0,0 @@ -package types - -import ( - "fmt" - - "github.com/axelarnetwork/axelar-core/x/nexus/exported" -) - -// Router implements a AddressValidator router based on module name. -type Router interface { - AddAddressValidator(module string, handler exported.AddressValidator) Router - HasAddressValidator(module string) bool - GetAddressValidator(module string) exported.AddressValidator - Seal() -} - -var _ Router = (*router)(nil) - -type router struct { - routes map[string]exported.AddressValidator - sealed bool -} - -// NewRouter creates a new Router interface instance -func NewRouter() Router { - return &router{ - routes: make(map[string]exported.AddressValidator), - } -} - -// Seal prevents additional route handlers from being added to the router. -func (r *router) Seal() { - r.sealed = true -} - -// AddAddressValidator registers a nexus handler for a given path and returns the handler. -// Panics if the router is sealed, module is an empty string, or if the module has been registered already. -func (r *router) AddAddressValidator(module string, handler exported.AddressValidator) Router { - if r.sealed { - panic("cannot add handler (router sealed)") - } - - if module == "" { - panic("module name cannot be an empty string") - } - - if r.HasAddressValidator(module) { - panic(fmt.Sprintf("handler for module %s has already been registered", module)) - } - - r.routes[module] = handler - return r -} - -// HasAddressValidator returns true if the router has an handler registered for the given module -func (r *router) HasAddressValidator(module string) bool { - return r.routes[module] != nil -} - -// GetAddressValidator returns a Handler for a given module. -func (r *router) GetAddressValidator(module string) exported.AddressValidator { - if !r.HasAddressValidator(module) { - panic(fmt.Sprintf("handler for module \"%s\" not registered", module)) - } - - return r.routes[module] -}