Skip to content

Commit

Permalink
refactor: implement UDM using editable package
Browse files Browse the repository at this point in the history
  • Loading branch information
rzyu45 committed Nov 22, 2024
1 parent f7702df commit f12e3fe
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 184 deletions.
1 change: 0 additions & 1 deletion Solverz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from Solverz.utilities.profile import count_time
from Solverz.variable.ssymbol import Var, AliasVar
from Solverz.model.basic import Model
from Solverz.num_api.user_function_parser import add_my_module, reset_my_module_paths

from importlib.metadata import version, PackageNotFoundError

Expand Down
17 changes: 5 additions & 12 deletions Solverz/num_api/module_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,13 @@
warnings.warn(f'Failed to import num api from SolMuseum: {e}')

# parse user defined functions
from .user_function_parser import load_my_module_paths

user_module_paths = load_my_module_paths()
if user_module_paths:
try:
import myfunc
print('User module detected.')
import os, sys
for path in user_module_paths:
module_name = os.path.splitext(os.path.basename(path))[0]
module_dir = os.path.dirname(path)

sys.path.insert(0, module_dir)
exec('import ' + module_name)
module_dict[module_name] = globals()[module_name]

module_dict['myfunc'] = myfunc
except ModuleNotFoundError as e:
pass

modules = [module_dict, 'numpy']
# We preserve the 'numpy' here in case one uses functions from sympy instead of from Solverz
Expand Down
90 changes: 31 additions & 59 deletions Solverz/num_api/test/test_udm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,46 @@
Test the user defined modules.
"""

import importlib
import os
import re
import shutil
from pathlib import Path

import pytest

from Solverz.num_api.user_function_parser import add_my_module, reset_my_module_paths

mymodule_code = """import numpy as np
from numba import njit
@njit(cache=True)
def c(x, y):
x = np.asarray(x).reshape((-1,))
y = np.asarray(y).reshape((-1,))
z = np.zeros_like(x)
for i in range(len(x)):
if x[i] <= y[i]:
z[i] = x[i]
else:
z[i] = y[i]
return z
"""


def test_udm():
# Create a .Solverz_test_temp directory in the user's home directory
user_home = str(Path.home())
solverz_dir = os.path.join(user_home, '.Solverz_test_temp')

# Create the .Solverz directory if it does not exist
if not os.path.exists(solverz_dir):
os.makedirs(solverz_dir)
from Solverz import Model, Var, Eqn, made_numerical, MulVarFunc
import numpy as np

file_path = os.path.join(solverz_dir, r'your_module.py')
file_path1 = os.path.join(solverz_dir, r'fake1.jl')
class Min(MulVarFunc):
arglength = 2

# Write the new paths to the file, but only if they are not already present
with open(file_path, 'a') as file:
file.write(mymodule_code)
def fdiff(self, argindex=1):
if argindex == 1:
return dMindx(*self.args)
elif argindex == 2:
return dMindy(*self.args)

with open(file_path1, 'a') as file:
file.write(mymodule_code)
def _numpycode(self, printer, **kwargs):
return (f'myfunc.Min' + r'(' +
', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')')

with pytest.raises(ValueError,
match=re.escape(f"The path {solverz_dir} is not a file.")):
add_my_module([solverz_dir])
class dMindx(MulVarFunc):
arglength = 2

with pytest.raises(ValueError,
match=re.escape(f"The path {os.path.join(user_home, '.Solverz_test_temp1')} does not exist.")):
add_my_module([os.path.join(user_home, '.Solverz_test_temp1')])
def _numpycode(self, printer, **kwargs):
return (f'myfunc.dMindx' + r'(' +
', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')')

with pytest.raises(ValueError,
match=re.escape(f"The file {file_path1} is not a Python file.")):
add_my_module([file_path1])
class dMindy(MulVarFunc):
arglength = 2

add_my_module([file_path])
def _numpycode(self, printer, **kwargs):
return (f'myfunc.dMindy' + r'(' +
', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')')

import Solverz
importlib.reload(Solverz.num_api.module_parser)
from Solverz.num_api.module_parser import your_module
import numpy as np
np.testing.assert_allclose(your_module.c(np.array([1, 0]), np.array([2, -1])), np.array([1, -1]))
m = Model()
m.x = Var('x', [1, 2])
m.y = Var('y', [3, 4])
m.f = Eqn('f', Min(m.x, m.y))
sae, y0 = m.create_instance()
ae = made_numerical(sae, y0, sparse=True)
np.testing.assert_allclose(ae.F(y0, ae.p), np.array([1.0, 2.0]))
np.testing.assert_allclose(ae.J(y0, ae.p).toarray(), np.array([[1., 0., 0., 0.],
[0., 1., 0., 0.]]))

shutil.rmtree(solverz_dir)
reset_my_module_paths()
111 changes: 0 additions & 111 deletions Solverz/num_api/user_function_parser.py

This file was deleted.

2 changes: 1 addition & 1 deletion docs/src/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,6 @@ ae = made_numerical(sae, y0, sparse=True)
We will have the output

```shell
>>> ae.F(y0)
>>> ae.F(y0, ae.p)
array([1.0, 2.0])
```

0 comments on commit f12e3fe

Please sign in to comment.