Skip to content

Commit

Permalink
Update GMRES/MLMG for nodal solver
Browse files Browse the repository at this point in the history
We need to set RHS on Dirichlet nodes to zero for nodal solvers because the
Dirichlet nodes are not unknowns and GMRES does not have the knowledge.
  • Loading branch information
WeiqunZhang committed Mar 7, 2024
1 parent 944d4b4 commit 8ddc4ef
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 37 deletions.
10 changes: 7 additions & 3 deletions Src/LinearSolvers/AMReX_GMRES_MLMG.H
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private:
GM m_gmres;
MG& m_mlmg;
MLLinOpT<MF>& m_linop;
bool m_use_precond = false;
bool m_use_precond = true;
bool m_prop_zero = false;
};

Expand All @@ -109,7 +109,7 @@ GMRESMLMGT<MF>::GMRESMLMGT (MG& mlmg)
"Only support single level solve");
m_mlmg.setVerbose(0);
m_mlmg.setBottomVerbose(0);
m_mlmg.prepareLinOp();
m_mlmg.prepareForGMRES();
m_gmres.define(*this);
}

Expand Down Expand Up @@ -199,12 +199,16 @@ template <typename MF>
void GMRESMLMGT<MF>::solve (MF& a_sol, MF const& a_rhs, RT a_tol_rel, RT a_tol_abs)
{
if (m_prop_zero) {
m_gmres.solve(a_sol, a_rhs, a_tol_rel, a_tol_abs);
auto rhs = makeVecRHS();
assign(rhs, a_rhs);
m_linop.setDirichletNodesToZero(0,0,rhs);
m_gmres.solve(a_sol, rhs, a_tol_rel, a_tol_abs);
} else {
auto res = makeVecRHS();
m_mlmg.apply({&res}, {&a_sol}); // res = L(sol)
increment(res, a_rhs, RT(-1)); // res = L(sol) - rhs
auto cor = makeVecLHS();
m_linop.setDirichletNodesToZero(0,0,res);
m_gmres.solve(cor, res, a_tol_rel, a_tol_abs); // L(cor) = res
increment(a_sol, cor, RT(-1)); // sol = sol - cor
}
Expand Down
20 changes: 20 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLCellABecLap.H
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public:

void prepareForSolve () override;

void setDirichletNodesToZero (int amrlev, int mglev, MF& mf) const override;

void getFluxes (const Vector<Array<MF*,AMREX_SPACEDIM> >& a_flux,
const Vector<MF*>& a_sol,
Location a_loc) const final;
Expand Down Expand Up @@ -247,6 +249,24 @@ MLCellABecLapT<MF>::prepareForSolve ()
MLCellLinOpT<MF>::prepareForSolve();
}

template <typename MF>
void
MLCellABecLapT<MF>::setDirichletNodesToZero (int amrlev, int mglev, MF& mf) const
{
auto const* omask = this->getOversetMask(amrlev, mglev);
if (omask) {
const int ncomp = this->getNComp();
auto const& mskma = omask->const_arrays();
auto const& ma = mf.arrays();
ParallelFor(mf, IntVect(0), ncomp,
[=] AMREX_GPU_DEVICE (int bno, int i, int j, int k, int n)
{
if (mskma[bno](i,j,k) == 0) { ma[bno](i,j,k,n) = RT(0.0); }
});
Gpu::streamSynchronize();
}
}

template <typename MF>
void
MLCellABecLapT<MF>::getFluxes (const Vector<Array<MF*,AMREX_SPACEDIM> >& a_flux,
Expand Down
8 changes: 4 additions & 4 deletions Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.H
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ public:

void setScalars (RT a_alpha, RT a_beta) noexcept;

//! Synchronize RHS on nodal points and set to zero on Dirichlet
//! boundaries. If the user can guarantee these requirements on RHS,
//! this function does not need to be called. If this is called, it
//! should only be called after setDomainBC is called.
//! Synchronize RHS on nodal points. If the user can guarantee it, this
//! function does not need to be called.
void prepareRHS (Vector<MF*> const& rhs) const;

void setDirichletNodesToZero (int amrlev, int mglev, MF& a_mf) const override;

[[nodiscard]] std::string name () const override {
return std::string("curl of curl");
}
Expand Down
58 changes: 32 additions & 26 deletions Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,46 +36,52 @@ void MLCurlCurl::setScalars (RT a_alpha, RT a_beta) noexcept
}

void MLCurlCurl::prepareRHS (Vector<MF*> const& rhs) const
{
for (int amrlev = 0; amrlev < m_num_amr_levels; ++amrlev) {
for (auto& mf : *rhs[amrlev]) {
mf.OverrideSync(m_geom[amrlev][0].periodicity());
}
}
}

