Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive hessian vector product #87

Merged
merged 34 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c22fd55
To deploy docs only in the default repo
rzyu45 Jun 4, 2024
1956cb5
feat: add inline hvp
rzyu45 Jun 7, 2024
0d3d349
feat: fix hvp initializer
rzyu45 Jun 9, 2024
d99c4cc
test: add the hvp test.
rzyu45 Jun 11, 2024
b18f39f
feat: add the module printer of hvp
rzyu45 Jun 11, 2024
c7df3fc
test: add hvp module generator tests
rzyu45 Jun 12, 2024
eac2a5a
fix: to sort jac blocks in lecico order
rzyu45 Jun 12, 2024
21017a3
fix: simplify the `dae.M` property
rzyu45 Jun 12, 2024
f64fb12
feat: add the sicnm solver
rzyu45 Jun 15, 2024
5718bf9
fix: resolve typos in sicnm
rzyu45 Jun 17, 2024
d6f561a
fix: resolve typos in scinm
rzyu45 Jun 18, 2024
8f62513
fix: resolve sicnm typos
rzyu45 Jun 19, 2024
f2dd614
ENH: try spsolve_triangular() in scipy 1.14.0
rzyu45 Jun 25, 2024
7bdecf3
feat: add octave code printers
rzyu45 Jun 26, 2024
441a392
feat: add heaviside func
rzyu45 Jun 28, 2024
c213270
feat: add event detection in scinm
rzyu45 Jul 1, 2024
1cf2596
feat: add `ln`
rzyu45 Jul 1, 2024
b1af424
EHN: improve stats of sicnm and nr
rzyu45 Jul 1, 2024
d54a075
EHN: improve `lm.stats`
rzyu45 Jul 2, 2024
1008345
EHN: improve `cnr.stats`
rzyu45 Jul 2, 2024
3624495
feat: add __repr__ of `sol` class
rzyu45 Jul 4, 2024
a07b37d
Update doc-deploy.yml
rzyu45 Jul 31, 2024
b643dc4
Update doc-deploy.yml
rzyu45 Jul 31, 2024
8c0a478
update workflow trigger
rzyu45 Jul 31, 2024
4e852c6
Update doc-deploy.yml
rzyu45 Jul 31, 2024
6061539
fix: prepare Solverz for numpy 2.0
rzyu45 Aug 4, 2024
9c686b9
fix: resolve issues arose from scipy compabaility
rzyu45 Aug 4, 2024
00bdfd7
remove matplotlib requirement
rzyu45 Aug 4, 2024
3118275
fix: deprecate explicit call of matplotlib
rzyu45 Aug 4, 2024
5cec43e
Revert "remove matplotlib requirement"
rzyu45 Aug 4, 2024
091b48d
fix: update fdaesolver
rzyu45 Aug 4, 2024
512c0e1
fix: resolve #79
rzyu45 Aug 4, 2024
aed59d5
doc: add docstring of `sicnm()`
rzyu45 Aug 6, 2024
d9780d8
docs: add `HVP` in gettingstarted.md
rzyu45 Aug 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/cookbook-practice.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: Run tests in cookbook

on: [push]
on:
push:
pull_request:
types: [opened, synchronize, reopened, closed]
branches:
- '*'

jobs:
tests_in_cookbook:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/doc-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:

jobs:
build_and_deploy_job:
if: github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.action != 'closed')
if: github.repository == 'smallbunnies/Solverz' &&(github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.action != 'closed'))
runs-on: ubuntu-latest
name: Build and Deploy Job
steps:
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: Run tests

on: [push]
on:
push:
pull_request:
types: [opened, synchronize, reopened, closed]
branches:
- '*'

