Skip to content

Commit

Permalink
refactor(schema): simplify dispatch, no recusion, ref #149
Browse files Browse the repository at this point in the history
  • Loading branch information
iboB committed Dec 2, 2024
1 parent 88e1680 commit eb76ef4
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 97 deletions.
16 changes: 16 additions & 0 deletions astl/include/astl/tuple_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,20 @@ constexpr decltype(auto) switch_index(Tuple& tup, int i, VFunc vf, NFunc nf) {
return impl::find_if<0>(tup, qfunc, vf, nf);
}

namespace impl {
template <typename Func>
struct expand_for_each {
Func& f;
template <typename... Args>
constexpr void operator()(Args&... args) {
(f(args), ...);
}
};
} // namespace impl

template <typename Tuple, typename Func>
constexpr void for_each(Tuple& tup, Func f) {
std::apply(impl::expand_for_each{f}, tup);
}

} // namespace astl::tuple
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@

namespace ac::local::schema {

struct DummyAInterface {
static constexpr auto id = "dummy-a/v1";
struct DummyInterface {
static constexpr auto id = "dummy/v1";

struct OpRun {
static constexpr auto id = "run";
static constexpr auto desc = "Run the dummy inference and produce some output";

struct Params {
Field<std::vector<std::string>> input;
Field<bool> splice = Default(false);
Field<bool> splice = Default(true);
Field<int> throwOn = Default(-1);

template <typename Visitor>
void visitFields(Visitor& v) {
v(input, "input", "Input items");
v(splice, "splice", "Splice input with model data (otherwise append model data to input)");
v(throwOn, "throwOn", "Throw exception on n-th token (or don't throw if -1)");
v(throwOn, "throw_on", "Throw exception on n-th token (or don't throw if -1)");
}
};
struct Return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
// SPDX-License-Identifier: MIT
//
#pragma once
#include "DummyAInterface.hpp"
#include "DummyInterface.hpp"
#include <tuple>

namespace ac::local::schema {

struct Dummy {
struct DummyLoader {
static constexpr auto id = "dummy";
struct Params {
Field<std::string> spliceString = std::nullopt;
Expand All @@ -22,15 +22,15 @@ struct Dummy {
static constexpr auto id = "general";

struct Params {
Field<int> cutoff = Default(0);
Field<int> cutoff = Default(-1);

template <typename Visitor>
void visitFields(Visitor& v) {
v(cutoff, "cutoff", "Cutoff value");
v(cutoff, "cutoff", "Cut off model data to n-th element (or don't cut if -1)");
}
};

using Interfaces = std::tuple<DummyAInterface>;
using Interfaces = std::tuple<DummyInterface>;
};

using Instances = std::tuple<InstanceGeneral>;
Expand Down
25 changes: 14 additions & 11 deletions dummy-plugin/code/ac/dummy/LocalDummy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: MIT
//
#include "LocalDummy.hpp"
#include "DummySchema.hpp"
#include "DummyLoaderSchema.hpp"

#include "Instance.hpp"
#include "Model.hpp"
Expand All @@ -27,8 +27,9 @@ namespace {
class DummyInstance final : public Instance {
std::shared_ptr<dummy::Model> m_model;
dummy::Instance m_instance;
schema::OpDispatcherData m_dispatcherData;
public:
using Schema = schema::Dummy::InstanceGeneral;
using Schema = schema::DummyLoader::InstanceGeneral;

static dummy::Instance::InitParams InitParams_fromDict(Dict&& d) {
auto schemaParams = schema::Struct_fromDict<Schema::Params>(astl::move(d));
Expand All @@ -40,16 +41,18 @@ class DummyInstance final : public Instance {
DummyInstance(std::shared_ptr<dummy::Model> model, Dict&& params)
: m_model(astl::move(model))
, m_instance(*m_model, InitParams_fromDict(astl::move(params)))
{}
{
schema::registerHandlers<schema::DummyInterface::Ops>(m_dispatcherData, *this);
}

schema::DummyAInterface::OpRun::Return on(schema::DummyAInterface::OpRun, schema::DummyAInterface::OpRun::Params params) {
schema::DummyInterface::OpRun::Return on(schema::DummyInterface::OpRun, schema::DummyInterface::OpRun::Params params) {
dummy::Instance::SessionParams sparams;
sparams.splice = params.splice;
sparams.throwOn = params.throwOn;

auto s = m_instance.newSession(std::move(params.input), sparams);

schema::DummyAInterface::OpRun::Return ret;
schema::DummyInterface::OpRun::Return ret;
auto& res = ret.result.materialize();
for (auto& w : s) {
res += w;
Expand All @@ -63,19 +66,19 @@ class DummyInstance final : public Instance {
return ret;
}

Dict onNoOp(std::string_view op, Dict) {
throw_ex{} << "dummy: unknown operation: " << op;
}

virtual Dict runOp(std::string_view op, Dict params, ProgressCb) override {
return schema::dispatchOp<schema::DummyAInterface::Ops>(op, std::move(params), *this);
auto ret = m_dispatcherData.dispatch(op, astl::move(params));
if (!ret) {
throw_ex{} << "dummy: unknown op: " << op;
}
return *ret;
}
};

class DummyModel final : public Model {
std::shared_ptr<dummy::Model> m_model;
public:
using Schema = schema::Dummy;
using Schema = schema::DummyLoader;

static dummy::Model::Params ModelParams_fromDict(Dict& d) {
auto schemaParams = schema::Struct_fromDict<Schema::Params>(std::move(d));
Expand Down
49 changes: 0 additions & 49 deletions dummy-plugin/code/ac/dummy/dummy-schema.yml

This file was deleted.

3 changes: 1 addition & 2 deletions dummy-plugin/test/t-dummy-schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
#include <ac/local/Instance.hpp>
#include <ac/local/ModelLoaderRegistry.hpp>

#include <ac/schema/Helpers.hpp>
#include <ac/schema/CallHelpers.hpp>

#include <ac-test-util/JalogFixture.inl>

#include <doctest/doctest.h>

#include <dummy-schema.hpp>
#include <aclp-dummy-plib.hpp>
#include <ac-test-data-dummy-models.h>

Expand Down
16 changes: 8 additions & 8 deletions dummy-plugin/test/t-dummy.inl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TEST_CASE("bad instance") {
auto model = f.loadModel(Model_Desc, {});
REQUIRE(model);
CHECK_THROWS_WITH(model->createInstance("nope", {}), "dummy: unknown instance type: nope");
CHECK_THROWS_WITH(model->createInstance("general", { {"cutoff", 40} }),
CHECK_THROWS_WITH(model->createInstance("general", {{"cutoff", 40}}),
"Cutoff 40 greater than model size 3");
}

Expand All @@ -53,20 +53,20 @@ TEST_CASE("general") {

CHECK_THROWS_WITH(i->runOp("nope", {}), "dummy: unknown op: nope");

CHECK_THROWS_WITH(i->runOp("run", { {"foo", "nope"} }), "Missing required field input");
CHECK_THROWS_WITH(i->runOp("run", {{"foo", "nope"}}), "Required field input is not set");

auto res = i->runOp("run", { {"input", {"a", "b"}} });
auto res = i->runOp("run", {{"input", {"a", "b"}}});
CHECK(res.at("result").get<std::string>() == "a soco b bate");

res = i->runOp("run", { {"input", {"a", "b"}}, {"splice", false} });
res = i->runOp("run", {{"input", {"a", "b"}}, {"splice", false}});
CHECK(res.at("result").get<std::string>() == "a b soco bate vira");

CHECK_THROWS_WITH(i->runOp("run", { {"input", {"a", "b"}}, {"throw_on", 3} }), "Throw on token 3");
CHECK_THROWS_WITH(i->runOp("run", {{"input", {"a", "b"}}, {"throw_on", 3}}), "Throw on token 3");

auto ci = model->createInstance("general", { {"cutoff", 2} });
auto ci = model->createInstance("general", {{"cutoff", 2}});
REQUIRE(ci);

res = ci->runOp("run", { {"input", {"a", "b", "c"}} });
res = ci->runOp("run", {{"input", {"a", "b", "c"}}});
CHECK(res.at("result").get<std::string>() == "a soco b bate c soco");
}

Expand All @@ -89,6 +89,6 @@ TEST_CASE("synthetic") {

auto instance = model->createInstance("general", {});

auto res = instance->runOp("run", { {"input", {"a", "b"}} });
auto res = instance->runOp("run", {{"input", {"a", "b"}}});
CHECK(res.at("result").get<std::string>() == "a one b two");
}
46 changes: 28 additions & 18 deletions schema/code/ac/schema/DispatchHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,40 @@
#include "IOVisitors.hpp"
#include <ac/DictFwd.hpp>
#include <astl/tuple_util.hpp>
#include <astl/qalgorithm.hpp>

namespace ac::local::schema {

namespace impl {
struct FindById {
std::string_view id;
template <typename T>
bool operator()(int, const T& elem) const {
return elem.id == id;
struct OpDispatcherData {
struct HandlerData {
std::string_view id;
std::function<Dict(Dict&&)> handler;
};
std::vector<HandlerData> handlers;

template <typename Op, typename Handler>
void registerHandler(Op, Handler& h) {
auto stronglyTypedCall = [&](Dict&& params) {
return Struct_toDict(h.on(Op{}, Struct_fromDict<typename Op::Params>(std::move(params))));
};
handlers.push_back({Op::id, std::move(stronglyTypedCall)});
}

std::optional<Dict> dispatch(std::string_view id, Dict&& params) {
auto h = astl::pfind_if(handlers, [&](const HandlerData& item) {
return item.id == id;
});
if (!h) return {};
return h->handler(std::move(params));
}
};
} // namespace impl

template <typename Ops, typename Dispatcher>
Dict dispatchOp(std::string_view opId, Dict&& opParams, Dispatcher& dispatcher) {
Ops ops;
return astl::tuple::find_if(ops, impl::FindById{opId},
[&]<typename Op>(Op& op) {
return Struct_toDict(dispatcher.on(op, Struct_fromDict<typename Op::Params>(std::move(opParams))));
},
[&] {
return dispatcher.onNoOp(opId, opParams);
}
);
template <typename Ops, typename Handler>
void registerHandlers(OpDispatcherData& data, Handler& h) {
Ops ops{};
astl::tuple::for_each(ops, [&]<typename Op>(Op op) {
data.registerHandler(op, h);
});
}

} // namespace ac::local::schema

0 comments on commit eb76ef4

Please sign in to comment.