Skip to content

Commit

Permalink
Merge pull request #49 from neuro-ml/C++
Browse files Browse the repository at this point in the history
C++
  • Loading branch information
vovaf709 authored Feb 2, 2024
2 parents 48ac786 + 5212e88 commit bab650d
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: actions/setup-python@v4
- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.10.0
- name: Install gcc for mac
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: actions/setup-python@v4
- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.10.0
- name: Install gcc for mac
Expand Down
2 changes: 1 addition & 1 deletion _build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class LazyImport(dict):
# https://github.com/cython/cython/blob/6ad6ca0e9e7d030354b7fe7d7b56c3f6e6a4bc23/Cython/Compiler/ModuleNode.py#L773
def __init__(self, module_name):
self.module_name = module_name
return super().__init__(self, description=self.__doc__)
super().__init__(self, description=self.__doc__)

# Must be hashable due to
# https://github.com/cython/cython/blob/6ad6ca0e9e7d030354b7fe7d7b56c3f6e6a4bc23/Cython/Compiler/Main.py#L307
Expand Down
2 changes: 1 addition & 1 deletion imops/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.8.5'
__version__ = '0.8.6'
31 changes: 19 additions & 12 deletions imops/interp2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ def __init__(
if triangles is not None:
if not isinstance(triangles, np.ndarray):
raise TypeError(f'Wrong type of `triangles` argument, expected np.ndarray. Got {type(triangles)}')
if triangles.ndim != 2 or triangles.shape[1] != 3 or triangles.shape[0] * 3 != triangles.size:
if triangles.ndim != 2 or triangles.shape[1] != 3:
raise ValueError('Passed `triangles` argument has an incorrect shape')

if not isinstance(points, np.ndarray):
raise TypeError(f'Wrong type of `points` argument, expected np.ndarray. Got {type(points)}')

if points.ndim != 2 or points.shape[1] != 2:
raise ValueError('Passed `points` argument has an incorrect shape')

if values is not None:
if not isinstance(values, np.ndarray):
raise TypeError(f'Wrong type of `values` argument, expected np.ndarray. Got {type(values)}')
Expand All @@ -81,7 +84,7 @@ def __call__(self, points: np.ndarray, values: np.ndarray = None, fill_value: fl
points: np.ndarray
2-D array of data point coordinates to interpolate at
values: np.ndarray
1-D array of fp32/fp64 values to use at initial points. If passed, existing values will be rewritten
1-D array of fp32/fp64 values to use at initial points
fill_value: float
value to fill past edges
Expand All @@ -90,22 +93,26 @@ def __call__(self, points: np.ndarray, values: np.ndarray = None, fill_value: fl
new_values: np.ndarray
interpolated values at given points
"""
self.values = values or self.values
if values is None:
values = self.values

if self.values is None:
if values is None:
raise ValueError('`values` argument was never passed neither in __init__ or __call__ methods')

if not isinstance(self.values, np.ndarray):
raise TypeError(f'Wrong type of `values` argument, expected np.ndarray. Got {type(self.values)}')
if not isinstance(values, np.ndarray):
raise TypeError(f'Wrong type of `values` argument, expected np.ndarray. Got {type(values)}')

if values.ndim > 1:
raise ValueError(f'Wrong shape of `values` argument, expected ndim=1. Got shape {values.shape}')

if not isinstance(points, np.ndarray):
raise TypeError(f'Wrong type of `points` argument, expected np.ndarray. Got {type(points)}')

if self.values.ndim > 1:
raise ValueError(f'Wrong shape of `values` argument, expected ndim=1. Got shape {self.values.shape}')
if points.ndim != 2 or points.shape[1] != 2:
raise ValueError('Passed `points` argument has an incorrect shape')

_, neighbors = self.kdtree.query(
points, 1, **{'workers': self.num_threads} if python_version()[:3] != '3.6' else {}
)

if not isinstance(points, np.ndarray):
raise TypeError(f'Wrong type of `points` argument, expected np.ndarray. Got {type(points)}')

return super().__call__(points, self.values, neighbors, fill_value)
return super().__call__(points, values, neighbors, fill_value)
33 changes: 32 additions & 1 deletion tests/test_interp2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_test_data(num_threads):
scipy_values = griddata(x_points, x_values, int_points, method='linear', fill_value=0.0)

delta_ds = np.abs(delaunay_values - scipy_values)
# delta_di = np.abs(delaunay_values - imops_values)
delta_si = np.abs(scipy_values - imops_values)

assert delta_ds.max() <= 1e-10 and delta_si.max() <= 5, f'Failed with big case, arr_{i}'
Expand Down Expand Up @@ -74,6 +73,15 @@ def test_no_values(example):
Linear2DInterpolator(x_points)(int_points)


def test_no_changes_in_values(example):
x_points, int_points = example
first_values = np.ones((x_points.shape[0],), dtype=float)
second_values = 2.0 * np.ones((x_points.shape[0],), dtype=float)
interpolator = Linear2DInterpolator(x_points, first_values)
interpolator(int_points, second_values)
assert np.all(interpolator.values == first_values), 'Failed with changes in self.values after __call__'


def test_bad_values_dtype(example):
x_points, int_points = example

Expand All @@ -99,3 +107,26 @@ def test_bad_values_lenght(example):
Linear2DInterpolator(x_points, values=values)(int_points)
with pytest.raises(ValueError):
Linear2DInterpolator(x_points)(int_points, values=values)


def test_bad_triangles_dtype(example):
x_points, _ = example

with pytest.raises(TypeError):
Linear2DInterpolator(x_points, triangles=[1, 2, 3])
with pytest.raises(ValueError):
Linear2DInterpolator(x_points, triangles=np.ones((3, 2)))


def test_bad_points_dtype(example):
x_points, _ = example

values = np.ones((x_points.shape[0],))
with pytest.raises(TypeError):
Linear2DInterpolator(points=[1, 2, 3])
with pytest.raises(TypeError):
Linear2DInterpolator(points=x_points, values=values)(points=[1, 2, 3])
with pytest.raises(ValueError):
Linear2DInterpolator(points=np.ones((3, 3)))
with pytest.raises(ValueError):
Linear2DInterpolator(points=x_points, values=values)(points=np.ones((3, 3)))

0 comments on commit bab650d

Please sign in to comment.