diff --git a/pandora/img_tools.py b/pandora/img_tools.py index a701bec..dc44878 100644 --- a/pandora/img_tools.py +++ b/pandora/img_tools.py @@ -656,7 +656,7 @@ def convert_pyramid_to_dataset( return pyramid -def shift_right_img(img_right: xr.Dataset, subpix: int, band: str = None) -> List[xr.Dataset]: +def shift_right_img(img_right: xr.Dataset, subpix: int, band: str = None, order: int = 1) -> List[xr.Dataset]: """ Return an array that contains the shifted right images @@ -666,6 +666,8 @@ def shift_right_img(img_right: xr.Dataset, subpix: int, band: str = None) -> Lis :type subpix: int :param band: User's value for selected band :type band: str + :param order: order parameter on zoom method + :type order: int :return: an array that contains the shifted right images :rtype: array of xarray.Dataset """ @@ -683,7 +685,7 @@ def shift_right_img(img_right: xr.Dataset, subpix: int, band: str = None) -> Lis for ind in np.arange(1, subpix): shift = 1 / subpix # For each index, shift the right image for subpixel precision 1/subpix*index - data = zoom(selected_band, (1, (nx_ * subpix - (subpix - 1)) / float(nx_)), order=1)[:, ind::subpix] + data = zoom(selected_band, (1, (nx_ * subpix - (subpix - 1)) / float(nx_)), order=order)[:, ind::subpix] col = np.arange( img_right.coords["col"].values[0] + shift * ind, img_right.coords["col"].values[-1], step=1 ) # type: np.ndarray diff --git a/pandora/matching_cost/census.py b/pandora/matching_cost/census.py index 37b0c38..9bd69c4 100644 --- a/pandora/matching_cost/census.py +++ b/pandora/matching_cost/census.py @@ -103,7 +103,7 @@ def compute_cost_volume( self.check_band_input_mc(img_left, img_right) # Contains the shifted right images - img_right_shift = shift_right_img(img_right, self._subpix, self._band) + img_right_shift = shift_right_img(img_right, self._subpix, self._band, self._spline_order) # Maximal cost of the cost volume with census measure cmax = int(self._window_size**2) diff --git a/pandora/matching_cost/matching_cost.py b/pandora/matching_cost/matching_cost.py index c028435..17c5939 100644 --- a/pandora/matching_cost/matching_cost.py +++ b/pandora/matching_cost/matching_cost.py @@ -53,18 +53,21 @@ class AbstractMatchingCost: _band = None _step_col = None _method = None + _spline_order = None # Default configuration, do not change these values _WINDOW_SIZE = 5 _SUBPIX = 1 _BAND = None _STEP_COL = 1 + _SPLINE_ORDER = 1 # Matching cost schema confi schema = { "subpix": And(int, lambda sp: sp in [1, 2, 4]), "band": Or(str, lambda input: input is None), "step": And(int, lambda y: y >= 1), + "spline_order": And(int, lambda y: 1 <= y <= 5), } margins = HalfWindowMargins() @@ -144,6 +147,10 @@ def instantiate_class(self, **cfg: Union[str, int]) -> None: self._band = self.cfg["band"] self._step_col = int(self.cfg["step"]) self._method = str(self.cfg["matching_cost_method"]) + self._spline_order = int(self.cfg["spline_order"]) + + # Remove spline_order key because it is a pandora2d setting and a need + del self.cfg["spline_order"] def check_conf(self, **cfg: Dict[str, Union[str, int]]) -> Dict: """ @@ -168,6 +175,8 @@ def check_conf(self, **cfg: Dict[str, Union[str, int]]) -> Dict: raise ValueError("Step parameter cannot be different from 1") if "step" not in cfg: cfg["step"] = self._STEP_COL # type: ignore + if "spline_order" not in cfg: + cfg["spline_order"] = self._SPLINE_ORDER # type: ignore return cfg diff --git a/pandora/matching_cost/sad_ssd.py b/pandora/matching_cost/sad_ssd.py index 07e4602..11330bd 100644 --- a/pandora/matching_cost/sad_ssd.py +++ b/pandora/matching_cost/sad_ssd.py @@ -105,7 +105,7 @@ def compute_cost_volume( self.check_band_input_mc(img_left, img_right) # Contains the shifted right images - img_right_shift = shift_right_img(img_right, self._subpix, self._band) + img_right_shift = shift_right_img(img_right, self._subpix, self._band, self._spline_order) if self._band is not None: band_index_left = list(img_left.band_im.data).index(self._band) band_index_right = list(img_right.band_im.data).index(self._band) diff --git a/pandora/matching_cost/zncc.py b/pandora/matching_cost/zncc.py index da7db7e..2dcd56d 100644 --- a/pandora/matching_cost/zncc.py +++ b/pandora/matching_cost/zncc.py @@ -142,7 +142,7 @@ def compute_cost_volume( self.check_band_input_mc(img_left, img_right) # Contains the shifted right images - img_right_shift = shift_right_img(img_right, self._subpix, self._band) + img_right_shift = shift_right_img(img_right, self._subpix, self._band, self._spline_order) # Computes the standard deviation raster for the whole images # The standard deviation raster is truncated for points that are not calculable diff --git a/tests/test_matching_cost/test_matching_cost.py b/tests/test_matching_cost/test_matching_cost.py index 9fff4d7..40467ca 100644 --- a/tests/test_matching_cost/test_matching_cost.py +++ b/tests/test_matching_cost/test_matching_cost.py @@ -32,6 +32,7 @@ import numpy as np import xarray as xr +import json_checker import pytest from pandora import matching_cost @@ -592,6 +593,54 @@ def test_find_nearest_multiple_of_step(self, step_col, value, expected): assert result == expected +class TestSplineOrder: + """ + Description : Test spline_order in matching_cost configuration + """ + + def test_nominal_case(self): + matching_cost.AbstractMatchingCost(**{"matching_cost_method": "zncc", "window_size": 5}) + + def test_default_spline_order(self): + result = matching_cost.AbstractMatchingCost(**{"matching_cost_method": "zncc", "window_size": 5}) + + assert result._spline_order == 1 # pylint:disable=protected-access + + def test_fails_with_negative_spline_order(self): + """ + Description : Test if the spline_order is negative + """ + with pytest.raises(json_checker.core.exceptions.DictCheckerError) as err: + matching_cost.AbstractMatchingCost(**{"matching_cost_method": "zncc", "window_size": 5, "spline_order": -2}) + assert "spline_order" in err.value.args[0] + + def test_fails_with_null_spline_order(self): + """ + Description : Test if the spline_order is null + """ + with pytest.raises(json_checker.core.exceptions.DictCheckerError) as err: + matching_cost.AbstractMatchingCost(**{"matching_cost_method": "zncc", "window_size": 5, "spline_order": 0}) + assert "spline_order" in err.value.args[0] + + def test_fails_with_more_than_five(self): + """ + Description : Test if the spline_order is > 5 + """ + with pytest.raises(json_checker.core.exceptions.DictCheckerError) as err: + matching_cost.AbstractMatchingCost(**{"matching_cost_method": "zncc", "window_size": 5, "spline_order": 6}) + assert "spline_order" in err.value.args[0] + + def test_fails_with_string_element(self): + """ + Description : Test fails if the spline_order is a string element + """ + with pytest.raises(json_checker.core.exceptions.DictCheckerError) as err: + matching_cost.AbstractMatchingCost( + **{"matching_cost_method": "zncc", "window_size": 5, "spline_order": "1"} + ) + assert "spline_order" in err.value.args[0] + + def make_image(data, disparity): """Make an image with a disparity range.""" return xr.Dataset(