From d919b341923563221e72ce535a0bd079616adb35 Mon Sep 17 00:00:00 2001 From: Ruizhi Yu Date: Wed, 9 Oct 2024 21:23:40 +0800 Subject: [PATCH] progress bar and bug fix (#97) * feat: progress bar of Rodas * feat: progress bar of Rodas * fix: resolve #96 * fix: update type checker --- Solverz/equation/eqn.py | 17 +++++++++------- Solverz/equation/jac.py | 5 ++++- Solverz/equation/test/test_eqn.py | 10 ++++++++++ Solverz/equation/test/test_jac.py | 11 +++++++++++ Solverz/solvers/daesolver/beuler.py | 14 +++++++------ Solverz/solvers/daesolver/rodas/rodas.py | 18 +++++++++++++++++ Solverz/solvers/daesolver/trapezoidal.py | 10 ++++++---- Solverz/solvers/daesolver/utilities.py | 1 + Solverz/solvers/fdesolver.py | 2 +- Solverz/solvers/solution.py | 3 +-- Solverz/utilities/address.py | 23 ++++++++++++++++++++++ Solverz/utilities/test/test_typechecker.py | 11 ++++++++++- Solverz/utilities/type_checker.py | 6 +++++- pyproject.toml | 2 +- requirements.txt | 2 +- 15 files changed, 110 insertions(+), 25 deletions(-) diff --git a/Solverz/equation/eqn.py b/Solverz/equation/eqn.py index b450c6c..640b99c 100644 --- a/Solverz/equation/eqn.py +++ b/Solverz/equation/eqn.py @@ -14,6 +14,7 @@ from Solverz.sym_algebra.transform import finite_difference, semi_descritize from Solverz.num_api.custom_function import numerical_interface from Solverz.variable.ssymbol import sSym2Sym +from Solverz.utilities.type_checker import is_zero def sVar2Var(var: Union[Var, iVar, List[iVar, Var]]) -> Union[iVar, List[iVar]]: @@ -76,18 +77,20 @@ def derive_derivative(self): diff = MixedEquationDiff(self.RHS, symbol_) else: diff = self.RHS.diff(symbol_) - self.derivatives[symbol_.name] = EqnDiff(name=f'Diff {self.name} w.r.t. {symbol_.name}', - eqn=diff, - diff_var=symbol_, - var_idx=idx_.name if isinstance(idx_, idx) else idx_) + if not is_zero(diff): + self.derivatives[symbol_.name] = EqnDiff(name=f'Diff {self.name} w.r.t. {symbol_.name}', + eqn=diff, + diff_var=symbol_, + var_idx=idx_.name if isinstance(idx_, idx) else idx_) elif isinstance(symbol_, iVar): if self.mixed_matrix_vector: diff = MixedEquationDiff(self.RHS, symbol_) else: diff = self.RHS.diff(symbol_) - self.derivatives[symbol_.name] = EqnDiff(name=f'Diff {self.name} w.r.t. {symbol_.name}', - eqn=diff, - diff_var=symbol_) + if not is_zero(diff): + self.derivatives[symbol_.name] = EqnDiff(name=f'Diff {self.name} w.r.t. {symbol_.name}', + eqn=diff, + diff_var=symbol_) @property def expr(self): diff --git a/Solverz/equation/jac.py b/Solverz/equation/jac.py index 66878f4..5ff84ba 100644 --- a/Solverz/equation/jac.py +++ b/Solverz/equation/jac.py @@ -5,7 +5,7 @@ from sympy import Expr, Function, Integer from Solverz.sym_algebra.symbols import iVar, IdxVar -from Solverz.utilities.type_checker import is_vector, is_scalar, is_integer, is_number, PyNumber +from Solverz.utilities.type_checker import is_vector, is_scalar, is_integer, is_number, PyNumber, is_zero from Solverz.sym_algebra.functions import Diag SolVar = Union[iVar, IdxVar] @@ -99,6 +99,9 @@ def __init__(self, 1. a diagonal matrix 2. a matrix """ + if is_zero(DeriExpr): + raise ValueError(f"We wont allow {DeriExpr} derivative!") + self.EqnName = EqnName self.EqnAddr: slice = EqnAddr self.DiffVar: SolVar = DiffVar diff --git a/Solverz/equation/test/test_eqn.py b/Solverz/equation/test/test_eqn.py index 6032123..bb21504 100644 --- a/Solverz/equation/test/test_eqn.py +++ b/Solverz/equation/test/test_eqn.py @@ -9,6 +9,7 @@ from Solverz.equation.eqn import Eqn, Ode from Solverz.equation.equations import DAE, AE from Solverz.variable.variables import combine_Vars, as_Vars +from Solverz import AntiWindUp, Param, Var x = iVar('x', value=[1]) f = Eqn('f', x - 1) @@ -20,3 +21,12 @@ f1 = f.NUM_EQN(np.array([3, 4])) assert isinstance(f1, np.ndarray) assert f1.ndim == 1 + +# discard zero derivative +u = Var('u', [1, 1, 1]) +e = Var('e', [1, 1, 1]) +umin = Param('umin', [0,0,0]) +umax = Param('umax', [5,5,5]) +F = Ode('Anti', AntiWindUp(u, umin, umax, e), u) +F.derive_derivative() +assert 'u' not in F.derivatives diff --git a/Solverz/equation/test/test_jac.py b/Solverz/equation/test/test_jac.py index 3768300..4a400b9 100644 --- a/Solverz/equation/test/test_jac.py +++ b/Solverz/equation/test/test_jac.py @@ -294,6 +294,17 @@ def test_jb_vector_var_scalar_deri(): iVar('y'), np.array([1])) +def test_jb_vector_var_zero_deri(): + + with pytest.raises(ValueError, + match=re.escape("We wont allow 0.0 derivative!")): + jb = JacBlock('a', + slice(0, 3), + iVar('x'), + np.ones(3), + slice(1, 4), + 0., + np.array([0])) # %% vector var and vector derivative def test_jb_vector_var_vector_deri(): diff --git a/Solverz/solvers/daesolver/beuler.py b/Solverz/solvers/daesolver/beuler.py index 3516ef0..61bfa86 100644 --- a/Solverz/solvers/daesolver/beuler.py +++ b/Solverz/solvers/daesolver/beuler.py @@ -51,13 +51,16 @@ def backward_euler(dae: nDAE, nt = 0 tt = T_initial t0 = tt + Nt = int((T_end-T_initial)/dt) + 1 - y = np.zeros((10000, y0.shape[0])) + y = np.zeros((Nt, y0.shape[0])) + y0 = DaeIc(dae, y0, t0, opt.rtol) # check and modify initial values y[0, :] = y0 - T = np.zeros((10000,)) + T = np.zeros((Nt,)) + T[0] = t0 p = dae.p - while abs(tt - T_end) > abs(dt) / 10: + while T_end - tt > abs(dt) / 10: My0 = dae.M @ y0 ae = nAE(lambda y_, p_: dae.M @ y_ - My0 - dt * dae.F(t0 + dt, y_, p_), lambda y_, p_: dae.M - dt * dae.J(t0 + dt, y_, p_), @@ -65,9 +68,8 @@ def backward_euler(dae: nDAE, sol = nr_method(ae, y0, Opt(stats=True)) y1 = sol.y - ite = sol.stats - stats.ndecomp = stats.ndecomp + ite - stats.nfeval = stats.nfeval + ite + stats.ndecomp = stats.ndecomp + sol.stats.nstep + stats.nfeval = stats.nfeval + sol.stats.nstep tt = tt + dt nt = nt + 1 diff --git a/Solverz/solvers/daesolver/rodas/rodas.py b/Solverz/solvers/daesolver/rodas/rodas.py index 22303ff..2e4c9eb 100644 --- a/Solverz/solvers/daesolver/rodas/rodas.py +++ b/Solverz/solvers/daesolver/rodas/rodas.py @@ -44,6 +44,8 @@ def Rodas(dae: nDAE, Initial step size - hmax: None(default)|float Maximum step size. + - pbar: False(default)|bool + To display progress bar Returns ======= @@ -69,6 +71,8 @@ def Rodas(dae: nDAE, tspan = np.array(tspan) tend = tspan[-1] t0 = tspan[0] + if t0 > tend: + raise ValueError(f't0: {t0} > tend: {tend}') if opt.hmax is None: opt.hmax = np.abs(tend - t0) nt = 0 @@ -81,6 +85,9 @@ def Rodas(dae: nDAE, y0 = DaeIc(dae, y0, t0, opt.rtol) # check and modify initial values y[0, :] = y0 + if opt.pbar: + pbar = tqdm(total=tend - t0) + dense_output = False n_tspan = len(tspan) told = t0 @@ -185,6 +192,7 @@ def Rodas(dae: nDAE, reject = 0 told = t t = t + dt + stats.nstep = stats.nstep + 1 # events if haveEvent: @@ -261,6 +269,9 @@ def Rodas(dae: nDAE, T[nt] = tnext y[nt] = ynext + if opt.pbar: + pbar.update(T[nt] - T[nt - 1]) + if haveEvent and stop: if tnext >= tevent: break @@ -278,6 +289,9 @@ def Rodas(dae: nDAE, T[nt] = t y[nt] = ynew + if opt.pbar: + pbar.update(T[nt] - T[nt - 1]) + if nt == 10000: warnings.warn("Time steps more than 10000! Rodas breaks. Try input a smaller tspan!") done = True @@ -295,6 +309,10 @@ def Rodas(dae: nDAE, T = T[0:nt + 1] y = y[0:nt + 1] + + if opt.pbar: + pbar.close() + if haveEvent: te = te[0:nevent + 1] ye = ye[0:nevent + 1] diff --git a/Solverz/solvers/daesolver/trapezoidal.py b/Solverz/solvers/daesolver/trapezoidal.py index 5333058..5aee9d9 100644 --- a/Solverz/solvers/daesolver/trapezoidal.py +++ b/Solverz/solvers/daesolver/trapezoidal.py @@ -50,21 +50,23 @@ def implicit_trapezoid(dae: nDAE, nt = 0 tt = T_initial t0 = tt + Nt = int((T_end-T_initial)/dt) + 1 - y = np.zeros((10000, y0.shape[0])) + y = np.zeros((Nt, y0.shape[0])) y0 = DaeIc(dae, y0, t0, opt.rtol) # check and modify initial values y[0, :] = y0 - T = np.zeros((10000,)) + T = np.zeros((Nt,)) + T[0] = t0 p = dae.p - while abs(tt - T_end) > abs(dt) / 10: + while T_end - tt > abs(dt) / 10: My0 = dae.M @ y0 F0 = dae.F(t0, y0, p).copy() ae = nAE(lambda y_, p_: dt / 2 * (dae.F(t0 + dt, y_, p_) + F0) - dae.M @ y_ + My0, lambda y_, p_: -dae.M + dt / 2 * dae.J(t0 + dt, y_, p_), p) - sol = nr_method(ae, y0, Opt(stats=True)) + sol = nr_method(ae, y0, Opt(stats=True, ite_tol=opt.ite_tol)) y1 = sol.y stats.ndecomp = stats.ndecomp + sol.stats.nstep stats.nfeval = stats.nfeval + sol.stats.nstep diff --git a/Solverz/solvers/daesolver/utilities.py b/Solverz/solvers/daesolver/utilities.py index f68341a..8777547 100644 --- a/Solverz/solvers/daesolver/utilities.py +++ b/Solverz/solvers/daesolver/utilities.py @@ -4,6 +4,7 @@ import numpy as np from numpy import abs, linalg +from tqdm import tqdm from Solverz.equation.equations import DAE from Solverz.num_api.num_eqn import nDAE, nAE diff --git a/Solverz/solvers/fdesolver.py b/Solverz/solvers/fdesolver.py index 0a065c4..a890686 100644 --- a/Solverz/solvers/fdesolver.py +++ b/Solverz/solvers/fdesolver.py @@ -67,7 +67,7 @@ def fdae_solver(fdae: nFDAE, T = np.zeros((nstep,)) T[0] = t0 if opt.pbar: - bar = tqdm.tqdm(total=tend) + bar = tqdm.tqdm(total=tend-t0) done = False p = fdae.p diff --git a/Solverz/solvers/solution.py b/Solverz/solvers/solution.py index 90330a1..0a8f1d5 100644 --- a/Solverz/solvers/solution.py +++ b/Solverz/solvers/solution.py @@ -54,8 +54,7 @@ def __getitem__(self, item): elif is_number(item): item = int(item) - if self.ie is not None: - return daesol(self.T[item], self.Y[item]) + return daesol(self.T[item], self.Y[item]) else: raise NotImplementedError(f"Index type {type(item)} not implemented!") diff --git a/Solverz/utilities/address.py b/Solverz/utilities/address.py index 83c0e19..7dba5ed 100644 --- a/Solverz/utilities/address.py +++ b/Solverz/utilities/address.py @@ -3,6 +3,7 @@ import numpy as np from typing import Dict from copy import deepcopy +from .type_checker import is_integer def combine_Address(a1: Address, a2: Address) -> Address: @@ -56,6 +57,28 @@ def update_v_cache(self): address_list.append(np.arange(start, start + self.length_array[idx], dtype=int)) self.v_cache = dict(zip(self.object_list, address_list)) + def inquiry_eqn_name(self, addr: int): + """ + Given eqn address (number), find the equation name. + """ + if not is_integer(addr): + raise ValueError(f"Address should be integer but {addr}!") + if addr < 0: + raise ValueError(f"No negative address allowed!") + + current_sum = -1 # The address should start from 0 + for i, value in enumerate(self.length_array): + current_sum += value + if current_sum >= addr: + break + if addr > current_sum: + raise ValueError(f"Input address bigger than maximum address {current_sum}!") + eqn_name = self.object_list[i] + if addr in self.v[eqn_name].tolist(): + return eqn_name + else: + raise ValueError(f"How could this happen?") + @property def v(self) -> Dict[str, np.ndarray]: return self.v_cache diff --git a/Solverz/utilities/test/test_typechecker.py b/Solverz/utilities/test/test_typechecker.py index 6487d07..f9ff8d1 100644 --- a/Solverz/utilities/test/test_typechecker.py +++ b/Solverz/utilities/test/test_typechecker.py @@ -1,7 +1,16 @@ import numpy as np -from Solverz.utilities.type_checker import is_integer +from sympy import Integer, Float +from Solverz.utilities.type_checker import is_integer, is_zero def test_is_number(): assert is_integer(np.array([1.0]).astype(int)[0]) assert not is_integer(np.array([1.0])[0]) + +def test_is_zero(): + assert is_zero(Integer(0)) + assert not is_zero(Float(1)) + assert is_zero(Float(0)) + assert is_zero(0) + assert is_zero(0.) + assert is_zero(np.array(0)) diff --git a/Solverz/utilities/type_checker.py b/Solverz/utilities/type_checker.py index 74ecfcb..2d932b4 100644 --- a/Solverz/utilities/type_checker.py +++ b/Solverz/utilities/type_checker.py @@ -1,4 +1,5 @@ from sympy import Number as SymNumber, Integer +from sympy import Float import numpy as np from numpy import integer as NpInteger from numbers import Number as PyNumber @@ -36,4 +37,7 @@ def is_scalar(a): def is_zero(a): - return a == 0 + if is_number(a): + return Float(a) == Float(0) + else: + return False diff --git a/pyproject.toml b/pyproject.toml index e852b30..4713733 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ version_file = "Solverz/_version.py" name = "Solverz" dynamic = ["version"] dependencies = [ - "sympy>=1.11.1", + "sympy>=1.13.0", "numba >= 0.58.1", "numpy>=2.0.0", "scipy>=1.14.0", diff --git a/requirements.txt b/requirements.txt index d23f44f..665c50d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy>=2.0.0 -sympy>=1.11.1 +sympy>=1.13.0 pandas>=1.4.2 openpyxl>=3.0.10 matplotlib == 3.9.0