Skip to content

Commit

Permalink
refactor: fix the success logic
Browse files Browse the repository at this point in the history
  • Loading branch information
rzyu45 committed Dec 29, 2024
1 parent 08628f8 commit 774cf5c
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 26 deletions.
20 changes: 10 additions & 10 deletions Solverz/solvers/nlaesolver/cnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def f(y_, p_) -> np.ndarray:
ite = 0
df = eqn.F(y, p)
stats.nfeval += 1
while max(abs(df)) > tol:
while np.max(np.abs(df)) > tol:

if ite > opt.max_it:
print(f"Cannot converge within 100 iterations. Deviation: {np.max(np.abs(df))}!")
break

ite = ite + 1
err = 2
nofailed = True
Expand All @@ -90,8 +95,8 @@ def f(y_, p_) -> np.ndarray:
# error control
# error estimation
err = dt * np.linalg.norm(
kE.reshape(-1, ) / np.maximum(np.maximum(abs(y), abs(ynew)).reshape(-1, ), threshold),
np.Inf)
kE.reshape(-1, ) / np.maximum(np.maximum(np.abs(y), np.abs(ynew)).reshape(-1, ), threshold),
np.inf)
if err > rtol: # failed step
if dt <= hmin:
raise ValueError(f'IntegrationTolNotMet step size: {dt} hmin: {hmin}')
Expand All @@ -116,13 +121,8 @@ def f(y_, p_) -> np.ndarray:

dt = np.min([dt, hmax])

if ite > 100:
print(f"Cannot converge within 100 iterations. Deviation: {max(abs(df))}!")
stats.succeed = False
break

if np.any(np.isnan(y)):
stats.succeed = False
if np.max(np.abs(df)) < tol:
stats.succeed = True
stats.nstep = ite

return aesol(y, stats)
9 changes: 7 additions & 2 deletions Solverz/solvers/nlaesolver/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ def lm(eqn: nAE,
p = eqn.p

# optimize.root func cannot handle callable jac that returns scipy.sparse.csc_array
sol = optimize.root(lambda x: eqn.F(x, p), y, jac=lambda x: eqn.J(x, p).toarray(), method='lm', tol=tol)
dF = eqn.F(sol.y, eqn.p)
sol = optimize.root(lambda x: eqn.F(x, p),
y,
jac=lambda x: eqn.J(x, p).toarray(),
method='lm',
tol=tol,
options={'maxiter': opt.max_it})
dF = eqn.F(sol.x, eqn.p)
if np.max(np.abs(dF)) < tol:
stats.succeed = True
stats.nfeval = sol.nfev
Expand Down
22 changes: 11 additions & 11 deletions Solverz/solvers/nlaesolver/nr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,27 @@ def nr_method(eqn: nAE,
if opt is None:
opt = Opt()

stats = Stats('Newton')
tol = opt.ite_tol
p = eqn.p
df = eqn.F(y, p)

stats = Stats('Newton')
stats.nfeval += 1

# main loop
while max(abs(df)) > tol:
while np.max(np.abs(df)) > tol:

if stats.nstep > opt.max_it:
print(f"Cannot converge within 100 iterations. Deviation: {np.max(np.abs(df))}!")
break

stats.nstep += 1

y = y - solve(eqn.J(y, p), df)
stats.ndecomp += 1
df = eqn.F(y, p)
stats.nfeval += 1
stats.nstep += 1

if stats.nstep >= 100:
print(f"Cannot converge within 100 iterations. Deviation: {max(abs(df))}!")
stats.succeed = False
break

if np.any(np.isnan(y)):
stats.succeed = False
if np.max(np.abs(df)) < tol:
stats.succeed = True

return aesol(y, stats)
1 change: 0 additions & 1 deletion Solverz/solvers/nlaesolver/utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from numpy import abs, max
from scipy import optimize

from Solverz.num_api.num_eqn import nAE
Expand Down
4 changes: 3 additions & 1 deletion Solverz/solvers/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self,
partial_decompose=False,
ode15smaxit=4,
normcontrol=False,
numJac=False):
numJac=False,
max_it=100):
self.atol = atol
self.rtol = rtol
self.f_savety = f_savety
Expand All @@ -43,3 +44,4 @@ def __init__(self,
self.ode15smaxit = ode15smaxit
self.normcontrol = normcontrol
self.numJac = numJac
self.max_it = max_it
2 changes: 1 addition & 1 deletion Solverz/solvers/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, scheme=None):
self.nreject = 0
self.nsolve = 0
self.ret = None
self.succeed = True
self.succeed = False

def __repr__(self):
return (f"Scheme: {self.scheme}, "
Expand Down

0 comments on commit 774cf5c

Please sign in to comment.