Skip to content

Commit

Permalink
add factory function for blackboxfunc
Browse files Browse the repository at this point in the history
  • Loading branch information
baperry2 committed May 28, 2024
1 parent 0846e68 commit 28faff8
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions Source/Utility/BlackBoxFunc/BlackBoxFuncFactory.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#ifndef BLACK_BOX_FUNC_FACTORY_H
#define BLACK_BOX_FUNC_FACTORY_H

#include "Table.H"
#include "NeuralNetHomerolled.H"

namespace pele {
namespace physics {

template <typename FuncType>
struct BlackBoxFuncFactory
{
};

template <unsigned int Dimension>
struct BlackBoxFuncFactory<TabFunc<Dimension>>
{

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
BlackBoxFuncFactory() {}

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
BlackBoxFuncFactory(const ManFuncData* mf_data)
: func(static_cast<const TabFuncData*>(mf_data))
{
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(
mf_data->manmodel == ManifoldModel::TABLE,
"Runtime Table/Network must match what you compiled with");
}

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
TabFunc<Dimension>* get_func() { return &func; }

private:
TabFunc<Dimension> func;
};

template <>
struct BlackBoxFuncFactory<NNFunc>
{

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
BlackBoxFuncFactory() {}

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
BlackBoxFuncFactory(const ManFuncData* mf_data)
: func(static_cast<const NNFuncData*>(mf_data))
{
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(
mf_data->manmodel == ManifoldModel::NEURAL_NET,
"Runtime Table/Network must match what you compiled with");
}

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
NNFunc* get_func() { return &func; }

private:
NNFunc func;
};

template <>
struct BlackBoxFuncFactory<ManifoldFunc>
{

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
BlackBoxFuncFactory() {}

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
BlackBoxFuncFactory(const ManFuncData* mf_data)
{
if (mf_data->manmodel == ManifoldModel::TABLE) {
const TabFuncData* tf_data = static_cast<const TabFuncData*>(mf_data);
func = new pele::physics::TabFunc<>(tf_data);
} else if (mf_data->manmodel == ManifoldModel::NEURAL_NET) {
const NNFuncData* nnf_data = static_cast<const NNFuncData*>(mf_data);
func = new pele::physics::NNFunc(nnf_data);
} else {
amrex::Abort("invalid black box function type requested");
}
}

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
~BlackBoxFuncFactory() { delete func; }

AMREX_GPU_HOST_DEVICE
AMREX_FORCE_INLINE
ManifoldFunc* get_func() { return func; }

private:
ManifoldFunc* func;
};
} // namespace physics
} // namespace pele
#endif

0 comments on commit 28faff8

Please sign in to comment.