Skip to content

Commit

Permalink
fix: resolve smallbunnies#79
Browse files Browse the repository at this point in the history
  • Loading branch information
rzyu45 committed Aug 4, 2024
1 parent 091b48d commit 512c0e1
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
10 changes: 5 additions & 5 deletions Solverz/code_printer/python/inline/tests/test_inline_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_jb_printer_scalar_var_scalar_deri():
assert symJb[2] == extend(data, iVar('y') * Ones(3))
symJb = print_J_block(jb, False)
assert symJb[0] == AddAugmentedAssignment(
J_[0:3, 1:2], iVar('y') * Ones(3))
J_[0:3, 1], iVar('y') * Ones(3))


def test_jb_printer_vector_var_vector_deri():
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_jbs_printer():
assert symJbs[4] == extend(col, SolList(7, 8, 9, 10, 11, 12, 13, 14, 15))
assert symJbs[5] == extend(data, y ** 2)
symJbs = print_J_blocks(jac, False)
assert symJbs[0] == AddAugmentedAssignment(J_[0:3, 1:2], y * Ones(3))
assert symJbs[0] == AddAugmentedAssignment(J_[0:3, 1], y * Ones(3))
assert symJbs[1] == AddAugmentedAssignment(J_[3:12, 7:16], Diag(y ** 2))


Expand All @@ -105,7 +105,7 @@ def test_jbs_printer():
v = y_[1:2]
g = p_["g"]
J_ = zeros((2, 2))
J_[0:1,1:2] += ones(1)
J_[0:1,1] += ones(1)
return J_
""".strip()

Expand Down Expand Up @@ -178,8 +178,8 @@ def test_print_F_J():
J_ = zeros((6, 6))
J_[0:2,1:3] += diagflat(ones(2))
J_[2:4,0:2] += diagflat(-ones(2))
J_[4:5,2:3] += -ones(1)
J_[5:6,2:3] += ones(1)
J_[4:5,2] += -ones(1)
J_[5:6,2] += ones(1)
return J_
""".strip()

Expand Down
4 changes: 3 additions & 1 deletion Solverz/equation/hvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def __init__(self, jac: Jac) -> None:
self.blocks_sorted = self.jac1.blocks_sorted


def parse_den_var_addr(den_var_addr: slice):
def parse_den_var_addr(den_var_addr: slice | int):
if isinstance(den_var_addr, int):
den_var_addr = slice(den_var_addr, den_var_addr + 1)
if den_var_addr.stop - den_var_addr.start == 1:
return den_var_addr.start
else:
Expand Down
8 changes: 4 additions & 4 deletions Solverz/equation/jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self,
self.SpEleSize = 0
self.SpDeriExpr: Expr = Integer(0)
self.DenEqnAddr: slice = slice(0)
self.DenVarAddr: slice = slice(0)
self.DenVarAddr: slice | int = slice(0)
self.DenDeriExpr: Expr = Integer(0)

EqnSize = self.EqnAddr.stop - self.EqnAddr.start
Expand Down Expand Up @@ -271,17 +271,17 @@ def ParseDen(self):
case 'vector' | 'scalar':
self.DenEqnAddr = self.EqnAddr
if isinstance(self.DiffVar, iVar):
self.DenVarAddr = self.VarAddr
self.DenVarAddr = self.VarAddr.start
else:
if isinstance(self.DiffVar.index, slice):
VarArange = slice2array(self.VarAddr)[self.DiffVar.index]
if VarArange.size > 1:
raise ValueError(f"Length of scalar variable {self.DiffVar} > 1!")
else:
self.DenVarAddr = slice(VarArange[0], VarArange[-1] + 1)
self.DenVarAddr = VarArange[0]
elif is_integer(self.DiffVar.index):
idx = int(slice2array(self.VarAddr)[self.DiffVar.index])
self.DenVarAddr = slice(idx, idx + 1)
self.DenVarAddr = idx
else:
raise TypeError(f"Index type {type(self.DiffVar.index)} not supported!")
self.DenDeriExpr = self.DeriExprBc
Expand Down
12 changes: 6 additions & 6 deletions Solverz/equation/test/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_jb_scalar_var_scalar_deri():
np.array([1]))

assert jb.DenEqnAddr == slice(0, 3)
assert jb.DenVarAddr == slice(1, 2)
assert jb.DenVarAddr == 1
assert jb.DenDeriExpr == iVar('y') * Ones(3)
assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2]))
assert_allclose(jb.SpVarAddr, np.array([1, 1, 1]))
Expand All @@ -89,7 +89,7 @@ def test_jb_scalar_var_scalar_deri():
np.array([1]))

assert jb.DenEqnAddr == slice(0, 3)
assert jb.DenVarAddr == slice(2, 3)
assert jb.DenVarAddr == 2
assert jb.DenDeriExpr == iVar('y') * Ones(3)
assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2]))
assert_allclose(jb.SpVarAddr, np.array([2, 2, 2]))
Expand All @@ -107,7 +107,7 @@ def test_jb_scalar_var_scalar_deri():
np.array([1]))

assert jb.DenEqnAddr == slice(0, 3)
assert jb.DenVarAddr == slice(2, 3)
assert jb.DenVarAddr == 2
assert jb.DenDeriExpr == iVar('y') * Ones(3)
assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2]))
assert_allclose(jb.SpVarAddr, np.array([2, 2, 2]))
Expand All @@ -128,7 +128,7 @@ def test_jb_scalar_var_vector_deri():
np.array([1, 1, 1]))

assert jb.DenEqnAddr == slice(0, 3)
assert jb.DenVarAddr == slice(1, 2)
assert jb.DenVarAddr == 1
assert jb.DenDeriExpr == iVar('y')
assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2]))
assert_allclose(jb.SpVarAddr, np.array([1, 1, 1]))
Expand All @@ -146,7 +146,7 @@ def test_jb_scalar_var_vector_deri():
np.array([1, 1, 1]))

assert jb.DenEqnAddr == slice(0, 3)
assert jb.DenVarAddr == slice(2, 3)
assert jb.DenVarAddr == 2
assert jb.DenDeriExpr == iVar('y')
assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2]))
assert_allclose(jb.SpVarAddr, np.array([2, 2, 2]))
Expand All @@ -164,7 +164,7 @@ def test_jb_scalar_var_vector_deri():
np.array([1, 1, 1]))

assert jb.DenEqnAddr == slice(0, 3)
assert jb.DenVarAddr == slice(2, 3)
assert jb.DenVarAddr == 2
assert jb.DenDeriExpr == iVar('y')
assert_allclose(jb.SpEqnAddr, np.array([0, 1, 2]))
assert_allclose(jb.SpVarAddr, np.array([2, 2, 2]))
Expand Down

0 comments on commit 512c0e1

Please sign in to comment.