Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GMRES/MLMG interface #3779

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions Src/Base/AMReX_FabArrayUtility.H
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,13 @@ void setBndry (MF& dst, typename MF::value_type val, int scomp, int ncomp)
dst.setBndry(val, scomp, ncomp);
}

//! dst *= val
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void Scale (MF& dst, typename MF::value_type val, int scomp, int ncomp, int nghost)
{
dst.mult(val, scomp, ncomp, nghost);
}

//! dst = src
template <class DMF, class SMF,
std::enable_if_t<IsMultiFabLike_v<DMF> &&
Expand Down Expand Up @@ -1650,6 +1657,16 @@ void Xpay (MF& dst, typename MF::value_type a, MF const& src, int scomp, int dco
MF::Xpay(dst, a, src, scomp, dcomp, ncomp, nghost);
}

//! dst = a*src_a + b*src_b
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void LinComb (MF& dst,
typename MF::value_type a, MF const& src_a, int acomp,
typename MF::value_type b, MF const& src_b, int bcomp,
int dcomp, int ncomp, IntVect const& nghost)
{
MF::LinComb(dst, a, src_a, acomp, b, src_b, bcomp, dcomp, ncomp, nghost);
}

//! dst = src w/ MPI communication
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>, int> = 0>
void ParallelCopy (MF& dst, MF const& src, int scomp, int dcomp, int ncomp,
Expand Down Expand Up @@ -1686,6 +1703,16 @@ void setBndry (Array<MF,N>& dst, typename MF::value_type val, int scomp, int nco
}
}

//! dst *= val
template <class MF, std::size_t N, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void Scale (Array<MF,N>& dst, typename MF::value_type val, int scomp, int ncomp,
int nghost)
{
for (auto& mf : dst) {
mf.mult(val, scomp, ncomp, nghost);
}
}

//! dst = src
template <class DMF, class SMF, std::size_t N,
std::enable_if_t<IsMultiFabLike_v<DMF> &&
Expand Down Expand Up @@ -1730,6 +1757,18 @@ void Xpay (Array<MF,N>& dst, typename MF::value_type a,
}
}

//! dst = a*src_a + b*src_b
template <class MF, std::size_t N, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void LinComb (Array<MF,N>& dst,
typename MF::value_type a, Array<MF,N> const& src_a, int acomp,
typename MF::value_type b, Array<MF,N> const& src_b, int bcomp,
int dcomp, int ncomp, IntVect const& nghost)
{
for (std::size_t i = 0; i < N; ++i) {
MF::LinComb(dst[i], a, src_a[i], acomp, b, src_b[i], bcomp, dcomp, ncomp, nghost);
}
}

