diff --git a/Solverz/solvers/nlaesolver/cnr.py b/Solverz/solvers/nlaesolver/cnr.py index 04bcb44..cf3b056 100644 --- a/Solverz/solvers/nlaesolver/cnr.py +++ b/Solverz/solvers/nlaesolver/cnr.py @@ -69,7 +69,12 @@ def f(y_, p_) -> np.ndarray: ite = 0 df = eqn.F(y, p) stats.nfeval += 1 - while max(abs(df)) > tol: + while np.max(np.abs(df)) > tol: + + if ite > opt.max_it: + print(f"Cannot converge within 100 iterations. Deviation: {np.max(np.abs(df))}!") + break + ite = ite + 1 err = 2 nofailed = True @@ -90,8 +95,8 @@ def f(y_, p_) -> np.ndarray: # error control # error estimation err = dt * np.linalg.norm( - kE.reshape(-1, ) / np.maximum(np.maximum(abs(y), abs(ynew)).reshape(-1, ), threshold), - np.Inf) + kE.reshape(-1, ) / np.maximum(np.maximum(np.abs(y), np.abs(ynew)).reshape(-1, ), threshold), + np.inf) if err > rtol: # failed step if dt <= hmin: raise ValueError(f'IntegrationTolNotMet step size: {dt} hmin: {hmin}') @@ -116,13 +121,8 @@ def f(y_, p_) -> np.ndarray: dt = np.min([dt, hmax]) - if ite > 100: - print(f"Cannot converge within 100 iterations. Deviation: {max(abs(df))}!") - stats.succeed = False - break - - if np.any(np.isnan(y)): - stats.succeed = False + if np.max(np.abs(df)) < tol: + stats.succeed = True stats.nstep = ite return aesol(y, stats) diff --git a/Solverz/solvers/nlaesolver/lm.py b/Solverz/solvers/nlaesolver/lm.py index 51ae50c..1194249 100644 --- a/Solverz/solvers/nlaesolver/lm.py +++ b/Solverz/solvers/nlaesolver/lm.py @@ -48,8 +48,13 @@ def lm(eqn: nAE, p = eqn.p # optimize.root func cannot handle callable jac that returns scipy.sparse.csc_array - sol = optimize.root(lambda x: eqn.F(x, p), y, jac=lambda x: eqn.J(x, p).toarray(), method='lm', tol=tol) - dF = eqn.F(sol.y, eqn.p) + sol = optimize.root(lambda x: eqn.F(x, p), + y, + jac=lambda x: eqn.J(x, p).toarray(), + method='lm', + tol=tol, + options={'maxiter': opt.max_it}) + dF = eqn.F(sol.x, eqn.p) if np.max(np.abs(dF)) < tol: stats.succeed = True stats.nfeval = sol.nfev diff --git a/Solverz/solvers/nlaesolver/nr.py b/Solverz/solvers/nlaesolver/nr.py index ad47666..501967e 100644 --- a/Solverz/solvers/nlaesolver/nr.py +++ b/Solverz/solvers/nlaesolver/nr.py @@ -38,27 +38,27 @@ def nr_method(eqn: nAE, if opt is None: opt = Opt() + stats = Stats('Newton') tol = opt.ite_tol p = eqn.p df = eqn.F(y, p) - - stats = Stats('Newton') stats.nfeval += 1 # main loop - while max(abs(df)) > tol: + while np.max(np.abs(df)) > tol: + + if stats.nstep > opt.max_it: + print(f"Cannot converge within 100 iterations. Deviation: {np.max(np.abs(df))}!") + break + + stats.nstep += 1 + y = y - solve(eqn.J(y, p), df) stats.ndecomp += 1 df = eqn.F(y, p) stats.nfeval += 1 - stats.nstep += 1 - - if stats.nstep >= 100: - print(f"Cannot converge within 100 iterations. Deviation: {max(abs(df))}!") - stats.succeed = False - break - if np.any(np.isnan(y)): - stats.succeed = False + if np.max(np.abs(df)) < tol: + stats.succeed = True return aesol(y, stats) diff --git a/Solverz/solvers/nlaesolver/utilities.py b/Solverz/solvers/nlaesolver/utilities.py index f0432de..232491b 100644 --- a/Solverz/solvers/nlaesolver/utilities.py +++ b/Solverz/solvers/nlaesolver/utilities.py @@ -1,5 +1,4 @@ import numpy as np -from numpy import abs, max from scipy import optimize from Solverz.num_api.num_eqn import nAE diff --git a/Solverz/solvers/option.py b/Solverz/solvers/option.py index 0f1b4f6..1457c5e 100644 --- a/Solverz/solvers/option.py +++ b/Solverz/solvers/option.py @@ -22,7 +22,8 @@ def __init__(self, partial_decompose=False, ode15smaxit=4, normcontrol=False, - numJac=False): + numJac=False, + max_it=100): self.atol = atol self.rtol = rtol self.f_savety = f_savety @@ -43,3 +44,4 @@ def __init__(self, self.ode15smaxit = ode15smaxit self.normcontrol = normcontrol self.numJac = numJac + self.max_it = max_it diff --git a/Solverz/solvers/stats.py b/Solverz/solvers/stats.py index 2b4e8f5..c3db3a2 100644 --- a/Solverz/solvers/stats.py +++ b/Solverz/solvers/stats.py @@ -9,7 +9,7 @@ def __init__(self, scheme=None): self.nreject = 0 self.nsolve = 0 self.ret = None - self.succeed = True + self.succeed = False def __repr__(self): return (f"Scheme: {self.scheme}, "