-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Parametric DMD for heat conduction example (#6)
* StopWatch python class added in python_utils module. * PointwiseSnapshot python class. * templetized pyDatabase. * bindings for CSV/HDFDatabase. need to finish all member functions. * binding of some member functions needed for parametric dmd. * templatized getVectorPointer. putDoubleArray uses numpy array. * ParametricDMD bindings. * enforce noconvert() for DMD.train() arguments. * templatized getParametricDMD. * parametric dmd for heat conduction example. * separated StopWatch class. * updated example command lines.
- Loading branch information
1 parent
17052e9
commit b9581d4
Showing
14 changed files
with
1,498 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// | ||
// Created by barrow9 on 6/4/23. | ||
// | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/operators.h> | ||
#include <pybind11/stl.h> | ||
#include "algo/DMD.h" | ||
#include "algo/ParametricDMD.h" | ||
#include "linalg/Vector.h" | ||
#include "python_utils/cpp_utils.hpp" | ||
|
||
namespace py = pybind11; | ||
using namespace CAROM; | ||
using namespace std; | ||
|
||
template <class T> | ||
T* getParametricDMD_Type( | ||
std::vector<Vector*>& parameter_points, | ||
std::vector<T*>& dmds, | ||
Vector* desired_point, | ||
std::string rbf, | ||
std::string interp_method, | ||
double closest_rbf_val, | ||
bool reorthogonalize_W) | ||
{ | ||
T *parametric_dmd = NULL; | ||
getParametricDMD(parametric_dmd, parameter_points, dmds, desired_point, | ||
rbf, interp_method, closest_rbf_val, reorthogonalize_W); | ||
return parametric_dmd; | ||
} | ||
|
||
template <class T> | ||
T* getParametricDMD_Type( | ||
std::vector<Vector*>& parameter_points, | ||
std::vector<std::string>& dmd_paths, | ||
Vector* desired_point, | ||
std::string rbf = "G", | ||
std::string interp_method = "LS", | ||
double closest_rbf_val = 0.9, | ||
bool reorthogonalize_W = false) | ||
{ | ||
T *parametric_dmd = NULL; | ||
getParametricDMD(parametric_dmd, parameter_points, dmd_paths, desired_point, | ||
rbf, interp_method, closest_rbf_val, reorthogonalize_W); | ||
return parametric_dmd; | ||
} | ||
|
||
void init_ParametricDMD(pybind11::module_ &m) { | ||
|
||
// original getParametricDMD pass a template pointer reference T* ¶metric_dmd. | ||
// While it is impossible to bind a template function as itself, | ||
// this first argument T* ¶metric_dmd is mainly for determining the type T. | ||
// Here we introduce a dummy argument in place of parametric_dmd. | ||
// This will let python decide which function to use, based on the first argument type. | ||
// We will need variants of this as we bind more DMD classes, | ||
// where dmd_type is the corresponding type. | ||
m.def("getParametricDMD", []( | ||
py::object &dmd_type, | ||
std::vector<Vector*>& parameter_points, | ||
std::vector<DMD*>& dmds, | ||
Vector* desired_point, | ||
std::string rbf = "G", | ||
std::string interp_method = "LS", | ||
double closest_rbf_val = 0.9, | ||
bool reorthogonalize_W = false) | ||
{ | ||
std::string name = dmd_type.attr("__name__").cast<std::string>(); | ||
if (name == "DMD") | ||
return getParametricDMD_Type<DMD>(parameter_points, dmds, desired_point, | ||
rbf, interp_method, closest_rbf_val, reorthogonalize_W); | ||
else | ||
{ | ||
std::string msg = name + " is not a proper libROM DMD class!\n"; | ||
throw std::runtime_error(msg.c_str()); | ||
} | ||
}, | ||
py::arg("dmd_type"), | ||
py::arg("parameter_points"), | ||
py::arg("dmds"), | ||
py::arg("desired_point"), | ||
py::arg("rbf") = "G", | ||
py::arg("interp_method") = "LS", | ||
py::arg("closest_rbf_val") = 0.9, | ||
py::arg("reorthogonalize_W") = false); | ||
|
||
// original getParametricDMD pass a template pointer reference T* ¶metric_dmd. | ||
// While it is impossible to bind a template function as itself, | ||
// this first argument T* ¶metric_dmd is mainly for determining the type T. | ||
// Here we introduce a dummy argument in place of parametric_dmd. | ||
// This will let python decide which function to use, based on the first argument type. | ||
// We will need variants of this as we bind more DMD classes, | ||
// where dmd_type is the corresponding type. | ||
m.def("getParametricDMD", []( | ||
py::object &dmd_type, | ||
std::vector<Vector*>& parameter_points, | ||
std::vector<std::string>& dmd_paths, | ||
Vector* desired_point, | ||
std::string rbf = "G", | ||
std::string interp_method = "LS", | ||
double closest_rbf_val = 0.9, | ||
bool reorthogonalize_W = false) | ||
{ | ||
std::string name = dmd_type.attr("__name__").cast<std::string>(); | ||
if (name == "DMD") | ||
return getParametricDMD_Type<DMD>(parameter_points, dmd_paths, desired_point, | ||
rbf, interp_method, closest_rbf_val, reorthogonalize_W); | ||
else | ||
{ | ||
std::string msg = name + " is not a proper libROM DMD class!\n"; | ||
throw std::runtime_error(msg.c_str()); | ||
} | ||
}, | ||
py::arg("dmd_type"), | ||
py::arg("parameter_points"), | ||
py::arg("dmd_paths"), | ||
py::arg("desired_point"), | ||
py::arg("rbf") = "G", | ||
py::arg("interp_method") = "LS", | ||
py::arg("closest_rbf_val") = 0.9, | ||
py::arg("reorthogonalize_W") = false); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import numpy as np | ||
import mfem.par as mfem | ||
from mpi4py import MPI | ||
|
||
class PointwiseSnapshot: | ||
finder = None | ||
npoints = -1 | ||
dims = [-1] * 3 | ||
spaceDim = -1 | ||
|
||
domainMin = mfem.Vector() | ||
domainMax = mfem.Vector() | ||
xyz = mfem.Vector() | ||
|
||
def __init__(self, sdim, dims_): | ||
self.finder = None | ||
self.spaceDim = sdim | ||
assert((1 < sdim) and (sdim < 4)) | ||
|
||
self.npoints = np.prod(dims_[:self.spaceDim]) | ||
self.dims = np.ones((3,), dtype=int) | ||
self.dims[:self.spaceDim] = dims_[:self.spaceDim] | ||
|
||
self.xyz.SetSize(self.npoints * self.spaceDim) | ||
self.xyz.Assign(0.) | ||
return | ||
|
||
def SetMesh(self, pmesh): | ||
if (self.finder is not None): | ||
self.finder.FreeData() # Free the internal gslib data. | ||
del self.finder | ||
|
||
assert(pmesh.Dimension() == self.spaceDim) | ||
assert(pmesh.SpaceDimension() == self.spaceDim) | ||
|
||
self.domainMin, self.domainMax = pmesh.GetBoundingBox(0) | ||
|
||
h = [0.] * 3 | ||
for i in range(self.spaceDim): | ||
h[i] = (self.domainMax[i] - self.domainMin[i]) / float(self.dims[i] - 1) | ||
|
||
rank = pmesh.GetMyRank() | ||
if (rank == 0): | ||
print("PointwiseSnapshot on bounding box from (", | ||
self.domainMin[:self.spaceDim], | ||
") to (", self.domainMax[:self.spaceDim], ")") | ||
|
||
# TODO(kevin): we might want to re-write this loop in python manner. | ||
xyzData = self.xyz.GetDataArray() | ||
for k in range(self.dims[2]): | ||
pz = self.domainMin[2] + k * h[2] if (self.spaceDim > 2) else 0.0 | ||
osk = k * self.dims[0] * self.dims[1] | ||
|
||
for j in range(self.dims[1]): | ||
py = self.domainMin[1] + j * h[1] | ||
osj = (j * self.dims[0]) + osk | ||
|
||
for i in range(self.dims[0]): | ||
px = self.domainMin[0] + i * h[0] | ||
xyzData[i + osj] = px | ||
if (self.spaceDim > 1): xyzData[self.npoints + i + osj] = py | ||
if (self.spaceDim > 2): xyzData[2 * self.npoints + i + osj] = pz | ||
|
||
self.finder = mfem.FindPointsGSLIB(MPI.COMM_WORLD) | ||
# mfem.FindPointsGSLIB() | ||
self.finder.Setup(pmesh) | ||
self.finder.SetL2AvgType(mfem.FindPointsGSLIB.NONE) | ||
return | ||
|
||
def GetSnapshot(self, f, s): | ||
vdim = f.FESpace().GetVDim() | ||
s.SetSize(self.npoints * vdim) | ||
|
||
self.finder.Interpolate(self.xyz, f, s) | ||
|
||
code_out = self.finder.GetCode() | ||
print(type(code_out)) | ||
print(code_out.__dir__()) | ||
|
||
assert(code_out.Size() == self.npoints) | ||
|
||
# Note that Min() and Max() are not defined for Array<unsigned int> | ||
#MFEM_VERIFY(code_out.Min() >= 0 && code_out.Max() < 2, ""); | ||
cmin = code_out[0] | ||
cmax = code_out[0] | ||
# TODO(kevin): does this work for mfem array? | ||
for c in code_out: | ||
if (c < cmin): | ||
cmin = c | ||
if (c > cmax): | ||
cmax = c | ||
|
||
assert((cmin >= 0) and (cmax < 2)) | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# To add pure python routines to this module, | ||
# either define/import the python routine in this file. | ||
# This will combine both c++ bindings/pure python routines into this module. | ||
from .Utilities import ComputeCtAB | ||
from .Utilities import ComputeCtAB | ||
from .PointwiseSnapshot import PointwiseSnapshot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.