//! dst = src w/ MPI communication
template <class MF, std::size_t N, std::enable_if_t<IsMultiFabLike_v<MF>, int> = 0>
void ParallelCopy (Array<MF,N>& dst, Array<MF,N> const& src,
Expand Down
23 changes: 15 additions & 8 deletions Src/LinearSolvers/AMReX_GMRES.H
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ namespace amrex {
* - void precond(V& lhs, V const& rhs)\n
* applies preconditioner to rhs. If there is no preconditioner,
* this function should do lhs = rhs.
* - void setVal(V& v, RT value)\n
* v = value.
* - void setToZero(V& v)\n
* v = 0.
*/
template <typename V, typename M>
class GMRES
{
public:

using RT = typename V::value_type; // double or float
using RT = typename M::RT; // double or float

GMRES ();

Expand All @@ -87,6 +87,9 @@ public:
//! Sets restart length. The default is 30.
void setRestartLength (int rl);

//! Sets the number of iterations
void setNumIters (int niters) { m_maxiter = niters; }

//! Gets the number of iterations.
[[nodiscard]] int getNumIters () const { return m_its; }

Expand Down Expand Up @@ -202,9 +205,9 @@ void GMRES<V,M>::solve (V& a_sol, V const& a_rhs, RT a_tol_rel, RT a_tol_abs, in
m_v_tmp_lhs = std::make_unique<V>(m_linop->makeVecLHS());
}
if (m_vv.empty()) {
m_vv.resize(m_restrtlen+1);
for (auto& v : m_vv) {
v = m_linop->makeVecRHS();
m_vv.reserve(m_restrtlen+1);
for (int i = 0; i < 2; ++i) { // to save space, start with just 2
m_vv.emplace_back(m_linop->makeVecRHS());
}
}

Expand All @@ -216,7 +219,7 @@ void GMRES<V,M>::solve (V& a_sol, V const& a_rhs, RT a_tol_rel, RT a_tol_abs, in
auto rnorm0 = RT(0);

m_linop->assign(m_vv[0], a_rhs);
m_linop->setVal(a_sol, RT(0.0));
m_linop->setToZero(a_sol);

m_its = 0;
m_status = -1;
Expand Down Expand Up @@ -269,6 +272,10 @@ void GMRES<V,M>::cycle (V& a_xx, int& a_status, int& a_itcount, RT& a_rnorm0)

if (a_status == 0) { break; }

while (m_vv.size() < it+2) {
m_vv.emplace_back(m_linop->makeVecRHS());
}

auto const& vv_it = m_vv[it ];
auto & vv_it1 = m_vv[it+1];

Expand Down Expand Up @@ -384,7 +391,7 @@ void GMRES<V,M>::build_solution (V& a_xx, int const it)
m_grs[k] = tt / m_hh(k,k);
}

m_linop->setVal(*m_v_tmp_rhs, RT(0.0));
m_linop->setToZero(*m_v_tmp_rhs);
for (int ii = 0; ii < it+1; ++ii) {
m_linop->increment(*m_v_tmp_rhs, m_vv[ii], m_grs[ii]);
}
Expand Down
39 changes: 23 additions & 16 deletions Src/LinearSolvers/AMReX_GMRES_MLMG.H
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class GMRESMLMGT
{
public:
using MF = typename M::MFType; // typically MultiFab
using RT = typename MF::value_type; // double or float
using RT = typename M::RT; // double or float

explicit GMRESMLMGT (M& mlmg);

Expand All @@ -29,8 +29,8 @@ public:

RT dotProduct (MF const& mf1, MF const& mf2) const;

//! lhs = value
static void setVal (MF& lhs, RT value);
//! lhs = 0
static void setToZero (MF& lhs);

//! lhs = rhs
static void assign (MF& lhs, MF const& rhs);
Expand Down Expand Up @@ -58,6 +58,8 @@ template <typename M>
GMRESMLMGT<M>::GMRESMLMGT (M& mlmg)
: m_mlmg(mlmg), m_linop(mlmg.getLinOp())
{
m_mlmg.setVerbose(0);
m_mlmg.setBottomVerbose(0);
m_mlmg.prepareLinOp();
}

Expand All @@ -71,7 +73,7 @@ template <typename M>
auto GMRESMLMGT<M>::makeVecLHS () const -> MF
{
auto mf = m_linop.make(0, 0, IntVect(1));
mf.setBndry(0);
setBndry(mf, RT(0), 0, nComp(mf));
return mf;
}

Expand All @@ -85,7 +87,7 @@ auto GMRESMLMGT<M>::norm2 (MF const& mf) const -> RT
template <typename M>
void GMRESMLMGT<M>::scale (MF& mf, RT scale_factor)
{
mf.mult(scale_factor, 0, mf.nComp());
Scale(mf, scale_factor, 0, nComp(mf), 0);
}

template <typename M>
Expand All @@ -95,27 +97,27 @@ auto GMRESMLMGT<M>::dotProduct (MF const& mf1, MF const& mf2) const -> RT
}

template <typename M>
void GMRESMLMGT<M>::setVal (MF& lhs, RT value)
void GMRESMLMGT<M>::setToZero (MF& lhs)
{
lhs.setVal(value);
setVal(lhs, RT(0.0));
}

template <typename M>
void GMRESMLMGT<M>::assign (MF& lhs, MF const& rhs)
{
MF::Copy(lhs, rhs, 0, 0, lhs.nComp(), IntVect(0));
LocalCopy(lhs, rhs, 0, 0, nComp(lhs), IntVect(0));
}

template <typename M>
void GMRESMLMGT<M>::increment (MF& lhs, MF const& rhs, RT a)
{
MF::Saxpy(lhs, a, rhs, 0, 0, lhs.nComp(), IntVect(0));
Saxpy(lhs, a, rhs, 0, 0, nComp(lhs), IntVect(0));
}

template <typename M>
void GMRESMLMGT<M>::linComb (MF& lhs, RT a, MF const& rhs_a, RT b, MF const& rhs_b)
{
MF::LinComb(lhs, a, rhs_a, 0, b, rhs_b, 0, 0, lhs.nComp(), IntVect(0));
LinComb(lhs, a, rhs_a, 0, b, rhs_b, 0, 0, nComp(lhs), IntVect(0));
}

template <typename M>
Expand All @@ -130,13 +132,18 @@ template <typename M>
void GMRESMLMGT<M>::precond (MF& lhs, MF const& rhs) const
{
if (m_use_precond) {
// for now, let's just do some smoothing
lhs.setVal(RT(0.0));
for (int m = 0; m < 4; ++m) {
m_linop.smooth(0, 0, lhs, rhs, (m==0) ? true : false);
}
AMREX_ALWAYS_ASSERT(m_linop.NAMRLevels() == 1);

m_mlmg.prepareMGcycle();

LocalCopy(m_mlmg.res[0][0], rhs, 0, 0, nComp(rhs), IntVect(0));

m_mlmg.mgVcycle(0,0);

LocalCopy(lhs, m_mlmg.cor[0][0], 0, 0, nComp(rhs), IntVect(0));

} else {
amrex::Copy(lhs, rhs, 0, 0, lhs.nComp(), IntVect(0));
LocalCopy(lhs, rhs, 0, 0, nComp(lhs), IntVect(0));
}
}

Expand Down
11 changes: 7 additions & 4 deletions Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.H
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@ namespace amrex {
* scalar, and beta is a non-negative scalar.
*
* It's the caller's responsibility to make sure rhs has consistent nodal
* data. If needed, one could use FabArray::OverrideSync to synchronize
* nodal data.
* data. If needed, one could call prepareRHS for this.
*
* The smoother is based on the 4-color Gauss-Seidel smoother of Li
* et. al. 2020. "An Efficient Preconditioner for 3-D Finite Difference
* Modeling of the Electromagnetic Diffusion Process in the Frequency
* Domain", IEEE Transactions on Geoscience and Remote Sensing, 58, 500-509.
*
* TODO: If beta is zero, the system could be singular.
*/
class MLCurlCurl
: public MLLinOpT<Array<MultiFab,3> >
Expand All @@ -48,6 +45,12 @@ public:

void setScalars (RT a_alpha, RT a_beta) noexcept;

//! Synchronize RHS on nodal points and set to zero on Dirichlet
//! boundaries. If the user can guarantee these requirements on RHS,
//! this function does not need to be called. If this is called, it
//! should only be called after setDomainBC is called.
void prepareRHS (Vector<MF*> const& rhs) const;

[[nodiscard]] std::string name () const override {
return std::string("curl of curl");
}
Expand Down
56 changes: 56 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,62 @@ void MLCurlCurl::setScalars (RT a_alpha, RT a_beta) noexcept
{
m_alpha = a_alpha;
m_beta = a_beta;
AMREX_ASSERT(m_beta > RT(0));
}

void MLCurlCurl::prepareRHS (Vector<MF*> const& rhs) const
{
MFItInfo mfi_info{};
#ifdef AMREX_USE_GPU
Vector<Array4BoxTag<RT>> tags;
mfi_info.DisableDeviceSync();
#endif

for (int amrlev = 0; amrlev < m_num_amr_levels; ++amrlev) {
for (auto& mf : *rhs[amrlev]) {
mf.OverrideSync(m_geom[amrlev][0].periodicity());

auto const idxtype = mf.ixType();
Box const domain = amrex::convert(m_geom[amrlev][0].Domain(), idxtype);

#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter mfi(mf,mfi_info); mfi.isValid(); ++mfi) {
auto const& vbx = mfi.validbox();
auto const& a = mf.array(mfi);
for (OrientationIter oit; oit; ++oit) {
Orientation const face = oit();
int const idim = face.coordDir();
bool is_dirichlet = face.isLow()
? m_lobc[0][idim] == LinOpBCType::Dirichlet
: m_hibc[0][idim] == LinOpBCType::Dirichlet;
if (is_dirichlet && domain[face] == vbx[face] &&
idxtype.nodeCentered(idim))
{
Box b = vbx;
b.setRange(idim, vbx[face], 1);
#ifdef AMREX_USE_GPU
tags.emplace_back(Array4BoxTag<RT>{a,b});
#else
amrex::LoopOnCpu(b, [&] (int i, int j, int k)
{
a(i,j,k) = RT(0.0);
});
#endif
}
}
}
}
}

#ifdef AMREX_USE_GPU
ParallelFor(tags,
[=] AMREX_GPU_DEVICE (int i, int j, int k, Array4BoxTag<RT> const& tag) noexcept
{
tag.dfab(i,j,k) = RT(0.0);
});
#endif
}

void MLCurlCurl::setLevelBC (int amrlev, const MF* levelbcdata, // TODO
Expand Down
Loading
Loading