From 512c0e17fe0f4878a9e5bc7806a299df19142e18 Mon Sep 17 00:00:00 2001 From: rzyu45 Date: Sun, 4 Aug 2024 23:24:48 +0800 Subject: [PATCH] fix: resolve #79 --- .../python/inline/tests/test_inline_printer.py | 10 +++++----- Solverz/equation/hvp.py | 4 +++- Solverz/equation/jac.py | 8 ++++---- Solverz/equation/test/test_jac.py | 12 ++++++------ 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/Solverz/code_printer/python/inline/tests/test_inline_printer.py b/Solverz/code_printer/python/inline/tests/test_inline_printer.py index 64afd16..b01025d 100644 --- a/Solverz/code_printer/python/inline/tests/test_inline_printer.py +++ b/Solverz/code_printer/python/inline/tests/test_inline_printer.py @@ -45,7 +45,7 @@ def test_jb_printer_scalar_var_scalar_deri(): assert symJb[2] == extend(data, iVar('y') * Ones(3)) symJb = print_J_block(jb, False) assert symJb[0] == AddAugmentedAssignment( - J_[0:3, 1:2], iVar('y') * Ones(3)) + J_[0:3, 1], iVar('y') * Ones(3)) def test_jb_printer_vector_var_vector_deri(): @@ -96,7 +96,7 @@ def test_jbs_printer(): assert symJbs[4] == extend(col, SolList(7, 8, 9, 10, 11, 12, 13, 14, 15)) assert symJbs[5] == extend(data, y ** 2) symJbs = print_J_blocks(jac, False) - assert symJbs[0] == AddAugmentedAssignment(J_[0:3, 1:2], y * Ones(3)) + assert symJbs[0] == AddAugmentedAssignment(J_[0:3, 1], y * Ones(3)) assert symJbs[1] == AddAugmentedAssignment(J_[3:12, 7:16], Diag(y ** 2)) @@ -105,7 +105,7 @@ def test_jbs_printer(): v = y_[1:2] g = p_["g"] J_ = zeros((2, 2)) - J_[0:1,1:2] += ones(1) + J_[0:1,1] += ones(1) return J_ """.strip() @@ -178,8 +178,8 @@ def test_print_F_J(): J_ = zeros((6, 6)) J_[0:2,1:3] += diagflat(ones(2)) J_[2:4,0:2] += diagflat(-ones(2)) - J_[4:5,2:3] += -ones(1) - J_[5:6,2:3] += ones(1) + J_[4:5,2] += -ones(1) + J_[5:6,2] += ones(1) return J_ """.strip() diff --git a/Solverz/equation/hvp.py b/Solverz/equation/hvp.py index 9395b43..b29980f 100644 --- a/Solverz/equation/hvp.py +++ b/Solverz/equation/hvp.py @@ -74,7 +74,9 @@ def __init__(self, jac: Jac) -> None: self.blocks_sorted = self.jac1.blocks_sorted -def parse_den_var_addr(den_var_addr: slice): +def parse_den_var_addr(den_var_addr: slice | int): + if isinstance(den_var_addr, int): + den_var_addr = slice(den_var_addr, den_var_addr + 1) if den_var_addr.stop - den_var_addr.start == 1: return den_var_addr.start else: diff --git a/Solverz/equation/jac.py b/Solverz/equation/jac.py index f83d62c..66878f4 100644 --- a/Solverz/equation/jac.py +++ b/Solverz/equation/jac.py @@ -112,7 +112,7 @@ def __init__(self, self.SpEleSize = 0 self.SpDeriExpr: Expr = Integer(0) self.DenEqnAddr: slice = slice(0) - self.DenVarAddr: slice = slice(0) + self.DenVarAddr: slice | int = slice(0) self.DenDeriExpr: Expr = Integer(0) EqnSize = self.EqnAddr.stop - self.EqnAddr.start @@ -271,17 +271,17 @@ def ParseDen(self): case 'vector' | 'scalar': self.DenEqnAddr = self.EqnAddr if isinstance(self.DiffVar, iVar): - self.DenVarAddr = self.VarAddr + self.DenVarAddr = self.VarAddr.start else: if isinstance(self.DiffVar.index, slice): VarArange = slice2array(self.VarAddr)[self.DiffVar.index] if VarArange.size > 1: raise ValueError(f"Length of scalar variable {self.DiffVar} > 1!") else: - self.DenVarAddr = slice(VarArange[0], VarArange[-1] + 1) + self.DenVarAddr = VarArange[0] elif is_integer(self.DiffVar.index): idx = int(slice2array(self.VarAddr)[self.DiffVar.index]) - self.DenVarAddr = slice(idx, idx + 1) + self.DenVarAddr = idx else: raise TypeError(f"Index type {type(self.DiffVar.index)} not supported!") self.DenDeriExpr = self.DeriExprBc diff --git a/Solverz/equation/test/test_jac.py b/Solverz/equation/test/test_jac.py index 454302f..3768300 100644 --- a/Solverz/equation/test/test_jac.py +++ b/Solverz/equation/test/test_jac.py @@ -71,7 +71,7 @@ def test_jb_scalar_var_scalar_deri(): np.array([1])) assert jb.DenEqnAddr == slice(0, 3) - assert jb.DenVarAddr == slice(1, 2) + assert jb.DenVarAddr == 1 assert jb.DenDeriExpr == iVar('y') * Ones(3) assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2])) assert_allclose(jb.SpVarAddr, np.array([1, 1, 1])) @@ -89,7 +89,7 @@ def test_jb_scalar_var_scalar_deri(): np.array([1])) assert jb.DenEqnAddr == slice(0, 3) - assert jb.DenVarAddr == slice(2, 3) + assert jb.DenVarAddr == 2 assert jb.DenDeriExpr == iVar('y') * Ones(3) assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2])) assert_allclose(jb.SpVarAddr, np.array([2, 2, 2])) @@ -107,7 +107,7 @@ def test_jb_scalar_var_scalar_deri(): np.array([1])) assert jb.DenEqnAddr == slice(0, 3) - assert jb.DenVarAddr == slice(2, 3) + assert jb.DenVarAddr == 2 assert jb.DenDeriExpr == iVar('y') * Ones(3) assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2])) assert_allclose(jb.SpVarAddr, np.array([2, 2, 2])) @@ -128,7 +128,7 @@ def test_jb_scalar_var_vector_deri(): np.array([1, 1, 1])) assert jb.DenEqnAddr == slice(0, 3) - assert jb.DenVarAddr == slice(1, 2) + assert jb.DenVarAddr == 1 assert jb.DenDeriExpr == iVar('y') assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2])) assert_allclose(jb.SpVarAddr, np.array([1, 1, 1])) @@ -146,7 +146,7 @@ def test_jb_scalar_var_vector_deri(): np.array([1, 1, 1])) assert jb.DenEqnAddr == slice(0, 3) - assert jb.DenVarAddr == slice(2, 3) + assert jb.DenVarAddr == 2 assert jb.DenDeriExpr == iVar('y') assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2])) assert_allclose(jb.SpVarAddr, np.array([2, 2, 2])) @@ -164,7 +164,7 @@ def test_jb_scalar_var_vector_deri(): np.array([1, 1, 1])) assert jb.DenEqnAddr == slice(0, 3) - assert jb.DenVarAddr == slice(2, 3) + assert jb.DenVarAddr == 2 assert jb.DenDeriExpr == iVar('y') assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2])) assert_allclose(jb.SpVarAddr, np.array([2, 2, 2]))