void MLCurlCurl::setDirichletNodesToZero (int amrlev, int mglev, MF& a_mf) const
{
MFItInfo mfi_info{};
#ifdef AMREX_USE_GPU
Vector<Array4BoxTag<RT>> tags;
mfi_info.DisableDeviceSync();
#endif

for (int amrlev = 0; amrlev < m_num_amr_levels; ++amrlev) {
for (auto& mf : *rhs[amrlev]) {
mf.OverrideSync(m_geom[amrlev][0].periodicity());

auto const idxtype = mf.ixType();
Box const domain = amrex::convert(m_geom[amrlev][0].Domain(), idxtype);
for (auto& mf : a_mf)
{
auto const idxtype = mf.ixType();
Box const domain = amrex::convert(m_geom[amrlev][mglev].Domain(), idxtype);

#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter mfi(mf,mfi_info); mfi.isValid(); ++mfi) {
auto const& vbx = mfi.validbox();
auto const& a = mf.array(mfi);
for (OrientationIter oit; oit; ++oit) {
Orientation const face = oit();
int const idim = face.coordDir();
bool is_dirichlet = face.isLow()
? m_lobc[0][idim] == LinOpBCType::Dirichlet
: m_hibc[0][idim] == LinOpBCType::Dirichlet;
if (is_dirichlet && domain[face] == vbx[face] &&
idxtype.nodeCentered(idim))
{
Box b = vbx;
b.setRange(idim, vbx[face], 1);
for (MFIter mfi(mf,mfi_info); mfi.isValid(); ++mfi) {
auto const& vbx = mfi.validbox();
auto const& a = mf.array(mfi);
for (OrientationIter oit; oit; ++oit) {
Orientation const face = oit();
int const idim = face.coordDir();
bool is_dirichlet = face.isLow()
? m_lobc[0][idim] == LinOpBCType::Dirichlet
: m_hibc[0][idim] == LinOpBCType::Dirichlet;
if (is_dirichlet && domain[face] == vbx[face] &&
idxtype.nodeCentered(idim))
{
Box b = vbx;
b.setRange(idim, vbx[face], 1);
#ifdef AMREX_USE_GPU
tags.emplace_back(Array4BoxTag<RT>{a,b});
tags.emplace_back(Array4BoxTag<RT>{a,b});
#else
amrex::LoopOnCpu(b, [&] (int i, int j, int k)
{
a(i,j,k) = RT(0.0);
});
amrex::LoopOnCpu(b, [&] (int i, int j, int k)
{
a(i,j,k) = RT(0.0);
});
#endif
}
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLLinOp.H
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,16 @@ public:

virtual void prepareForSolve () = 0;

virtual void prepareForGMRES () {}

//! For GMRES to work, this might need to be implemented to mask out
//! Dirichlet nodes or cells (e.g., EB covered cells or overset cells)
virtual void setDirichletNodesToZero (int /*amrlev*/, int /*mglev*/,
MF& /*mf*/) const
{
amrex::Warning("This function might need to be implemented for GMRES to work with this LinOp.");
}

//! Is it singular on given AMR level?
[[nodiscard]] virtual bool isSingular (int amrlev) const = 0;
//! Is the bottom of MG singular?
Expand Down
10 changes: 10 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLMG.H
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ public:

void prepareMGcycle ();

void prepareForGMRES ();

void oneIter (int iter);

void miniCycle (int amrlev);
Expand Down Expand Up @@ -1114,6 +1116,14 @@ MLMGT<MF>::prepareLinOp ()
}
}

template <typename MF>
void
MLMGT<MF>::prepareForGMRES ()
{
prepareLinOp();
linop.prepareForGMRES();
}

template <typename MF>
void
MLMGT<MF>::prepareMGcycle ()
Expand Down
4 changes: 4 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLNodeLinOp.H
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ public:

void prepareForSolve () override;

void prepareForGMRES () override;

void setDirichletNodesToZero (int amrlev, int mglev, MultiFab& mf) const override;

bool isSingular (int amrlev) const override
{ return (amrlev == 0) ? m_is_bottom_singular : false; }
bool isBottomSingular () const override { return m_is_bottom_singular; }
Expand Down
36 changes: 36 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLNodeLinOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#include <AMReX_MLMG_K.H>
#include <AMReX_MultiFabUtil.H>

#ifdef AMREX_USE_EB
#include <AMReX_EBMultiFabUtil.H>
#endif

#ifdef AMREX_USE_OMP
#include <omp.h>
#endif
Expand Down Expand Up @@ -379,6 +383,38 @@ MLNodeLinOp::buildMasks ()
}
}

void
MLNodeLinOp::prepareForGMRES ()
{
if (m_coarse_dot_mask.empty()) {
int amrlev = 0;
int mglev = 0;
const Geometry& geom = m_geom[amrlev][mglev];
const iMultiFab& omask = *m_owner_mask_top;
m_coarse_dot_mask.define(omask.boxArray(), omask.DistributionMap(), 1, 0);
const auto lobc = LoBC();
const auto hibc = HiBC();
MLNodeLinOp_set_dot_mask(m_coarse_dot_mask, omask, geom, lobc, hibc, m_coarsening_strategy);
}
}

void
MLNodeLinOp::setDirichletNodesToZero (int amrlev, int mglev, MultiFab& mf) const
{
auto const& maskma = m_dirichlet_mask[amrlev][mglev]->const_arrays();
auto const& ma = mf.arrays();
const int ncomp = getNComp();
ParallelFor(mf, IntVect(0), ncomp,
[=] AMREX_GPU_DEVICE (int bno, int i, int j, int k, int n)
{
if (maskma[bno](i,j,k)) { ma[bno](i,j,k,n) = RT(0.0); }
});
Gpu::streamSynchronize();
#ifdef AMREX_USE_EB
EB_set_covered(mf, 0, ncomp, 0, RT(0.0));
#endif
}

void
MLNodeLinOp::setOversetMask (int amrlev, const iMultiFab& a_dmask)
{
Expand Down
2 changes: 1 addition & 1 deletion Tests/LinearSolvers/NodalPoisson/GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include $(AMREX_HOME)/Tools/GNUMake/Make.defs

include ./Make.package

Pdirs := Base Boundary AmrCore LinearSolvers/MLMG
Pdirs := Base Boundary AmrCore LinearSolvers

Ppack += $(foreach dir, $(Pdirs), $(AMREX_HOME)/Src/$(dir)/Make.package)

Expand Down
2 changes: 2 additions & 0 deletions Tests/LinearSolvers/NodalPoisson/MyTest.H
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ private:
int ref_ratio = 2;
int n_cell = 128;
int max_grid_size = 64;
amrex::Real domain_ratio = 1.0;

bool composite_solve = true;
bool use_gmres = false;

// For MLMG solver
int verbose = 2;
Expand Down
17 changes: 14 additions & 3 deletions Tests/LinearSolvers/NodalPoisson/MyTest.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "MyTest.H"

#include <AMReX_GMRES_MLMG.H>
#include <AMReX_MLNodeLaplacian.H>
#include <AMReX_ParmParse.H>
#include <AMReX_FillPatchUtil.H>
Expand Down Expand Up @@ -112,7 +113,13 @@ MyTest::solve ()
bcrec, 0);
}

mlmg.solve({&solution[ilev]}, {&rhs[ilev]}, reltol, 0.0);
if (use_gmres) {
GMRESMLMG gmsolver(mlmg);
gmsolver.setVerbose(verbose);
gmsolver.solve(solution[ilev], rhs[ilev], reltol, 0.0);
} else {
mlmg.solve({&solution[ilev]}, {&rhs[ilev]}, reltol, 0.0);
}
}
}
}
Expand Down Expand Up @@ -143,8 +150,13 @@ MyTest::readParameters ()
pp.query("ref_ratio", ref_ratio);
pp.query("n_cell", n_cell);
pp.query("max_grid_size", max_grid_size);
pp.query("domain_ratio", domain_ratio);

