Skip to content

Commit

Permalink
Ensure MathsProvider template arguments are forwarded through ModelT:…
Browse files Browse the repository at this point in the history
…:parseJson
  • Loading branch information
jatinchowdhury18 committed Nov 4, 2024
1 parent 32b8664 commit e0e2037
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 6 deletions.
8 changes: 4 additions & 4 deletions RTNeural/ModelT.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ namespace modelt_detail
}
}

template <typename T, int in_size, int out_size, SampleRateCorrectionMode mode>
void loadLayer(GRULayerT<T, in_size, out_size, mode>& gru, int& json_stream_idx, const nlohmann::json& l,
template <typename T, int in_size, int out_size, SampleRateCorrectionMode mode, typename MathsProvider>
void loadLayer(GRULayerT<T, in_size, out_size, mode, MathsProvider>& gru, int& json_stream_idx, const nlohmann::json& l,
const std::string& type, int layerDims, bool debug)
{
using namespace json_parser;
Expand All @@ -179,8 +179,8 @@ namespace modelt_detail
json_stream_idx++;
}

template <typename T, int in_size, int out_size, SampleRateCorrectionMode mode>
void loadLayer(LSTMLayerT<T, in_size, out_size, mode>& lstm, int& json_stream_idx, const nlohmann::json& l,
template <typename T, int in_size, int out_size, SampleRateCorrectionMode mode, typename MathsProvider>
void loadLayer(LSTMLayerT<T, in_size, out_size, mode, MathsProvider>& lstm, int& json_stream_idx, const nlohmann::json& l,
const std::string& type, int layerDims, bool debug)
{
using namespace json_parser;
Expand Down
34 changes: 33 additions & 1 deletion tests/functional/model_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <gmock/gmock.h>

#include "load_csv.hpp"
#include <RTNeural/RTNeural.h>
#include "test_maths_provider.hpp"

using namespace testing;

Expand Down Expand Up @@ -49,6 +49,21 @@ auto loadTemplatedModel()
modelT.parseJson(jsonStream, true);
return modelT;
}

auto loadTemplatedModelWithMathsProvider()
{
auto modelT = RTNeural::ModelT<TestType, 1, 1,
RTNeural::DenseT<TestType, 1, 8>,
RTNeural::TanhActivationT<TestType, 8, TestMathsProvider>,
RTNeural::Conv1DT<TestType, 8, 4, 3, 2>,
RTNeural::TanhActivationT<TestType, 4, TestMathsProvider>,
RTNeural::GRULayerT<TestType, 4, 8, RTNeural::SampleRateCorrectionMode::None, TestMathsProvider>,
RTNeural::DenseT<TestType, 8, 1>> {};

std::ifstream jsonStream(model_file, std::ifstream::binary);
modelT.parseJson(jsonStream, true);
return modelT;
}
}

TEST(TestModel, templateModelOutputMatchesDynamicModel)
Expand All @@ -67,3 +82,20 @@ TEST(TestModel, templateModelOutputMatchesDynamicModel)

EXPECT_THAT(yData, Pointwise(DoubleNear(threshold), yRefData));
}

TEST(TestModel, templateModelWithMathsProviderOutputMatchesDynamicModel)
{
constexpr double threshold = 1.0e-12;

auto xData = loadInputData();
auto yRefData = std::vector<TestType>(xData.size(), TestType { 0 });
auto yData = std::vector<TestType>(xData.size(), TestType { 0 });

auto modelRef = loadDynamicModel();
processModel(*modelRef.get(), xData, yRefData);

auto modelT = loadTemplatedModelWithMathsProvider();
processModel(modelT, xData, yData);

EXPECT_THAT(yData, Pointwise(DoubleNear(threshold), yRefData));
}
63 changes: 62 additions & 1 deletion tests/functional/templated_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "load_csv.hpp"
#include "test_configs.hpp"
#include <RTNeural/RTNeural.h>
#include "test_maths_provider.hpp"

namespace
{
Expand Down Expand Up @@ -58,6 +58,22 @@ TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForDense)
runTestTemplated<TestType, ModelType>(tests.at("dense"));
}

TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForDenseWithMathsProvider)
{
using ModelType = ModelT<TestType, 1, 1,
DenseT<TestType, 1, 8>,
TanhActivationT<TestType, 8, TestMathsProvider>,
DenseT<TestType, 8, 8>,
ReLuActivationT<TestType, 8>,
DenseT<TestType, 8, 8>,
ELuActivationT<TestType, 8, 1, 1, TestMathsProvider>,
DenseT<TestType, 8, 8>,
SoftmaxActivationT<TestType, 8, TestMathsProvider>,
DenseT<TestType, 8, 1>>;

runTestTemplated<TestType, ModelType>(tests.at("dense"));
}

TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForConv1D)
{
using ModelType = ModelT<TestType, 1, 1,
Expand Down Expand Up @@ -92,6 +108,19 @@ TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForGRU)
runTestTemplated<TestType, ModelType>(tests.at("gru"));
}

TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForGRUWithMathsProvider)
{
using ModelType = ModelT<TestType, 1, 1,
DenseT<TestType, 1, 8>,
TanhActivationT<TestType, 8, TestMathsProvider>,
GRULayerT<TestType, 8, 8, RTNeural::SampleRateCorrectionMode::None, TestMathsProvider>,
DenseT<TestType, 8, 8>,
SigmoidActivationT<TestType, 8, TestMathsProvider>,
DenseT<TestType, 8, 1>>;

runTestTemplated<TestType, ModelType>(tests.at("gru"));
}

TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForGRU1D)
{
using ModelType = ModelT<TestType, 1, 1,
Expand All @@ -103,6 +132,18 @@ TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForGRU1D)
runTestTemplated<TestType, ModelType>(tests.at("gru_1d"));
}


TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForGRU1DWithMathsProvider)
{
using ModelType = ModelT<TestType, 1, 1,
GRULayerT<TestType, 1, 8, RTNeural::SampleRateCorrectionMode::None, TestMathsProvider>,
DenseT<TestType, 8, 8>,
SigmoidActivationT<TestType, 8, TestMathsProvider>,
DenseT<TestType, 8, 1>>;

runTestTemplated<TestType, ModelType>(tests.at("gru_1d"));
}

TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForLSTM)
{
using ModelType = ModelT<TestType, 1, 1,
Expand All @@ -114,6 +155,17 @@ TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForLSTM)
runTestTemplated<TestType, ModelType>(tests.at("lstm"));
}

TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForLSTMWithMathsProvider)
{
using ModelType = ModelT<TestType, 1, 1,
DenseT<TestType, 1, 8>,
TanhActivationT<TestType, 8, TestMathsProvider>,
LSTMLayerT<TestType, 8, 8, RTNeural::SampleRateCorrectionMode::None, TestMathsProvider>,
DenseT<TestType, 8, 1>>;

runTestTemplated<TestType, ModelType>(tests.at("lstm"));
}

TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForLSTM1D)
{
using ModelType = ModelT<TestType, 1, 1,
Expand All @@ -122,3 +174,12 @@ TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForLSTM1D)

runTestTemplated<TestType, ModelType>(tests.at("lstm_1d"));
}

TEST(TestTemplatedModels, modelOutputMatchesPythonImplementationForLSTM1DWithMathsProvider)
{
using ModelType = ModelT<TestType, 1, 1,
LSTMLayerT<TestType, 1, 8, RTNeural::SampleRateCorrectionMode::None, TestMathsProvider>,
DenseT<TestType, 8, 1>>;

runTestTemplated<TestType, ModelType>(tests.at("lstm_1d"));
}
35 changes: 35 additions & 0 deletions tests/functional/test_maths_provider.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <RTNeural/RTNeural.h>

#if RTNEURAL_USE_XSIMD
struct TestMathsProvider
{
template <typename T>
static T tanh(T x) { using std::tanh; using xsimd::tanh; return tanh(x); }
template <typename T>
static T sigmoid(T x) { using std::exp; using xsimd::exp; return (T)1 / ((T)1 + exp(-x)); }
template <typename T>
static T exp(T x) { using std::exp; using xsimd::exp; return exp(x); }
};
#elif RTNEURAL_USE_EIGEN
struct TestMathsProvider
{
template <typename Matrix>
static auto tanh(const Matrix& x) { return x.array().tanh(); }
template <typename Matrix>
static auto sigmoid(const Matrix& x) { using T = typename Matrix::Scalar; return (T)1 / (((T)-1 * x.array()).array().exp() + (T)1); }
template <typename Matrix>
static auto exp(const Matrix& x) { return x.array().exp(); }
};
#else
struct TestMathsProvider
{
template <typename T>
static T tanh(T x) { return std::tanh(x); }
template <typename T>
static T sigmoid(T x) { return (T)1 / ((T)1 + std::exp(-x)); }
template <typename T>
static T exp(T x) { return std::exp(x); }
};
#endif

0 comments on commit e0e2037

Please sign in to comment.