jobs:
built_in_tests:
Expand Down
3 changes: 2 additions & 1 deletion Solverz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from Solverz.equation.equations import AE, FDAE, DAE
from Solverz.equation.param import Param, IdxParam, TimeSeriesParam
from Solverz.sym_algebra.symbols import idx, Para, iVar, iAliasVar
from Solverz.sym_algebra.functions import Sign, Abs, transpose, exp, Diag, Mat_Mul, sin, cos, Min, AntiWindUp, Saturation
from Solverz.sym_algebra.functions import (Sign, Abs, transpose, exp, Diag, Mat_Mul, sin, cos, Min, AntiWindUp,
Saturation, heaviside, ln)
from Solverz.num_api.custom_function import minmod_flag, minmod
from Solverz.variable.variables import Vars, TimeVars, as_Vars
from Solverz.solvers import *
Expand Down
7 changes: 5 additions & 2 deletions Solverz/code_printer/make_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self,
variables: Vars | List[Vars],
name: str,
lang='python',
make_hvp=False,
directory=None,
jit=False):
self.name = name
Expand All @@ -20,6 +21,7 @@ def __init__(self,
self.variables = [variables]
else:
self.variables = variables
self.make_hvp = make_hvp
self.directory = directory
self.jit = jit

Expand All @@ -29,6 +31,7 @@ def render(self):
*self.variables,
name=self.name,
directory=self.directory,
numba=self.jit)
numba=self.jit,
make_hvp=self.make_hvp)
else:
raise NotImplemented(f"{self.lang} module renderer not implemented!")
raise NotImplemented(f"{self.lang} module renderer not implemented!")
69 changes: 62 additions & 7 deletions Solverz/code_printer/python/inline/inline_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def print_J(eqs_type: str,
body.extend(param_decla)
body.extend(print_trigger(PARAM))
if not sparse:
body.append(Assignment(temp, zeros(EqnAddr.total_size, var_addr.total_size)))
body.append(Assignment(temp, zeros(
EqnAddr.total_size, var_addr.total_size)))
body.extend(print_J_blocks(jac, False))
body.append(Return(temp))
else:
Expand All @@ -67,7 +68,7 @@ def print_J(eqs_type: str,

def print_J_blocks(jac: Jac, sparse: bool):
eqn_declaration = []
for eqn_name, jbs_row in jac.blocks.items():
for eqn_name, jbs_row in jac.blocks_sorted.items():
for var, jb in jbs_row.items():
eqn_declaration.extend(print_J_block(jb,
sparse))
Expand All @@ -82,17 +83,55 @@ def print_J_block(jb: JacBlock, sparse: bool) -> List:
# extend(iVar('row', internal_use=True), eqn_address),
# extend(iVar('col', internal_use=True), var_address),
# extend(iVar('data', internal_use=True), data)]
raise NotImplementedError("Matrix parameters in sparse Jac not implemented yet!")
raise NotImplementedError(
"Matrix parameters in sparse Jac not implemented yet!")
case 'vector' | 'scalar':
return [extend(iVar('row', internal_use=True), SolList(*jb.SpEqnAddr.tolist())),
extend(iVar('col', internal_use=True), SolList(*jb.SpVarAddr.tolist())),
extend(iVar('col', internal_use=True),
SolList(*jb.SpVarAddr.tolist())),
extend(iVar('data', internal_use=True), jb.SpDeriExpr)]
else:
return [AddAugmentedAssignment(iVar('J_', internal_use=True)[jb.DenEqnAddr, jb.DenVarAddr],
jb.DenDeriExpr)]


def made_numerical(eqs: SymEquations, y: Vars, sparse=False, output_code=False):
def print_Hvp(eqs_type: str,
hvp: Hvp,
EqnAddr: Address,
var_addr: Address,
PARAM: Dict[str, ParamBase],
nstep: int = 0,
sparse=True):
fp = print_Hvp_prototype(eqs_type,
nstep=nstep)
# initialize temp
temp = iVar('Hvp_', internal_use=True)
body = list()
body.extend(print_var(var_addr,
nstep)[0])
param_decla = print_param(PARAM)[0]
body.extend(param_decla)
body.extend(print_trigger(PARAM))
if not sparse:
body.append(Assignment(temp, zeros(
EqnAddr.total_size, var_addr.total_size)))
body.extend(print_J_blocks(hvp, False))
body.append(Return(temp))
else:
body.extend([Assignment(iVar('row', internal_use=True), SolList()),
Assignment(iVar('col', internal_use=True), SolList()),
Assignment(iVar('data', internal_use=True), SolList())])
body.extend(print_J_blocks(hvp, True))
body.append(Return(coo_2_csc(EqnAddr.total_size, var_addr.total_size)))
Jd = FunctionDefinition.from_FunctionPrototype(fp, body)
return pycode(Jd, fully_qualified_modules=False)


def made_numerical(eqs: SymEquations,
y: Vars,
sparse=False,
output_code=False,
make_hvp=False):
"""
factory method of numerical equations
"""
Expand All @@ -111,11 +150,24 @@ def made_numerical(eqs: SymEquations, y: Vars, sparse=False, output_code=False):
eqs.PARAM,
eqs.nstep,
sparse)
code = {'F': code_F, 'J': code_J}
if make_hvp:
eqs.hvp = Hvp(eqs.jac)
code_HVP = print_Hvp(eqs.__class__.__name__,
eqs.hvp,
eqs.a,
eqs.var_address,
eqs.PARAM,
eqs.nstep,
sparse)
code['HVP'] = code_HVP
custom_func = dict()
custom_func.update(numerical_interface)
custom_func.update(parse_trigger_func(eqs.PARAM))
F = Solverzlambdify(code_F, 'F_', modules=[custom_func, 'numpy'])
J = Solverzlambdify(code_J, 'J_', modules=[custom_func, 'numpy'])
if make_hvp:
HVP = Solverzlambdify(code_HVP, 'Hvp_', modules=[custom_func, 'numpy'])
p = parse_p(eqs.PARAM)
print('Complete!')
if isinstance(eqs, SymAE) and not isinstance(eqs, SymFDAE):
Expand All @@ -126,8 +178,10 @@ def made_numerical(eqs: SymEquations, y: Vars, sparse=False, output_code=False):
num_eqn = nDAE(eqs.M, F, J, p)
else:
raise ValueError(f'Unknown equation type {type(eqs)}')
if make_hvp:
num_eqn.HVP = HVP
if output_code:
return num_eqn, {'F': code_F, 'J': code_J}
return num_eqn, code
else:
return num_eqn

Expand Down Expand Up @@ -216,7 +270,8 @@ class tolist(Function):
@classmethod
def eval(cls, *args):
if len(args) != 1:
raise ValueError(f"Solverz' tolist function accepts only one input.")
raise ValueError(
f"Solverz' tolist function accepts only one input.")

def _numpycode(self, printer, **kwargs):
return r'((' + printer._print(self.args[0]) + r').tolist())'
Expand Down
93 changes: 83 additions & 10 deletions Solverz/code_printer/python/inline/tests/test_inline_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from Solverz.sym_algebra.symbols import iVar, idx, Para
from Solverz.variable.variables import combine_Vars, as_Vars
from Solverz.equation.jac import JacBlock, Ones, Jac
from Solverz.sym_algebra.functions import Diag
from Solverz.equation.hvp import Hvp
from Solverz.sym_algebra.functions import Diag, sin, cos, exp
from Solverz.code_printer.python.inline.inline_printer import print_J_block, extend, SolList, print_J_blocks, print_J, \
print_F, made_numerical
print_F, print_Hvp, made_numerical, Solverzlambdify
from Solverz.utilities.address import Address
from Solverz.num_api.custom_function import numerical_interface

# %%
row = iVar('row', internal_use=True)
Expand All @@ -41,7 +44,8 @@ def test_jb_printer_scalar_var_scalar_deri():
assert symJb[1] == extend(col, SolList(1, 1, 1))
assert symJb[2] == extend(data, iVar('y') * Ones(3))
symJb = print_J_block(jb, False)
assert symJb[0] == AddAugmentedAssignment(J_[0:3, 1:2], iVar('y') * Ones(3))
assert symJb[0] == AddAugmentedAssignment(
J_[0:3, 1], iVar('y') * Ones(3))


def test_jb_printer_vector_var_vector_deri():
Expand All @@ -57,7 +61,8 @@ def test_jb_printer_vector_var_vector_deri():
assert symJb[1] == extend(col, SolList(*np.arange(0, 9).tolist()))
assert symJb[2] == extend(data, iVar('y'))
symJb = print_J_block(jb, False)
assert symJb[0] == AddAugmentedAssignment(iVar('J_', internal_use=True)[1:10, 0:9], Diag(iVar('y')))
assert symJb[0] == AddAugmentedAssignment(
iVar('J_', internal_use=True)[1:10, 0:9], Diag(iVar('y')))


def test_jbs_printer():
Expand Down Expand Up @@ -91,7 +96,7 @@ def test_jbs_printer():
assert symJbs[4] == extend(col, SolList(7, 8, 9, 10, 11, 12, 13, 14, 15))
assert symJbs[5] == extend(data, y ** 2)
symJbs = print_J_blocks(jac, False)
assert symJbs[0] == AddAugmentedAssignment(J_[0:3, 1:2], y * Ones(3))
assert symJbs[0] == AddAugmentedAssignment(J_[0:3, 1], y * Ones(3))
assert symJbs[1] == AddAugmentedAssignment(J_[3:12, 7:16], Diag(y ** 2))


Expand All @@ -100,7 +105,7 @@ def test_jbs_printer():
v = y_[1:2]
g = p_["g"]
J_ = zeros((2, 2))
J_[0:1,1:2] += ones(1)
J_[0:1,1] += ones(1)
return J_
""".strip()

Expand Down Expand Up @@ -173,8 +178,8 @@ def test_print_F_J():
J_ = zeros((6, 6))
J_[0:2,1:3] += diagflat(ones(2))
J_[2:4,0:2] += diagflat(-ones(2))
J_[4:5,2:3] += -ones(1)
J_[5:6,2:3] += ones(1)
J_[4:5,2] += -ones(1)
J_[5:6,2] += ones(1)
return J_
""".strip()

Expand Down Expand Up @@ -263,5 +268,73 @@ def test_made_numerical():
nF, code = made_numerical(F, y, sparse=True, output_code=True)
F0 = nF.F(y, nF.p)
J0 = nF.J(y, nF.p)
np.testing.assert_allclose(F0, np.array([2 * 1 + 1, 1 + np.sin(1)]), rtol=1e-8)
np.testing.assert_allclose(J0.toarray(), np.array([[2, 1], [2, 0.54030231]]), rtol=1e-8)
np.testing.assert_allclose(F0, np.array(
[2 * 1 + 1, 1 + np.sin(1)]), rtol=1e-8)
np.testing.assert_allclose(J0.toarray(), np.array(
[[2, 1], [2, 0.54030231]]), rtol=1e-8)


def test_hvp_printer():
jac = Jac()
x = iVar("x")
jac.add_block(
"a",
x[0],
JacBlock(
"a",
slice(0, 1),
x[0],
np.array([1]),
slice(0, 2),
exp(x[0]),
np.array([2.71828183]),
),
)
jac.add_block(
"a",
x[1],
JacBlock(
"a",
slice(0, 1),
x[1],
np.array([1]),
slice(0, 2),
cos(x[1]),
np.array([0.54030231]),
),
)
jac.add_block(
"b",
x[0],
JacBlock("b",
slice(1, 2),
x[0],
np.ones(1),
slice(0, 2),
1,
np.array([1])),
)
jac.add_block(
"b",
x[1],
JacBlock("b", slice(1, 2), x[1], np.ones(1),
slice(0, 2), 2 * x[1], np.array([2])),
)

h = Hvp(jac)
eqn_addr = Address()
eqn_addr.add('a', 1)
eqn_addr.add('b', 1)
var_addr = Address()
var_addr.add('x', 2)

code = print_Hvp('AE',
h,
eqn_addr,
var_addr,
dict())

HVP = Solverzlambdify(code, 'Hvp_', [numerical_interface, 'numpy'])
np.testing.assert_allclose(HVP(np.array([1, 2]), dict(), np.array([1, 1])).toarray(),
np.array([[2.71828183, -0.90929743],
[0., 2.]]))
Loading
Loading