Skip to content

Commit

Permalink
fixed bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexlib committed Dec 12, 2023
1 parent 92d3c1e commit e2d7cb4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
11 changes: 6 additions & 5 deletions openptv_python/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
import copy

import numpy as np
from numba import njit
from scipy import ndimage
from scipy.ndimage import uniform_filter

from .parameters import ControlPar

filter_t = np.zeros((3, 3), dtype=float)


@njit
def filter_3(img, kernel=None) -> np.ndarray:
"""Apply a 3x3 filter to an image."""
if kernel is None: # default is a low pass
kernel = np.ones((3, 3)) / 9
filtered_img = ndimage.convolve(img, kernel)
return filtered_img


@njit
def lowpass_3(img: np.ndarray) -> np.ndarray:
"""Lowpass filter of 3x3."""
# Define the 3x3 lowpass filter kernel
Expand All @@ -28,7 +29,7 @@ def lowpass_3(img: np.ndarray) -> np.ndarray:

return img_lp


@njit
def fast_box_blur(filt_span: int, src: np.ndarray, cpar: ControlPar) -> np.ndarray:
"""Fast box blur."""
n = 2 * filt_span + 1
Expand Down Expand Up @@ -56,7 +57,7 @@ def fast_box_blur(filt_span: int, src: np.ndarray, cpar: ControlPar) -> np.ndarr
# new_img = map_coordinates(img, [coords_y, coords_x], mode="constant", cval=0)
# return new_img


@njit
def subtract_img(img1: np.ndarray, img2: np.ndarray, img_new: np.ndarray) -> None:
"""
Subtract img2 from img1 and store the result in img_new.
Expand All @@ -68,7 +69,7 @@ def subtract_img(img1: np.ndarray, img2: np.ndarray, img_new: np.ndarray) -> Non
"""
img_new[:] = ndimage.maximum(img1 - img2, 0)


@njit
def subtract_mask(img: np.ndarray, img_mask: np.ndarray):
"""Subtract mask from image."""
img_new = np.where(img_mask == 0, 0, img)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_arguments(self):
"""Test that the function raises errors when it should."""
output_img = prepare_image(
self.input_img,
filter_hp=self.filter_hp,
dim_lp=True,
# filter_hp=self.filter_hp,
# dim_lp=True,
)
assert output_img.shape == (5, 5)

Expand Down Expand Up @@ -62,8 +62,8 @@ def test_preprocess_image(self):
res = prepare_image(
self.input_img,
dim_lp=1,
filter_hp=self.filter_hp,
filter_file='',
# filter_hp=self.filter_hp,
# filter_file='',
)

# print(res)
Expand Down

0 comments on commit e2d7cb4

Please sign in to comment.