Skip to content

Commit

Permalink
Merge pull request #87 from rzyu45/Derive-Hessian-vector-product
Browse files Browse the repository at this point in the history
Derive hessian vector product
  • Loading branch information
rzyu45 authored Aug 6, 2024
2 parents 9b4853e + d9780d8 commit a453828
Show file tree
Hide file tree
Showing 44 changed files with 1,752 additions and 465 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/cookbook-practice.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: Run tests in cookbook

on: [push]
on:
push:
pull_request:
types: [opened, synchronize, reopened, closed]
branches:
- '*'

jobs:
tests_in_cookbook:
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: Run tests

on: [push]
on:
push:
pull_request:
types: [opened, synchronize, reopened, closed]
branches:
- '*'

jobs:
built_in_tests:
Expand Down
3 changes: 2 additions & 1 deletion Solverz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from Solverz.equation.equations import AE, FDAE, DAE
from Solverz.equation.param import Param, IdxParam, TimeSeriesParam
from Solverz.sym_algebra.symbols import idx, Para, iVar, iAliasVar
from Solverz.sym_algebra.functions import Sign, Abs, transpose, exp, Diag, Mat_Mul, sin, cos, Min, AntiWindUp, Saturation
from Solverz.sym_algebra.functions import (Sign, Abs, transpose, exp, Diag, Mat_Mul, sin, cos, Min, AntiWindUp,
Saturation, heaviside, ln)
from Solverz.num_api.custom_function import minmod_flag, minmod
from Solverz.variable.variables import Vars, TimeVars, as_Vars
from Solverz.solvers import *
Expand Down
7 changes: 5 additions & 2 deletions Solverz/code_printer/make_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self,
variables: Vars | List[Vars],
name: str,
lang='python',
make_hvp=False,
directory=None,
jit=False):
self.name = name
Expand All @@ -20,6 +21,7 @@ def __init__(self,
self.variables = [variables]
else:
self.variables = variables
self.make_hvp = make_hvp
self.directory = directory
self.jit = jit

Expand All @@ -29,6 +31,7 @@ def render(self):
*self.variables,
name=self.name,
directory=self.directory,
numba=self.jit)
numba=self.jit,
make_hvp=self.make_hvp)
else:
raise NotImplemented(f"{self.lang} module renderer not implemented!")
raise NotImplemented(f"{self.lang} module renderer not implemented!")
69 changes: 62 additions & 7 deletions Solverz/code_printer/python/inline/inline_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def print_J(eqs_type: str,
body.extend(param_decla)
body.extend(print_trigger(PARAM))
if not sparse:
body.append(Assignment(temp, zeros(EqnAddr.total_size, var_addr.total_size)))
body.append(Assignment(temp, zeros(
EqnAddr.total_size, var_addr.total_size)))
body.extend(print_J_blocks(jac, False))
body.append(Return(temp))
else:
Expand All @@ -67,7 +68,7 @@ def print_J(eqs_type: str,

def print_J_blocks(jac: Jac, sparse: bool):
eqn_declaration = []
for eqn_name, jbs_row in jac.blocks.items():
for eqn_name, jbs_row in jac.blocks_sorted.items():
for var, jb in jbs_row.items():
eqn_declaration.extend(print_J_block(jb,
sparse))
Expand All @@ -82,17 +83,55 @@ def print_J_block(jb: JacBlock, sparse: bool) -> List:
# extend(iVar('row', internal_use=True), eqn_address),
# extend(iVar('col', internal_use=True), var_address),
# extend(iVar('data', internal_use=True), data)]
raise NotImplementedError("Matrix parameters in sparse Jac not implemented yet!")
raise NotImplementedError(
"Matrix parameters in sparse Jac not implemented yet!")
case 'vector' | 'scalar':
return [extend(iVar('row', internal_use=True), SolList(*jb.SpEqnAddr.tolist())),
extend(iVar('col', internal_use=True), SolList(*jb.SpVarAddr.tolist())),
extend(iVar('col', internal_use=True),
SolList(*jb.SpVarAddr.tolist())),
extend(iVar('data', internal_use=True), jb.SpDeriExpr)]
else:
return [AddAugmentedAssignment(iVar('J_', internal_use=True)[jb.DenEqnAddr, jb.DenVarAddr],
jb.DenDeriExpr)]


def made_numerical(eqs: SymEquations, y: Vars, sparse=False, output_code=False):
def print_Hvp(eqs_type: str,
hvp: Hvp,
EqnAddr: Address,
var_addr: Address,
PARAM: Dict[str, ParamBase],
nstep: int = 0,
sparse=True):
fp = print_Hvp_prototype(eqs_type,
nstep=nstep)
# initialize temp
temp = iVar('Hvp_', internal_use=True)
body = list()
body.extend(print_var(var_addr,
nstep)[0])
param_decla = print_param(PARAM)[0]
body.extend(param_decla)
body.extend(print_trigger(PARAM))
if not sparse:
body.append(Assignment(temp, zeros(
EqnAddr.total_size, var_addr.total_size)))
body.extend(print_J_blocks(hvp, False))
body.append(Return(temp))
else:
body.extend([Assignment(iVar('row', internal_use=True), SolList()),
Assignment(iVar('col', internal_use=True), SolList()),
Assignment(iVar('data', internal_use=True), SolList())])
body.extend(print_J_blocks(hvp, True))
body.append(Return(coo_2_csc(EqnAddr.total_size, var_addr.total_size)))
Jd = FunctionDefinition.from_FunctionPrototype(fp, body)
return pycode(Jd, fully_qualified_modules=False)


def made_numerical(eqs: SymEquations,
y: Vars,
sparse=False,
output_code=False,
make_hvp=False):
"""
factory method of numerical equations
"""
Expand All @@ -111,11 +150,24 @@ def made_numerical(eqs: SymEquations, y: Vars, sparse=False, output_code=False):
eqs.PARAM,
eqs.nstep,
sparse)
code = {'F': code_F, 'J': code_J}
if make_hvp:
eqs.hvp = Hvp(eqs.jac)
code_HVP = print_Hvp(eqs.__class__.__name__,
eqs.hvp,
eqs.a,
eqs.var_address,
eqs.PARAM,
eqs.nstep,
sparse)
code['HVP'] = code_HVP
custom_func = dict()
custom_func.update(numerical_interface)
custom_func.update(parse_trigger_func(eqs.PARAM))
F = Solverzlambdify(code_F, 'F_', modules=[custom_func, 'numpy'])
J = Solverzlambdify(code_J, 'J_', modules=[custom_func, 'numpy'])
if make_hvp:
HVP = Solverzlambdify(code_HVP, 'Hvp_', modules=[custom_func, 'numpy'])
p = parse_p(eqs.PARAM)
print('Complete!')
if isinstance(eqs, SymAE) and not isinstance(eqs, SymFDAE):
Expand All @@ -126,8 +178,10 @@ def made_numerical(eqs: SymEquations, y: Vars, sparse=False, output_code=False):
num_eqn = nDAE(eqs.M, F, J, p)
else:
raise ValueError(f'Unknown equation type {type(eqs)}')
if make_hvp:
num_eqn.HVP = HVP
if output_code:
return num_eqn, {'F': code_F, 'J': code_J}
return num_eqn, code
else:
return num_eqn

Expand Down Expand Up @@ -216,7 +270,8 @@ class tolist(Function):
@classmethod
def eval(cls, *args):
if len(args) != 1:
raise ValueError(f"Solverz' tolist function accepts only one input.")
raise ValueError(
f"Solverz' tolist function accepts only one input.")

def _numpycode(self, printer, **kwargs):
return r'((' + printer._print(self.args[0]) + r').tolist())'
Expand Down
93 changes: 83 additions & 10 deletions Solverz/code_printer/python/inline/tests/test_inline_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from Solverz.sym_algebra.symbols import iVar, idx, Para
from Solverz.variable.variables import combine_Vars, as_Vars
from Solverz.equation.jac import JacBlock, Ones, Jac
from Solverz.sym_algebra.functions import Diag
from Solverz.equation.hvp import Hvp
from Solverz.sym_algebra.functions import Diag, sin, cos, exp
from Solverz.code_printer.python.inline.inline_printer import print_J_block, extend, SolList, print_J_blocks, print_J, \
print_F, made_numerical
print_F, print_Hvp, made_numerical, Solverzlambdify
from Solverz.utilities.address import Address
from Solverz.num_api.custom_function import numerical_interface

# %%
row = iVar('row', internal_use=True)
Expand All @@ -41,7 +44,8 @@ def test_jb_printer_scalar_var_scalar_deri():
assert symJb[1] == extend(col, SolList(1, 1, 1))
assert symJb[2] == extend(data, iVar('y') * Ones(3))
symJb = print_J_block(jb, False)
assert symJb[0] == AddAugmentedAssignment(J_[0:3, 1:2], iVar('y') * Ones(3))
assert symJb[0] == AddAugmentedAssignment(
J_[0:3, 1], iVar('y') * Ones(3))


def test_jb_printer_vector_var_vector_deri():
Expand All @@ -57,7 +61,8 @@ def test_jb_printer_vector_var_vector_deri():
assert symJb[1] == extend(col, SolList(*np.arange(0, 9).tolist()))
assert symJb[2] == extend(data, iVar('y'))
symJb = print_J_block(jb, False)
assert symJb[0] == AddAugmentedAssignment(iVar('J_', internal_use=True)[1:10, 0:9], Diag(iVar('y')))
assert symJb[0] == AddAugmentedAssignment(
iVar('J_', internal_use=True)[1:10, 0:9], Diag(iVar('y')))


def test_jbs_printer():
Expand Down Expand Up @@ -91,7 +96,7 @@ def test_jbs_printer():
assert symJbs[4] == extend(col, SolList(7, 8, 9, 10, 11, 12, 13, 14, 15))
assert symJbs[5] == extend(data, y ** 2)
symJbs = print_J_blocks(jac, False)
assert symJbs[0] == AddAugmentedAssignment(J_[0:3, 1:2], y * Ones(3))
assert symJbs[0] == AddAugmentedAssignment(J_[0:3, 1], y * Ones(3))
assert symJbs[1] == AddAugmentedAssignment(J_[3:12, 7:16], Diag(y ** 2))


Expand All @@ -100,7 +105,7 @@ def test_jbs_printer():
v = y_[1:2]
g = p_["g"]
J_ = zeros((2, 2))
J_[0:1,1:2] += ones(1)
J_[0:1,1] += ones(1)
return J_
""".strip()

Expand Down Expand Up @@ -173,8 +178,8 @@ def test_print_F_J():
J_ = zeros((6, 6))
J_[0:2,1:3] += diagflat(ones(2))
J_[2:4,0:2] += diagflat(-ones(2))
J_[4:5,2:3] += -ones(1)
J_[5:6,2:3] += ones(1)
J_[4:5,2] += -ones(1)
J_[5:6,2] += ones(1)
return J_
""".strip()

Expand Down Expand Up @@ -263,5 +268,73 @@ def test_made_numerical():
nF, code = made_numerical(F, y, sparse=True, output_code=True)
F0 = nF.F(y, nF.p)
J0 = nF.J(y, nF.p)
np.testing.assert_allclose(F0, np.array([2 * 1 + 1, 1 + np.sin(1)]), rtol=1e-8)
np.testing.assert_allclose(J0.toarray(), np.array([[2, 1], [2, 0.54030231]]), rtol=1e-8)
np.testing.assert_allclose(F0, np.array(
[2 * 1 + 1, 1 + np.sin(1)]), rtol=1e-8)
np.testing.assert_allclose(J0.toarray(), np.array(
[[2, 1], [2, 0.54030231]]), rtol=1e-8)


def test_hvp_printer():
jac = Jac()
x = iVar("x")
jac.add_block(
"a",
x[0],
JacBlock(
"a",
slice(0, 1),
x[0],
np.array([1]),
slice(0, 2),
exp(x[0]),
np.array([2.71828183]),
),
)
jac.add_block(
"a",
x[1],
JacBlock(
"a",
slice(0, 1),
x[1],
np.array([1]),
slice(0, 2),
cos(x[1]),
np.array([0.54030231]),
),
)
jac.add_block(
"b",
x[0],
JacBlock("b",
slice(1, 2),
x[0],
np.ones(1),
slice(0, 2),
1,
np.array([1])),
)
jac.add_block(
"b",
x[1],
JacBlock("b", slice(1, 2), x[1], np.ones(1),
slice(0, 2), 2 * x[1], np.array([2])),
)

h = Hvp(jac)
eqn_addr = Address()
eqn_addr.add('a', 1)
eqn_addr.add('b', 1)
var_addr = Address()
var_addr.add('x', 2)

code = print_Hvp('AE',
h,
eqn_addr,
var_addr,
dict())

HVP = Solverzlambdify(code, 'Hvp_', [numerical_interface, 'numpy'])
np.testing.assert_allclose(HVP(np.array([1, 2]), dict(), np.array([1, 1])).toarray(),
np.array([[2.71828183, -0.90929743],
[0., 2.]]))
Loading

0 comments on commit a453828

Please sign in to comment.