Skip to content

Commit

Permalink
test: add UDM test
Browse files Browse the repository at this point in the history
  • Loading branch information
rzyu45 committed Nov 21, 2024
1 parent 9d4c176 commit 50f28aa
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 12 deletions.
Empty file.
75 changes: 75 additions & 0 deletions Solverz/num_api/test/test_udm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Test the user defined modules.
"""

import importlib
import os
import re
import shutil
from pathlib import Path

import pytest

from Solverz.num_api.user_function_parser import add_my_module, reset_my_module_paths

mymodule_code = """import numpy as np
from numba import njit
@njit(cache=True)
def c(x, y):
x = np.asarray(x).reshape((-1,))
y = np.asarray(y).reshape((-1,))
z = np.zeros_like(x)
for i in range(len(x)):
if x[i] <= y[i]:
z[i] = x[i]
else:
z[i] = y[i]
return z
"""


def test_udm():
# Create a .Solverz_test_temp directory in the user's home directory
user_home = str(Path.home())
solverz_dir = os.path.join(user_home, '.Solverz_test_temp')

# Create the .Solverz directory if it does not exist
if not os.path.exists(solverz_dir):
os.makedirs(solverz_dir)

file_path = os.path.join(solverz_dir, r'your_module.py')
file_path1 = os.path.join(solverz_dir, r'fake1.jl')

# Write the new paths to the file, but only if they are not already present
with open(file_path, 'a') as file:
file.write(mymodule_code)

with open(file_path1, 'a') as file:
file.write(mymodule_code)

with pytest.raises(ValueError,
match=re.escape(f"The path {solverz_dir} is not a file.")):
add_my_module([solverz_dir])

with pytest.raises(ValueError,
match=re.escape(f"The path {os.path.join(user_home, '.Solverz_test_temp1')} does not exist.")):
add_my_module([os.path.join(user_home, '.Solverz_test_temp1')])

with pytest.raises(ValueError,
match=re.escape(f"The file {file_path1} is not a Python file.")):
add_my_module([file_path1])

add_my_module([file_path])

import Solverz
importlib.reload(Solverz.num_api.module_parser)
from Solverz.num_api.module_parser import your_module
import numpy as np
np.testing.assert_allclose(your_module.c(np.array([1, 0]), np.array([2, -1])), np.array([1, -1]))

shutil.rmtree(solverz_dir)
reset_my_module_paths()
21 changes: 9 additions & 12 deletions Solverz/num_api/user_function_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def validate_module_paths(paths):
if not os.path.exists(path):
raise ValueError(f"The path {path} does not exist.")

# Check if the path is a file
if not os.path.isfile(path):
raise ValueError(f"The path {path} is not a file.")
# Check if the path is a file
if not os.path.isfile(path):
raise ValueError(f"The path {path} is not a file.")

# Check if the file is a Python file
if not path.endswith('.py'):
raise ValueError(f"The file {path} is not a Python file.")
# Check if the file is a Python file
if not path.endswith('.py'):
raise ValueError(f"The file {path} is not a Python file.")

# If all checks pass, add the path to the valid paths list
valid_paths.append(path)
Expand All @@ -38,12 +38,9 @@ def add_my_module(paths, filename='user_modules.txt'):
:param paths: List of user-provided module paths
:param filename: Name of the file to save, default is 'user_modules.txt'
"""
try:
# Validate paths
validated_paths = validate_module_paths(paths)
except ValueError as e:
print(e)
return

# Validate paths
validated_paths = validate_module_paths(paths)

# Get the path to the .Solverz directory in the user's home directory
user_home = str(Path.home())
Expand Down

0 comments on commit 50f28aa

Please sign in to comment.