pp.query("composite_solve", composite_solve);
pp.query("use_gmres", use_gmres);
if (use_gmres) {
composite_solve = false;
}

pp.query("verbose", verbose);
pp.query("bottom_verbose", bottom_verbose);
Expand Down Expand Up @@ -190,7 +202,7 @@ MyTest::initData ()
exact_solution.resize(nlevels);
sigma.resize(nlevels);

RealBox rb({AMREX_D_DECL(0.,0.,0.)}, {AMREX_D_DECL(1.,1.,1.)});
RealBox rb({AMREX_D_DECL(0.,0.,0.)}, {AMREX_D_DECL(1.,domain_ratio,1.)});
Array<int,AMREX_SPACEDIM> is_periodic{AMREX_D_DECL(0,0,0)};
Geometry::Setup(&rb, 0, is_periodic.data());
Box domain0(IntVect{AMREX_D_DECL(0,0,0)}, IntVect{AMREX_D_DECL(n_cell-1,n_cell-1,n_cell-1)});
Expand Down Expand Up @@ -253,4 +265,3 @@ MyTest::initData ()
sigma[ilev].setVal(1.0);
}
}

16 changes: 16 additions & 0 deletions Tests/LinearSolvers/NodalPoisson/inputs-gmres
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

max_level = 1
ref_ratio = 2
n_cell = 128
max_grid_size = 64

composite_solve = 0 # composite solve or level by level?
use_gmres = 1
domain_ratio = 1.0

# For MLMG
verbose = 2
bottom_verbose = 0
max_iter = 100
max_fmg_iter = 0 # # of F-cycles before switching to V. To do pure V-cycle, set to 0
reltol = 1.e-11

0 comments on commit 8ddc4ef

Please sign in to comment.