Skip to content

Commit

Permalink
Merge pull request #100 from rzyu45/dev
Browse files Browse the repository at this point in the history
Support of user defined functions
  • Loading branch information
rzyu45 authored Nov 22, 2024
2 parents 8b27e4f + 9a43302 commit 411752d
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 127 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/cookbook-practice.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ jobs:
run: |
pip install --upgrade pip setuptools wheel
pip install git+https://github.com/${{github.repository}}.git@${{ github.sha }}
pip install pytest-xdist
- name: Run Tests
run: |
pytest
pytest -n auto
6 changes: 2 additions & 4 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ jobs:
python-version: '3.11'
cache: 'pip' # caching pip dependencies
- run: |
pip install -r requirements.txt
cd docs/
pip install -r requirements.txt
cd ..
pip install -e .
pip install git+https://github.com/smallbunnies/myfunc.git@main
- run: | # run both independent pytest and doctest
pytest
2 changes: 1 addition & 1 deletion Solverz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from Solverz.equation.param import Param, IdxParam, TimeSeriesParam
from Solverz.sym_algebra.symbols import idx, Para, iVar, iAliasVar
from Solverz.sym_algebra.functions import (Sign, Abs, transpose, exp, Diag, Mat_Mul, sin, cos, Min, AntiWindUp,
Saturation, heaviside, ln)
Saturation, heaviside, ln, MulVarFunc, UniVarFunc)
from Solverz.variable.variables import Vars, TimeVars, as_Vars
from Solverz.solvers import *
from Solverz.code_printer import made_numerical, module_printer
Expand Down
23 changes: 7 additions & 16 deletions Solverz/code_printer/python/module/module_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def render_modules(eqs: SymEquations,
code_dict["inner_Hvp"] = inner_Hvp['code_inner_Hvp']
code_dict["sub_inner_Hvp"] = inner_Hvp['code_sub_inner_Hvp']



def print_trigger_func_code():
code_tfuc = dict()
trigger_func = parse_trigger_func(eqs.PARAM)
Expand Down Expand Up @@ -176,7 +174,11 @@ def create_python_module(module_name,


def print_init_code(eqn_type: str, module_name, eqn_param):
code = 'from .num_func import F_, J_\n'
code = '"""\n'
from ...._version import __version__
code += f'Python module generated by Solverz {__version__}\n'
code += '"""\n'
code += 'from .num_func import F_, J_\n'
code += 'from .dependency import setting, p_, y\n'
code += 'import time\n'
match eqn_type:
Expand Down Expand Up @@ -207,7 +209,7 @@ def print_init_code(eqn_type: str, module_name, eqn_param):
code += f'start = time.perf_counter()\n'
code += code_compile.format(alpha=code_compile_args_F_J, beta=code_compile_args_Hvp)
code += f'end = time.perf_counter()\n'
code += "print(f'Compiling time elapsed: {end-start}s')\n"
code += "print(f'Compiling time elapsed: {end - start}s')\n"
return code


Expand Down Expand Up @@ -264,25 +266,14 @@ def print_module_code(code_dict: Dict[str, str], numba=False):

return code

#
code_from_SolMuseum="""
try:
import SolMuseum.num_api as SolMF
except ImportError as e:
pass

"""
def print_dependency_code(modules):
code = "import os\n"
code += "current_module_dir = os.path.dirname(os.path.abspath(__file__))\n"
code += 'from Solverz import load\n'
code += 'auxiliary = load(f"{current_module_dir}\\\\param_and_setting.pkl")\n'
code += 'from numpy import *\n'
code += 'import numpy as np\n'
code += 'import Solverz.num_api.custom_function as SolCF\n' # import Solverz built-in func
code += code_from_SolMuseum
code += 'import scipy.sparse as sps\n'
code += 'from numba import njit\n'
code += 'from Solverz.num_api.module_parser import *\n'
code += 'setting = auxiliary["eqn_param"]\n'
code += 'row = setting["row"]\n'
code += 'col = setting["col"]\n'
Expand Down
22 changes: 9 additions & 13 deletions Solverz/code_printer/python/module/test/test_module_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,7 @@
from Solverz import load
auxiliary = load(f"{current_module_dir}\\param_and_setting.pkl")
from numpy import *
import numpy as np
import Solverz.num_api.custom_function as SolCF
try:
import SolMuseum.num_api as SolMF
except ImportError as e:
pass
import scipy.sparse as sps
from numba import njit
from Solverz.num_api.module_parser import *
setting = auxiliary["eqn_param"]
row = setting["row"]
col = setting["col"]
Expand All @@ -38,7 +29,12 @@
y = auxiliary["vars"]
"""

expected_init = r"""from .num_func import F_, J_
from ....._version import __version__

expected_init = r'''"""
Python module generated by Solverz {vs}
"""
from .num_func import F_, J_
from .dependency import setting, p_, y
import time
from Solverz.num_api.num_eqn import nAE
Expand All @@ -60,8 +56,8 @@
v = ones_like(y)
mdl.HVP(y, p_, v)
end = time.perf_counter()
print(f'Compiling time elapsed: {end-start}s')
"""
print(f'Compiling time elapsed: {{end - start}}s')
'''.format(vs=__version__)


def test_AE_module_printer():
Expand Down
21 changes: 18 additions & 3 deletions Solverz/num_api/module_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,28 @@
import warnings

import Solverz.num_api.custom_function as SolCF
import numpy as np
import scipy.sparse as sps
from numba import njit

modules = [{'SolCF': SolCF, 'np': numpy, 'sps': scipy.sparse}, 'numpy']
# We preserve the 'numpy' here in case one uses functions from sympy instead of from Solverz
module_dict = {'SolCF': SolCF, 'np': np, 'sps': sps, 'njit': njit}

# parse modules from museum
try:
import SolMuseum.num_api as SolMF
modules[0]['SolMF'] = SolMF
module_dict['SolMF'] = SolMF
except ModuleNotFoundError as e:
warnings.warn(f'Failed to import num api from SolMuseum: {e}')

# parse user defined functions

try:
import myfunc
print('User module detected.')
module_dict['myfunc'] = myfunc
except ModuleNotFoundError as e:
pass

modules = [module_dict, 'numpy']
# We preserve the 'numpy' here in case one uses functions from sympy instead of from Solverz
__all__ = list(module_dict.keys())
47 changes: 47 additions & 0 deletions Solverz/num_api/test/test_udm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Test the user defined modules.
"""


def test_udm():

from Solverz import Model, Var, Eqn, made_numerical, MulVarFunc
import numpy as np

class Min(MulVarFunc):
arglength = 2

def fdiff(self, argindex=1):
if argindex == 1:
return dMindx(*self.args)
elif argindex == 2:
return dMindy(*self.args)

def _numpycode(self, printer, **kwargs):
return (f'myfunc.Min' + r'(' +
', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')')

class dMindx(MulVarFunc):
arglength = 2

def _numpycode(self, printer, **kwargs):
return (f'myfunc.dMindx' + r'(' +
', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')')

class dMindy(MulVarFunc):
arglength = 2

def _numpycode(self, printer, **kwargs):
return (f'myfunc.dMindy' + r'(' +
', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')')

m = Model()
m.x = Var('x', [1, 2])
m.y = Var('y', [3, 4])
m.f = Eqn('f', Min(m.x, m.y))
sae, y0 = m.create_instance()
ae = made_numerical(sae, y0, sparse=True)
np.testing.assert_allclose(ae.F(y0, ae.p), np.array([1.0, 2.0]))
np.testing.assert_allclose(ae.J(y0, ae.p).toarray(), np.array([[1., 0., 0., 0.],
[0., 1., 0., 0.]]))

Loading

0 comments on commit 411752d

Please sign in to comment.