diff --git a/app/wasm.go b/app/wasm.go index dd72239bb..e0a909d80 100644 --- a/app/wasm.go +++ b/app/wasm.go @@ -156,8 +156,8 @@ type QueryRequest struct { } // NewQueryPlugins returns a new instance of the custom query plugins -func NewQueryPlugins(msgIDGenerator nexustypes.MsgIDGenerator) *wasmkeeper.QueryPlugins { - nexusWasmQuerier := nexusKeeper.NewWasmQuerier(msgIDGenerator) +func NewQueryPlugins(nexus nexustypes.Nexus) *wasmkeeper.QueryPlugins { + nexusWasmQuerier := nexusKeeper.NewWasmQuerier(nexus) return &wasmkeeper.QueryPlugins{ Custom: func(ctx sdk.Context, request json.RawMessage) ([]byte, error) { diff --git a/app/wasm_test.go b/app/wasm_test.go index 4477f3142..b18edbd16 100644 --- a/app/wasm_test.go +++ b/app/wasm_test.go @@ -21,6 +21,7 @@ import ( "github.com/axelarnetwork/axelar-core/cmd/axelard/cmd" "github.com/axelarnetwork/axelar-core/testutils/fake" "github.com/axelarnetwork/axelar-core/testutils/rand" + nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" nexusmock "github.com/axelarnetwork/axelar-core/x/nexus/types/mock" "github.com/axelarnetwork/utils/funcs" . "github.com/axelarnetwork/utils/test" @@ -334,21 +335,21 @@ func TestMaxSizeOverrideForClient(t *testing.T) { func TestQueryPlugins(t *testing.T) { var ( - msgIDGenerator *nexusmock.MsgIDGeneratorMock - req json.RawMessage - ctx sdk.Context + nexusK *nexusmock.NexusMock + req json.RawMessage + ctx sdk.Context ) - Given("the tx id generator", func() { + Given("the nexus keeper", func() { ctx = sdk.NewContext(nil, tmproto.Header{}, false, log.TestingLogger()) - msgIDGenerator = &nexusmock.MsgIDGeneratorMock{} + nexusK = &nexusmock.NexusMock{} }). Branch( When("request is invalid", func() { req = []byte("{\"invalid\"}") }). Then("it should return an error", func(t *testing.T) { - _, err := app.NewQueryPlugins(msgIDGenerator).Custom(ctx, req) + _, err := app.NewQueryPlugins(nexusK).Custom(ctx, req) assert.ErrorContains(t, err, "invalid Custom query request") }), @@ -357,7 +358,7 @@ func TestQueryPlugins(t *testing.T) { req = []byte("{\"unknown\":{}}") }). Then("it should return an error", func(t *testing.T) { - _, err := app.NewQueryPlugins(msgIDGenerator).Custom(ctx, req) + _, err := app.NewQueryPlugins(nexusK).Custom(ctx, req) assert.ErrorContains(t, err, "unknown Custom query request") }), @@ -366,7 +367,7 @@ func TestQueryPlugins(t *testing.T) { req = []byte("{\"nexus\":{}}") }). Then("it should return an error", func(t *testing.T) { - _, err := app.NewQueryPlugins(msgIDGenerator).Custom(ctx, req) + _, err := app.NewQueryPlugins(nexusK).Custom(ctx, req) assert.ErrorContains(t, err, "unknown Nexus query request") }), @@ -374,18 +375,30 @@ func TestQueryPlugins(t *testing.T) { When("request is a nexus wasm TxID query", func() { req = []byte("{\"nexus\":{\"tx_hash_and_nonce\":{}}}") }). - Then("it should return an error", func(t *testing.T) { + Then("it should return a TxHashAndNonce response", func(t *testing.T) { txHash := [32]byte(rand.Bytes(32)) index := uint64(rand.PosI64()) - msgIDGenerator.CurrIDFunc = func(ctx sdk.Context) ([32]byte, uint64) { + nexusK.CurrIDFunc = func(ctx sdk.Context) ([32]byte, uint64) { return txHash, index } - actual, err := app.NewQueryPlugins(msgIDGenerator).Custom(ctx, req) + actual, err := app.NewQueryPlugins(nexusK).Custom(ctx, req) assert.NoError(t, err) assert.Equal(t, fmt.Sprintf("{\"tx_hash\":%s,\"nonce\":%d}", funcs.Must(json.Marshal(txHash)), index), string(actual)) }), + When("request is a nexus wasm IsChainRegistered query", func() { + req = []byte("{\"nexus\":{\"is_chain_registered\":{\"chain\": \"chain-0\"}}}") + }). + Then("it should return a chain registered response", func(t *testing.T) { + nexusK.GetChainFunc = func(ctx sdk.Context, chain nexus.ChainName) (nexus.Chain, bool) { + return nexus.Chain{}, true + } + actual, err := app.NewQueryPlugins(nexusK).Custom(ctx, req) + + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("{\"registered\":true}"), string(actual)) + }), ). Run(t) diff --git a/x/nexus/exported/types.go b/x/nexus/exported/types.go index 6a3fa9dc2..6b188ca68 100644 --- a/x/nexus/exported/types.go +++ b/x/nexus/exported/types.go @@ -417,7 +417,8 @@ func (bz *WasmBytes) UnmarshalJSON(data []byte) error { // WasmQueryRequest is the request for wasm contracts to query type WasmQueryRequest struct { - TxHashAndNonce *struct{} `json:"tx_hash_and_nonce,omitempty"` + TxHashAndNonce *struct{} `json:"tx_hash_and_nonce,omitempty"` + IsChainRegistered *IsChainRegistered `json:"is_chain_registered,omitempty"` } // WasmQueryTxHashAndNonceResponse is the response for the TxHashAndNonce query @@ -425,3 +426,11 @@ type WasmQueryTxHashAndNonceResponse struct { TxHash [32]byte `json:"tx_hash,omitempty"` // the hash of the current transaction Nonce uint64 `json:"nonce,omitempty"` // the nonce of the current execution, which increments with each entry of any wasm execution } + +type IsChainRegistered struct { + Chain string `json:"chain"` +} + +type WasmQueryIsChainRegisteredResponse struct { + Registered bool `json:"registered"` +} diff --git a/x/nexus/keeper/wasm_querier.go b/x/nexus/keeper/wasm_querier.go index 9fb785f33..206343c54 100644 --- a/x/nexus/keeper/wasm_querier.go +++ b/x/nexus/keeper/wasm_querier.go @@ -13,24 +13,36 @@ import ( // WasmQuerier is a querier for the wasm contracts type WasmQuerier struct { - msgIDGenerator types.MsgIDGenerator + nexus types.Nexus } // NewWasmQuerier creates a new WasmQuerier -func NewWasmQuerier(msgIDGenerator types.MsgIDGenerator) *WasmQuerier { - return &WasmQuerier{msgIDGenerator} +func NewWasmQuerier(nexus types.Nexus) *WasmQuerier { + return &WasmQuerier{nexus} } // Query handles the wasm queries for the nexus module func (q WasmQuerier) Query(ctx sdk.Context, req exported.WasmQueryRequest) ([]byte, error) { - if req.TxHashAndNonce != nil { - txHash, nonce := q.msgIDGenerator.CurrID(ctx) + switch { + case req.TxHashAndNonce != nil: + txHash, nonce := q.nexus.CurrID(ctx) return funcs.Must(json.Marshal(exported.WasmQueryTxHashAndNonceResponse{ TxHash: txHash, Nonce: nonce, })), nil - } + case req.IsChainRegistered != nil: + chainName := exported.ChainName(req.IsChainRegistered.Chain) + if err := chainName.Validate(); err != nil { + return nil, err + } + + _, registered := q.nexus.GetChain(ctx, chainName) + return funcs.Must(json.Marshal(exported.WasmQueryIsChainRegisteredResponse{ + Registered: registered, + })), nil - return nil, wasmvmtypes.UnsupportedRequest{Kind: "unknown Nexus query request"} + default: + return nil, wasmvmtypes.UnsupportedRequest{Kind: "unknown Nexus query request"} + } } diff --git a/x/nexus/types/expected_keepers.go b/x/nexus/types/expected_keepers.go index b29e3432f..d3bd97825 100644 --- a/x/nexus/types/expected_keepers.go +++ b/x/nexus/types/expected_keepers.go @@ -50,6 +50,7 @@ type Nexus interface { RouteMessage(ctx sdk.Context, id string, routingCtx ...exported.RoutingContext) error DequeueRouteMessage(ctx sdk.Context) (exported.GeneralMessage, bool) IsAssetRegistered(ctx sdk.Context, chain exported.Chain, denom string) bool + CurrID(ctx sdk.Context) ([32]byte, uint64) } // MsgIDGenerator provides functionality to generate msg IDs diff --git a/x/nexus/types/mock/expected_keepers.go b/x/nexus/types/mock/expected_keepers.go index 16fed77aa..141a110db 100644 --- a/x/nexus/types/mock/expected_keepers.go +++ b/x/nexus/types/mock/expected_keepers.go @@ -36,6 +36,9 @@ var _ nexustypes.Nexus = &NexusMock{} // AddChainMaintainerFunc: func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain, validator cosmossdktypes.ValAddress) error { // panic("mock out the AddChainMaintainer method") // }, +// CurrIDFunc: func(ctx cosmossdktypes.Context) ([32]byte, uint64) { +// panic("mock out the CurrID method") +// }, // DeactivateChainFunc: func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain) { // panic("mock out the DeactivateChain method") // }, @@ -130,6 +133,9 @@ type NexusMock struct { // AddChainMaintainerFunc mocks the AddChainMaintainer method. AddChainMaintainerFunc func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain, validator cosmossdktypes.ValAddress) error + // CurrIDFunc mocks the CurrID method. + CurrIDFunc func(ctx cosmossdktypes.Context) ([32]byte, uint64) + // DeactivateChainFunc mocks the DeactivateChain method. DeactivateChainFunc func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain) @@ -231,6 +237,11 @@ type NexusMock struct { // Validator is the validator argument value. Validator cosmossdktypes.ValAddress } + // CurrID holds details about calls to the CurrID method. + CurrID []struct { + // Ctx is the ctx argument value. + Ctx cosmossdktypes.Context + } // DeactivateChain holds details about calls to the DeactivateChain method. DeactivateChain []struct { // Ctx is the ctx argument value. @@ -425,6 +436,7 @@ type NexusMock struct { lockActivateChain sync.RWMutex lockActivateWasmConnection sync.RWMutex lockAddChainMaintainer sync.RWMutex + lockCurrID sync.RWMutex lockDeactivateChain sync.RWMutex lockDeactivateWasmConnection sync.RWMutex lockDequeueRouteMessage sync.RWMutex @@ -561,6 +573,38 @@ func (mock *NexusMock) AddChainMaintainerCalls() []struct { return calls } +// CurrID calls CurrIDFunc. +func (mock *NexusMock) CurrID(ctx cosmossdktypes.Context) ([32]byte, uint64) { + if mock.CurrIDFunc == nil { + panic("NexusMock.CurrIDFunc: method is nil but Nexus.CurrID was just called") + } + callInfo := struct { + Ctx cosmossdktypes.Context + }{ + Ctx: ctx, + } + mock.lockCurrID.Lock() + mock.calls.CurrID = append(mock.calls.CurrID, callInfo) + mock.lockCurrID.Unlock() + return mock.CurrIDFunc(ctx) +} + +// CurrIDCalls gets all the calls that were made to CurrID. +// Check the length with: +// +// len(mockedNexus.CurrIDCalls()) +func (mock *NexusMock) CurrIDCalls() []struct { + Ctx cosmossdktypes.Context +} { + var calls []struct { + Ctx cosmossdktypes.Context + } + mock.lockCurrID.RLock() + calls = mock.calls.CurrID + mock.lockCurrID.RUnlock() + return calls +} + // DeactivateChain calls DeactivateChainFunc. func (mock *NexusMock) DeactivateChain(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.Chain) { if mock.DeactivateChainFunc == nil {