From eb76ef480cf8bf6afe8c0d822104870e20e81e34 Mon Sep 17 00:00:00 2001 From: Borislav Stanimirov Date: Mon, 2 Dec 2024 15:14:47 +0200 Subject: [PATCH] refactor(schema): simplify dispatch, no recusion, ref #149 --- astl/include/astl/tuple_util.hpp | 16 ++++++ ...DummyAInterface.hpp => DummyInterface.hpp} | 8 +-- ...{DummySchema.hpp => DummyLoaderSchema.hpp} | 10 ++-- dummy-plugin/code/ac/dummy/LocalDummy.cpp | 25 +++++----- dummy-plugin/code/ac/dummy/dummy-schema.yml | 49 ------------------- dummy-plugin/test/t-dummy-schema.cpp | 3 +- dummy-plugin/test/t-dummy.inl | 16 +++--- schema/code/ac/schema/DispatchHelpers.hpp | 46 ++++++++++------- 8 files changed, 76 insertions(+), 97 deletions(-) rename dummy-plugin/code/ac/dummy/{DummyAInterface.hpp => DummyInterface.hpp} (83%) rename dummy-plugin/code/ac/dummy/{DummySchema.hpp => DummyLoaderSchema.hpp} (74%) delete mode 100644 dummy-plugin/code/ac/dummy/dummy-schema.yml diff --git a/astl/include/astl/tuple_util.hpp b/astl/include/astl/tuple_util.hpp index 4f961463..a9aa0df7 100644 --- a/astl/include/astl/tuple_util.hpp +++ b/astl/include/astl/tuple_util.hpp @@ -54,4 +54,20 @@ constexpr decltype(auto) switch_index(Tuple& tup, int i, VFunc vf, NFunc nf) { return impl::find_if<0>(tup, qfunc, vf, nf); } +namespace impl { +template +struct expand_for_each { + Func& f; + template + constexpr void operator()(Args&... args) { + (f(args), ...); + } +}; +} // namespace impl + +template +constexpr void for_each(Tuple& tup, Func f) { + std::apply(impl::expand_for_each{f}, tup); +} + } // namespace astl::tuple diff --git a/dummy-plugin/code/ac/dummy/DummyAInterface.hpp b/dummy-plugin/code/ac/dummy/DummyInterface.hpp similarity index 83% rename from dummy-plugin/code/ac/dummy/DummyAInterface.hpp rename to dummy-plugin/code/ac/dummy/DummyInterface.hpp index 95bc10cd..602b4b80 100644 --- a/dummy-plugin/code/ac/dummy/DummyAInterface.hpp +++ b/dummy-plugin/code/ac/dummy/DummyInterface.hpp @@ -9,8 +9,8 @@ namespace ac::local::schema { -struct DummyAInterface { - static constexpr auto id = "dummy-a/v1"; +struct DummyInterface { + static constexpr auto id = "dummy/v1"; struct OpRun { static constexpr auto id = "run"; @@ -18,14 +18,14 @@ struct DummyAInterface { struct Params { Field> input; - Field splice = Default(false); + Field splice = Default(true); Field throwOn = Default(-1); template void visitFields(Visitor& v) { v(input, "input", "Input items"); v(splice, "splice", "Splice input with model data (otherwise append model data to input)"); - v(throwOn, "throwOn", "Throw exception on n-th token (or don't throw if -1)"); + v(throwOn, "throw_on", "Throw exception on n-th token (or don't throw if -1)"); } }; struct Return { diff --git a/dummy-plugin/code/ac/dummy/DummySchema.hpp b/dummy-plugin/code/ac/dummy/DummyLoaderSchema.hpp similarity index 74% rename from dummy-plugin/code/ac/dummy/DummySchema.hpp rename to dummy-plugin/code/ac/dummy/DummyLoaderSchema.hpp index d880b3ef..1dfa637c 100644 --- a/dummy-plugin/code/ac/dummy/DummySchema.hpp +++ b/dummy-plugin/code/ac/dummy/DummyLoaderSchema.hpp @@ -2,12 +2,12 @@ // SPDX-License-Identifier: MIT // #pragma once -#include "DummyAInterface.hpp" +#include "DummyInterface.hpp" #include namespace ac::local::schema { -struct Dummy { +struct DummyLoader { static constexpr auto id = "dummy"; struct Params { Field spliceString = std::nullopt; @@ -22,15 +22,15 @@ struct Dummy { static constexpr auto id = "general"; struct Params { - Field cutoff = Default(0); + Field cutoff = Default(-1); template void visitFields(Visitor& v) { - v(cutoff, "cutoff", "Cutoff value"); + v(cutoff, "cutoff", "Cut off model data to n-th element (or don't cut if -1)"); } }; - using Interfaces = std::tuple; + using Interfaces = std::tuple; }; using Instances = std::tuple; diff --git a/dummy-plugin/code/ac/dummy/LocalDummy.cpp b/dummy-plugin/code/ac/dummy/LocalDummy.cpp index b3b484f1..772039a9 100644 --- a/dummy-plugin/code/ac/dummy/LocalDummy.cpp +++ b/dummy-plugin/code/ac/dummy/LocalDummy.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT // #include "LocalDummy.hpp" -#include "DummySchema.hpp" +#include "DummyLoaderSchema.hpp" #include "Instance.hpp" #include "Model.hpp" @@ -27,8 +27,9 @@ namespace { class DummyInstance final : public Instance { std::shared_ptr m_model; dummy::Instance m_instance; + schema::OpDispatcherData m_dispatcherData; public: - using Schema = schema::Dummy::InstanceGeneral; + using Schema = schema::DummyLoader::InstanceGeneral; static dummy::Instance::InitParams InitParams_fromDict(Dict&& d) { auto schemaParams = schema::Struct_fromDict(astl::move(d)); @@ -40,16 +41,18 @@ class DummyInstance final : public Instance { DummyInstance(std::shared_ptr model, Dict&& params) : m_model(astl::move(model)) , m_instance(*m_model, InitParams_fromDict(astl::move(params))) - {} + { + schema::registerHandlers(m_dispatcherData, *this); + } - schema::DummyAInterface::OpRun::Return on(schema::DummyAInterface::OpRun, schema::DummyAInterface::OpRun::Params params) { + schema::DummyInterface::OpRun::Return on(schema::DummyInterface::OpRun, schema::DummyInterface::OpRun::Params params) { dummy::Instance::SessionParams sparams; sparams.splice = params.splice; sparams.throwOn = params.throwOn; auto s = m_instance.newSession(std::move(params.input), sparams); - schema::DummyAInterface::OpRun::Return ret; + schema::DummyInterface::OpRun::Return ret; auto& res = ret.result.materialize(); for (auto& w : s) { res += w; @@ -63,19 +66,19 @@ class DummyInstance final : public Instance { return ret; } - Dict onNoOp(std::string_view op, Dict) { - throw_ex{} << "dummy: unknown operation: " << op; - } - virtual Dict runOp(std::string_view op, Dict params, ProgressCb) override { - return schema::dispatchOp(op, std::move(params), *this); + auto ret = m_dispatcherData.dispatch(op, astl::move(params)); + if (!ret) { + throw_ex{} << "dummy: unknown op: " << op; + } + return *ret; } }; class DummyModel final : public Model { std::shared_ptr m_model; public: - using Schema = schema::Dummy; + using Schema = schema::DummyLoader; static dummy::Model::Params ModelParams_fromDict(Dict& d) { auto schemaParams = schema::Struct_fromDict(std::move(d)); diff --git a/dummy-plugin/code/ac/dummy/dummy-schema.yml b/dummy-plugin/code/ac/dummy/dummy-schema.yml deleted file mode 100644 index 8f9b76da..00000000 --- a/dummy-plugin/code/ac/dummy/dummy-schema.yml +++ /dev/null @@ -1,49 +0,0 @@ ---- -id: dummy -description: Dummy inference for tests, examples, and experiments. -params: - type: object - properties: - splice_string: - description: String to splice model data with input - type: string -instances: - general: - description: General instance - params: - type: object - properties: - cutoff: - description: Cut off model data to n-th element (or don't cut if -1) - type: integer - default: -1 - ops: - run: - description: Run the dummy inference and produce some output - params: - type: object - properties: - input: - description: Input items - type: array - items: - type: string - splice: - description: Splice input with model data (otherwise append model data - to input) - type: boolean - default: true - throw_on: - description: Throw exception on n-th token (or don't throw if -1) - type: integer - default: -1 - required: - - input - return: - type: object - properties: - result: - description: Output text (tokens joined with space) - type: string - required: - - result diff --git a/dummy-plugin/test/t-dummy-schema.cpp b/dummy-plugin/test/t-dummy-schema.cpp index 49dc0fe8..64825e07 100644 --- a/dummy-plugin/test/t-dummy-schema.cpp +++ b/dummy-plugin/test/t-dummy-schema.cpp @@ -6,13 +6,12 @@ #include #include -#include +#include #include #include -#include #include #include diff --git a/dummy-plugin/test/t-dummy.inl b/dummy-plugin/test/t-dummy.inl index d689b247..c00bbd2b 100644 --- a/dummy-plugin/test/t-dummy.inl +++ b/dummy-plugin/test/t-dummy.inl @@ -39,7 +39,7 @@ TEST_CASE("bad instance") { auto model = f.loadModel(Model_Desc, {}); REQUIRE(model); CHECK_THROWS_WITH(model->createInstance("nope", {}), "dummy: unknown instance type: nope"); - CHECK_THROWS_WITH(model->createInstance("general", { {"cutoff", 40} }), + CHECK_THROWS_WITH(model->createInstance("general", {{"cutoff", 40}}), "Cutoff 40 greater than model size 3"); } @@ -53,20 +53,20 @@ TEST_CASE("general") { CHECK_THROWS_WITH(i->runOp("nope", {}), "dummy: unknown op: nope"); - CHECK_THROWS_WITH(i->runOp("run", { {"foo", "nope"} }), "Missing required field input"); + CHECK_THROWS_WITH(i->runOp("run", {{"foo", "nope"}}), "Required field input is not set"); - auto res = i->runOp("run", { {"input", {"a", "b"}} }); + auto res = i->runOp("run", {{"input", {"a", "b"}}}); CHECK(res.at("result").get() == "a soco b bate"); - res = i->runOp("run", { {"input", {"a", "b"}}, {"splice", false} }); + res = i->runOp("run", {{"input", {"a", "b"}}, {"splice", false}}); CHECK(res.at("result").get() == "a b soco bate vira"); - CHECK_THROWS_WITH(i->runOp("run", { {"input", {"a", "b"}}, {"throw_on", 3} }), "Throw on token 3"); + CHECK_THROWS_WITH(i->runOp("run", {{"input", {"a", "b"}}, {"throw_on", 3}}), "Throw on token 3"); - auto ci = model->createInstance("general", { {"cutoff", 2} }); + auto ci = model->createInstance("general", {{"cutoff", 2}}); REQUIRE(ci); - res = ci->runOp("run", { {"input", {"a", "b", "c"}} }); + res = ci->runOp("run", {{"input", {"a", "b", "c"}}}); CHECK(res.at("result").get() == "a soco b bate c soco"); } @@ -89,6 +89,6 @@ TEST_CASE("synthetic") { auto instance = model->createInstance("general", {}); - auto res = instance->runOp("run", { {"input", {"a", "b"}} }); + auto res = instance->runOp("run", {{"input", {"a", "b"}}}); CHECK(res.at("result").get() == "a one b two"); } diff --git a/schema/code/ac/schema/DispatchHelpers.hpp b/schema/code/ac/schema/DispatchHelpers.hpp index 56d7d8b1..3c6b705d 100644 --- a/schema/code/ac/schema/DispatchHelpers.hpp +++ b/schema/code/ac/schema/DispatchHelpers.hpp @@ -5,30 +5,40 @@ #include "IOVisitors.hpp" #include #include +#include namespace ac::local::schema { -namespace impl { -struct FindById { - std::string_view id; - template - bool operator()(int, const T& elem) const { - return elem.id == id; +struct OpDispatcherData { + struct HandlerData { + std::string_view id; + std::function handler; + }; + std::vector handlers; + + template + void registerHandler(Op, Handler& h) { + auto stronglyTypedCall = [&](Dict&& params) { + return Struct_toDict(h.on(Op{}, Struct_fromDict(std::move(params)))); + }; + handlers.push_back({Op::id, std::move(stronglyTypedCall)}); + } + + std::optional dispatch(std::string_view id, Dict&& params) { + auto h = astl::pfind_if(handlers, [&](const HandlerData& item) { + return item.id == id; + }); + if (!h) return {}; + return h->handler(std::move(params)); } }; -} // namespace impl -template -Dict dispatchOp(std::string_view opId, Dict&& opParams, Dispatcher& dispatcher) { - Ops ops; - return astl::tuple::find_if(ops, impl::FindById{opId}, - [&](Op& op) { - return Struct_toDict(dispatcher.on(op, Struct_fromDict(std::move(opParams)))); - }, - [&] { - return dispatcher.onNoOp(opId, opParams); - } - ); +template +void registerHandlers(OpDispatcherData& data, Handler& h) { + Ops ops{}; + astl::tuple::for_each(ops, [&](Op op) { + data.registerHandler(op, h); + }); } } // namespace ac::local::schema