Skip to content

Commit

Permalink
progress bar and bug fix (#97)
Browse files Browse the repository at this point in the history
* feat: progress bar of Rodas

* feat: progress bar of Rodas

* fix: resolve #96

* fix: update type checker
  • Loading branch information
rzyu45 authored Oct 9, 2024
1 parent c76ff80 commit d919b34
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 25 deletions.
17 changes: 10 additions & 7 deletions Solverz/equation/eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from Solverz.sym_algebra.transform import finite_difference, semi_descritize
from Solverz.num_api.custom_function import numerical_interface
from Solverz.variable.ssymbol import sSym2Sym
from Solverz.utilities.type_checker import is_zero


def sVar2Var(var: Union[Var, iVar, List[iVar, Var]]) -> Union[iVar, List[iVar]]:
Expand Down Expand Up @@ -76,18 +77,20 @@ def derive_derivative(self):
diff = MixedEquationDiff(self.RHS, symbol_)
else:
diff = self.RHS.diff(symbol_)
self.derivatives[symbol_.name] = EqnDiff(name=f'Diff {self.name} w.r.t. {symbol_.name}',
eqn=diff,
diff_var=symbol_,
var_idx=idx_.name if isinstance(idx_, idx) else idx_)
if not is_zero(diff):
self.derivatives[symbol_.name] = EqnDiff(name=f'Diff {self.name} w.r.t. {symbol_.name}',
eqn=diff,
diff_var=symbol_,
var_idx=idx_.name if isinstance(idx_, idx) else idx_)
elif isinstance(symbol_, iVar):
if self.mixed_matrix_vector:
diff = MixedEquationDiff(self.RHS, symbol_)
else:
diff = self.RHS.diff(symbol_)
self.derivatives[symbol_.name] = EqnDiff(name=f'Diff {self.name} w.r.t. {symbol_.name}',
eqn=diff,
diff_var=symbol_)
if not is_zero(diff):
self.derivatives[symbol_.name] = EqnDiff(name=f'Diff {self.name} w.r.t. {symbol_.name}',
eqn=diff,
diff_var=symbol_)

@property
def expr(self):
Expand Down
5 changes: 4 additions & 1 deletion Solverz/equation/jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sympy import Expr, Function, Integer
from Solverz.sym_algebra.symbols import iVar, IdxVar
from Solverz.utilities.type_checker import is_vector, is_scalar, is_integer, is_number, PyNumber
from Solverz.utilities.type_checker import is_vector, is_scalar, is_integer, is_number, PyNumber, is_zero
from Solverz.sym_algebra.functions import Diag

SolVar = Union[iVar, IdxVar]
Expand Down Expand Up @@ -99,6 +99,9 @@ def __init__(self,
1. a diagonal matrix
2. a matrix
"""
if is_zero(DeriExpr):
raise ValueError(f"We wont allow {DeriExpr} derivative!")

self.EqnName = EqnName
self.EqnAddr: slice = EqnAddr
self.DiffVar: SolVar = DiffVar
Expand Down
10 changes: 10 additions & 0 deletions Solverz/equation/test/test_eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from Solverz.equation.eqn import Eqn, Ode
from Solverz.equation.equations import DAE, AE
from Solverz.variable.variables import combine_Vars, as_Vars
from Solverz import AntiWindUp, Param, Var

x = iVar('x', value=[1])
f = Eqn('f', x - 1)
Expand All @@ -20,3 +21,12 @@
f1 = f.NUM_EQN(np.array([3, 4]))
assert isinstance(f1, np.ndarray)
assert f1.ndim == 1

# discard zero derivative
u = Var('u', [1, 1, 1])
e = Var('e', [1, 1, 1])
umin = Param('umin', [0,0,0])
umax = Param('umax', [5,5,5])
F = Ode('Anti', AntiWindUp(u, umin, umax, e), u)
F.derive_derivative()
assert 'u' not in F.derivatives
11 changes: 11 additions & 0 deletions Solverz/equation/test/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,17 @@ def test_jb_vector_var_scalar_deri():
iVar('y'),
np.array([1]))

def test_jb_vector_var_zero_deri():

with pytest.raises(ValueError,
match=re.escape("We wont allow 0.0 derivative!")):
jb = JacBlock('a',
slice(0, 3),
iVar('x'),
np.ones(3),
slice(1, 4),
0.,
np.array([0]))

# %% vector var and vector derivative
def test_jb_vector_var_vector_deri():
Expand Down
14 changes: 8 additions & 6 deletions Solverz/solvers/daesolver/beuler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,25 @@ def backward_euler(dae: nDAE,
nt = 0
tt = T_initial
t0 = tt
Nt = int((T_end-T_initial)/dt) + 1

y = np.zeros((10000, y0.shape[0]))
y = np.zeros((Nt, y0.shape[0]))
y0 = DaeIc(dae, y0, t0, opt.rtol) # check and modify initial values
y[0, :] = y0
T = np.zeros((10000,))
T = np.zeros((Nt,))
T[0] = t0

p = dae.p
while abs(tt - T_end) > abs(dt) / 10:
while T_end - tt > abs(dt) / 10:
My0 = dae.M @ y0
ae = nAE(lambda y_, p_: dae.M @ y_ - My0 - dt * dae.F(t0 + dt, y_, p_),
lambda y_, p_: dae.M - dt * dae.J(t0 + dt, y_, p_),
p)

sol = nr_method(ae, y0, Opt(stats=True))
y1 = sol.y
ite = sol.stats
stats.ndecomp = stats.ndecomp + ite
stats.nfeval = stats.nfeval + ite
stats.ndecomp = stats.ndecomp + sol.stats.nstep
stats.nfeval = stats.nfeval + sol.stats.nstep

tt = tt + dt
nt = nt + 1
Expand Down
18 changes: 18 additions & 0 deletions Solverz/solvers/daesolver/rodas/rodas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def Rodas(dae: nDAE,
Initial step size
- hmax: None(default)|float
Maximum step size.
- pbar: False(default)|bool
To display progress bar
Returns
=======
Expand All @@ -69,6 +71,8 @@ def Rodas(dae: nDAE,
tspan = np.array(tspan)
tend = tspan[-1]
t0 = tspan[0]
if t0 > tend:
raise ValueError(f't0: {t0} > tend: {tend}')
if opt.hmax is None:
opt.hmax = np.abs(tend - t0)
nt = 0
Expand All @@ -81,6 +85,9 @@ def Rodas(dae: nDAE,
y0 = DaeIc(dae, y0, t0, opt.rtol) # check and modify initial values
y[0, :] = y0

if opt.pbar:
pbar = tqdm(total=tend - t0)

dense_output = False
n_tspan = len(tspan)
told = t0
Expand Down Expand Up @@ -185,6 +192,7 @@ def Rodas(dae: nDAE,
reject = 0
told = t
t = t + dt

stats.nstep = stats.nstep + 1
# events
if haveEvent:
Expand Down Expand Up @@ -261,6 +269,9 @@ def Rodas(dae: nDAE,
T[nt] = tnext
y[nt] = ynext

if opt.pbar:
pbar.update(T[nt] - T[nt - 1])

if haveEvent and stop:
if tnext >= tevent:
break
Expand All @@ -278,6 +289,9 @@ def Rodas(dae: nDAE,
T[nt] = t
y[nt] = ynew

if opt.pbar:
pbar.update(T[nt] - T[nt - 1])

if nt == 10000:
warnings.warn("Time steps more than 10000! Rodas breaks. Try input a smaller tspan!")
done = True
Expand All @@ -295,6 +309,10 @@ def Rodas(dae: nDAE,

T = T[0:nt + 1]
y = y[0:nt + 1]

if opt.pbar:
pbar.close()

if haveEvent:
te = te[0:nevent + 1]
ye = ye[0:nevent + 1]
Expand Down
10 changes: 6 additions & 4 deletions Solverz/solvers/daesolver/trapezoidal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,23 @@ def implicit_trapezoid(dae: nDAE,
nt = 0
tt = T_initial
t0 = tt
Nt = int((T_end-T_initial)/dt) + 1

y = np.zeros((10000, y0.shape[0]))
y = np.zeros((Nt, y0.shape[0]))
y0 = DaeIc(dae, y0, t0, opt.rtol) # check and modify initial values
y[0, :] = y0
T = np.zeros((10000,))
T = np.zeros((Nt,))
T[0] = t0

p = dae.p
while abs(tt - T_end) > abs(dt) / 10:
while T_end - tt > abs(dt) / 10:
My0 = dae.M @ y0
F0 = dae.F(t0, y0, p).copy()
ae = nAE(lambda y_, p_: dt / 2 * (dae.F(t0 + dt, y_, p_) + F0) - dae.M @ y_ + My0,
lambda y_, p_: -dae.M + dt / 2 * dae.J(t0 + dt, y_, p_),
p)

sol = nr_method(ae, y0, Opt(stats=True))
sol = nr_method(ae, y0, Opt(stats=True, ite_tol=opt.ite_tol))
y1 = sol.y
stats.ndecomp = stats.ndecomp + sol.stats.nstep
stats.nfeval = stats.nfeval + sol.stats.nstep
Expand Down
1 change: 1 addition & 0 deletions Solverz/solvers/daesolver/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from numpy import abs, linalg
from tqdm import tqdm

from Solverz.equation.equations import DAE
from Solverz.num_api.num_eqn import nDAE, nAE
Expand Down
2 changes: 1 addition & 1 deletion Solverz/solvers/fdesolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def fdae_solver(fdae: nFDAE,
T = np.zeros((nstep,))
T[0] = t0
if opt.pbar:
bar = tqdm.tqdm(total=tend)
bar = tqdm.tqdm(total=tend-t0)

done = False
p = fdae.p
Expand Down
3 changes: 1 addition & 2 deletions Solverz/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def __getitem__(self, item):

elif is_number(item):
item = int(item)
if self.ie is not None:
return daesol(self.T[item], self.Y[item])
return daesol(self.T[item], self.Y[item])
else:
raise NotImplementedError(f"Index type {type(item)} not implemented!")

Expand Down
23 changes: 23 additions & 0 deletions Solverz/utilities/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from typing import Dict
from copy import deepcopy
from .type_checker import is_integer


def combine_Address(a1: Address, a2: Address) -> Address:
Expand Down Expand Up @@ -56,6 +57,28 @@ def update_v_cache(self):
address_list.append(np.arange(start, start + self.length_array[idx], dtype=int))
self.v_cache = dict(zip(self.object_list, address_list))

def inquiry_eqn_name(self, addr: int):
"""
Given eqn address (number), find the equation name.
"""
if not is_integer(addr):
raise ValueError(f"Address should be integer but {addr}!")
if addr < 0:
raise ValueError(f"No negative address allowed!")

current_sum = -1 # The address should start from 0
for i, value in enumerate(self.length_array):
current_sum += value
if current_sum >= addr:
break
if addr > current_sum:
raise ValueError(f"Input address bigger than maximum address {current_sum}!")
eqn_name = self.object_list[i]
if addr in self.v[eqn_name].tolist():
return eqn_name
else:
raise ValueError(f"How could this happen?")

@property
def v(self) -> Dict[str, np.ndarray]:
return self.v_cache
Expand Down
11 changes: 10 additions & 1 deletion Solverz/utilities/test/test_typechecker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import numpy as np
from Solverz.utilities.type_checker import is_integer
from sympy import Integer, Float
from Solverz.utilities.type_checker import is_integer, is_zero


def test_is_number():
assert is_integer(np.array([1.0]).astype(int)[0])
assert not is_integer(np.array([1.0])[0])

def test_is_zero():
assert is_zero(Integer(0))
assert not is_zero(Float(1))
assert is_zero(Float(0))
assert is_zero(0)
assert is_zero(0.)
assert is_zero(np.array(0))
6 changes: 5 additions & 1 deletion Solverz/utilities/type_checker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sympy import Number as SymNumber, Integer
from sympy import Float
import numpy as np
from numpy import integer as NpInteger
from numbers import Number as PyNumber
Expand Down Expand Up @@ -36,4 +37,7 @@ def is_scalar(a):


def is_zero(a):
return a == 0
if is_number(a):
return Float(a) == Float(0)
else:
return False
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ version_file = "Solverz/_version.py"
name = "Solverz"
dynamic = ["version"]
dependencies = [
"sympy>=1.11.1",
"sympy>=1.13.0",
"numba >= 0.58.1",
"numpy>=2.0.0",
"scipy>=1.14.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=2.0.0
sympy>=1.11.1
sympy>=1.13.0
pandas>=1.4.2
openpyxl>=3.0.10
matplotlib == 3.9.0
Expand Down

0 comments on commit d919b34

Please sign in to comment.