diff --git a/core/solver/chebyshev.cpp b/core/solver/chebyshev.cpp index 97b9fbc241d..9825812271f 100644 --- a/core/solver/chebyshev.cpp +++ b/core/solver/chebyshev.cpp @@ -40,9 +40,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/distributed/helpers.hpp" #include "core/solver/ir_kernels.hpp" -#include "core/solver/residual_update.hpp" #include "core/solver/solver_base.hpp" #include "core/solver/solver_boilerplate.hpp" +#include "core/solver/update_residual.hpp" namespace gko { @@ -186,7 +186,6 @@ void Chebyshev::apply_dense_impl(const VectorType* dense_b, { using Vector = matrix::Dense; using ws = workspace_traits; - constexpr uint8 relative_stopping_id{1}; auto exec = this->get_executor(); this->setup_workspace(); @@ -229,7 +228,6 @@ void Chebyshev::apply_dense_impl(const VectorType* dense_b, auto beta_ref = ValueType{0.5} * (foci_direction_ * alpha_ref) * (foci_direction_ * alpha_ref); - bool one_changed{}; auto& stop_status = this->template create_workspace_array( ws::stop, dense_b->get_size()[1]); exec->run(chebyshev::make_initialize(&stop_status)); @@ -257,10 +255,9 @@ void Chebyshev::apply_dense_impl(const VectorType* dense_b, solver, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr, &stop_status, all_stopped); }; - bool all_stopped = residual_update( - this, iter, one_op, neg_one_op, dense_b, dense_x, residual, - residual_ptr, stop_criterion, relative_stopping_id, stop_status, - one_changed, log_func); + bool all_stopped = update_residual( + this, iter, dense_b, dense_x, residual, residual_ptr, + stop_criterion, stop_status, log_func); if (all_stopped) { break; } diff --git a/core/solver/ir.cpp b/core/solver/ir.cpp index a923bb8ccde..121de44180d 100644 --- a/core/solver/ir.cpp +++ b/core/solver/ir.cpp @@ -40,9 +40,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/distributed/helpers.hpp" #include "core/solver/ir_kernels.hpp" -#include "core/solver/residual_update.hpp" #include "core/solver/solver_base.hpp" #include "core/solver/solver_boilerplate.hpp" +#include "core/solver/update_residual.hpp" namespace gko { @@ -193,7 +193,6 @@ void Ir::apply_dense_impl(const VectorType* dense_b, { using Vector = matrix::Dense; using ws = workspace_traits; - constexpr uint8 relative_stopping_id{1}; auto exec = this->get_executor(); this->setup_workspace(); @@ -203,7 +202,6 @@ void Ir::apply_dense_impl(const VectorType* dense_b, GKO_SOLVER_ONE_MINUS_ONE(); - bool one_changed{}; auto& stop_status = this->template create_workspace_array( ws::stop, dense_b->get_size()[1]); exec->run(ir::make_initialize(&stop_status)); @@ -232,10 +230,9 @@ void Ir::apply_dense_impl(const VectorType* dense_b, solver, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr, &stop_status, all_stopped); }; - bool all_stopped = residual_update( - this, iter, one_op, neg_one_op, dense_b, dense_x, residual, - residual_ptr, stop_criterion, relative_stopping_id, stop_status, - one_changed, log_func); + bool all_stopped = update_residual( + this, iter, dense_b, dense_x, residual, residual_ptr, + stop_criterion, stop_status, log_func); if (all_stopped) { break; } diff --git a/core/solver/residual_update.hpp b/core/solver/update_residual.hpp similarity index 85% rename from core/solver/residual_update.hpp rename to core/solver/update_residual.hpp index fdba4fe924b..603bbd3cae9 100644 --- a/core/solver/residual_update.hpp +++ b/core/solver/update_residual.hpp @@ -30,8 +30,8 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *************************************************************/ -#ifndef GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_ -#define GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_ +#ifndef GKO_CORE_SOLVER_UPDATE_RESIDUAL_HPP_ +#define GKO_CORE_SOLVER_UPDATE_RESIDUAL_HPP_ #include @@ -43,17 +43,21 @@ namespace gko { namespace solver { -template -bool residual_update(SolverType* solver, int iter, const ScalarType* one_op, - const ScalarType* neg_one_op, const VectorType* dense_b, +template +bool update_residual(SolverType* solver, int iter, const VectorType* dense_b, VectorType* dense_x, VectorType* residual, const VectorType*& residual_ptr, std::unique_ptr& stop_criterion, - uint8 relative_stopping_id, - array& stop_status, bool& one_changed, - LogFunc log) + array& stop_status, LogFunc log) { + using ws = workspace_traits>; + constexpr uint8 relative_stopping_id{1}; + + // It's required to be initialized outside. + auto one_op = solver->get_workspace_op(ws::one); + auto neg_one_op = solver->get_workspace_op(ws::minus_one); + + bool one_changed{}; if (iter == 0) { // In iter 0, the iteration and residual are updated. bool all_stopped = @@ -100,4 +104,4 @@ bool residual_update(SolverType* solver, int iter, const ScalarType* one_op, } // namespace solver } // namespace gko -#endif // GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_ +#endif // GKO_CORE_SOLVER_UPDATE_RESIDUAL_HPP_