Skip to content

Commit

Permalink
Merge branch '421-add-env-var-for-numba-cache' into 'release'
Browse files Browse the repository at this point in the history
Resolve "add possibility to not use numba cache"

See merge request 3d/PandoraBox/pandora!366
  • Loading branch information
lecontm committed Aug 21, 2024
2 parents e64d003 + 403c9ba commit 93ddf20
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 25 deletions.
8 changes: 7 additions & 1 deletion docs/source/userguide/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,10 @@ How can I disable numba parallelization?
****************************************
Some functions of Pandora are parallelized using the numba package. To prevent the use of this tool, it is possible to set
an environment variable named **PANDORA_NUMBA_PARALLEL** to **false**.
an environment variable named **PANDORA_NUMBA_PARALLEL** to **false**.
How can I disable numba cache?
******************************
Some Pandora functions have a cache of the numba package. This cache can improve execution speed. To do this, you can set
an environment variable named **PANDORA_NUMBA_PAR** to **true**.
18 changes: 13 additions & 5 deletions pandora/aggregation/cbca.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"""

from typing import Dict, Union, Tuple, List
import os
from ast import literal_eval

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -305,7 +307,7 @@ def computes_cross_supports(
return cross_left, cross_right


@njit("f4[:, :](f4[:, :])", cache=True)
@njit("f4[:, :](f4[:, :])", cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def cbca_step_1(cv: np.ndarray) -> np.ndarray:
"""
Giving the matching cost for one disparity, build a horizontal integral image storing the cumulative row sum,
Expand All @@ -332,7 +334,10 @@ def cbca_step_1(cv: np.ndarray) -> np.ndarray:
return step1


@njit("(f4[:, :], i2[:, :, :], i2[:, :, :], i8[:], i8[:])", cache=True)
@njit(
"(f4[:, :], i2[:, :, :], i2[:, :, :], i8[:], i8[:])",
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def cbca_step_2(
step1: np.ndarray,
cross_left: np.ndarray,
Expand Down Expand Up @@ -382,7 +387,7 @@ def cbca_step_2(
return step2, sum_step2


@njit("f4[:, :](f4[:, :])", cache=True)
@njit("f4[:, :](f4[:, :])", cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def cbca_step_3(step2: np.ndarray) -> np.ndarray:
"""
Giving the horizontal matching cost, build a vertical integral image for one disparity,
Expand All @@ -406,7 +411,10 @@ def cbca_step_3(step2: np.ndarray) -> np.ndarray:
return step3


@njit("(f4[:, :], f4[:, :], i2[:, :, :], i2[:, :, :], i8[:], i8[:])", cache=True)
@njit(
"(f4[:, :], f4[:, :], i2[:, :, :], i2[:, :, :], i8[:], i8[:])",
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def cbca_step_4(
step3: np.ndarray,
sum2: np.ndarray,
Expand Down Expand Up @@ -463,7 +471,7 @@ def cbca_step_4(
return step4, sum4


@njit("i2[:, :, :](f4[:, :], i2, f4)", cache=True)
@njit("i2[:, :, :](f4[:, :], i2, f4)", cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def cross_support(image: np.ndarray, len_arms: int, intensity: float) -> np.ndarray:
"""
Compute the cross support for an image: find the 4 arms.
Expand Down
5 changes: 2 additions & 3 deletions pandora/cost_volume_confidence/ambiguity.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def normalize_with_percentile(self, ambiguity: np.ndarray) -> np.ndarray:
@njit(
"f4[:, :](f4[:, :, :], f8[:], i8, i8[:, :, :],f4[:], bool_)",
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "False")),
cache=True,
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def compute_ambiguity(
cv: np.ndarray,
Expand Down Expand Up @@ -247,7 +247,6 @@ def compute_ambiguity(
if np.isnan(normalized_min_cost):
ambiguity[row, col] = nbr_etas * nb_disps
else:

idx_disp_min = np.searchsorted(disparity_range, grids[0][row, col])
idx_disp_max = np.searchsorted(disparity_range, grids[1][row, col]) + 1

Expand Down Expand Up @@ -276,7 +275,7 @@ def compute_ambiguity(
@njit(
"Tuple((f4[:, :],f4[:, :, :]))(f4[:, :, :], f8[:], i8, i8[:, :, :], f4[:])",
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "False")),
cache=True,
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def compute_ambiguity_and_sampled_ambiguity(
cv: np.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion pandora/cost_volume_confidence/interval_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def confidence_prediction(
@njit(
"UniTuple(f4[:, :], 2)(f4[:, :, :], f4[:], f4, f4)",
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")),
cache=True,
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def compute_interval_bounds(
cv: np.ndarray,
Expand Down
4 changes: 2 additions & 2 deletions pandora/cost_volume_confidence/risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def confidence_prediction(
@njit(
"Tuple((f4[:, :],f4[:, :]))(f4[:, :, :], f4[:, :, :], f8[:], i8, i8[:, :, :], f4[:])",
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")),
cache=True,
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def compute_risk(
cv: np.ndarray,
Expand Down Expand Up @@ -260,7 +260,7 @@ def compute_risk(
@njit(
"Tuple((f4[:, :],f4[:, :],f4[:, :, :],f4[:, :, :]))(f4[:, :, :], f4[:, :, :], f8[:], i8, i8[:, :, :], f4[:])",
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")),
cache=True,
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def compute_risk_and_sampled_risk(
cv: np.ndarray,
Expand Down
6 changes: 4 additions & 2 deletions pandora/img_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

import warnings
from typing import List, Union, Tuple, cast, Dict
from ast import literal_eval
import os

import numpy as np
import rasterio
Expand Down Expand Up @@ -536,7 +538,7 @@ def fill_nodata_image(dataset: xr.Dataset) -> Tuple[np.ndarray, np.ndarray]:
return img, msk


@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def interpolate_nodata_sgm(img: np.ndarray, valid: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Interpolation of the input image to resolve invalid (nodata) pixels.
Expand Down Expand Up @@ -825,7 +827,7 @@ def compute_mean_raster(img: xr.Dataset, win_size: int, band: str = None) -> np.
return r_mean / float(win_size * win_size)


@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def find_valid_neighbors(dirs: np.ndarray, disp: np.ndarray, valid: np.ndarray, row: int, col: int):
"""
Find valid neighbors along directions
Expand Down
6 changes: 4 additions & 2 deletions pandora/interval_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@


@njit(
"b1[:,:](i8[:,:], i8[:,:], i8)", parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")), cache=True
"b1[:,:](i8[:,:], i8[:,:], i8)",
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")),
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def create_connected_graph(border_left: np.ndarray, border_right: np.ndarray, depth: int) -> np.ndarray:
"""
Expand Down Expand Up @@ -80,7 +82,7 @@ def create_connected_graph(border_left: np.ndarray, border_right: np.ndarray, de
@njit(
"Tuple([f4[:,:],f4[:,:],b1[:,:]])(f4[:,:],f4[:,:],i8[:,:],i8[:,:],b1[:,:],f8)",
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")),
cache=True,
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def graph_regularization(
interval_inf: np.ndarray,
Expand Down
4 changes: 3 additions & 1 deletion pandora/refinement/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"""

from typing import Dict, Tuple
from ast import literal_eval
import os

import numpy as np
from json_checker import Checker, And
Expand Down Expand Up @@ -72,7 +74,7 @@ def desc(self) -> None:
print("Quadratic refinement method")

@staticmethod
@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def refinement_method(cost: np.ndarray, disp: float, measure: str) -> Tuple[float, float, int]:
"""
Return the subpixel disparity and cost, by fitting a quadratic curve
Expand Down
12 changes: 9 additions & 3 deletions pandora/refinement/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ def desc(self) -> None:
print("Subpixel method description")

@staticmethod
@njit(parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")), cache=True)
@njit(
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")),
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def loop_refinement(
cv: np.ndarray,
disp: np.ndarray,
Expand Down Expand Up @@ -281,7 +284,7 @@ def loop_refinement(

@staticmethod
@abstractmethod
@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def refinement_method(cost: np.ndarray, disp: float, measure: str) -> Tuple[float, float, int]:
"""
Return the subpixel disparity and cost
Expand All @@ -298,7 +301,10 @@ def refinement_method(cost: np.ndarray, disp: float, measure: str) -> Tuple[floa
"""

@staticmethod
@njit(parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")), cache=True)
@njit(
parallel=literal_eval(os.environ.get("PANDORA_NUMBA_PARALLEL", "True")),
cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")),
)
def loop_approximate_refinement(
cv: np.ndarray,
disp: np.ndarray,
Expand Down
4 changes: 3 additions & 1 deletion pandora/refinement/vfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"""

from typing import Dict, Tuple
import os
from ast import literal_eval

import numpy as np
from json_checker import Checker, And
Expand Down Expand Up @@ -72,7 +74,7 @@ def desc(self) -> None:
print("Vfit refinement method")

@staticmethod
@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def refinement_method(cost: np.ndarray, disp: float, measure: str) -> Tuple[float, float, int]:
"""
Return the subpixel disparity and cost, by matching a symmetric V shape (linear interpolation)
Expand Down
10 changes: 6 additions & 4 deletions pandora/validation/interpolated_disparity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import math
from abc import ABCMeta, abstractmethod
from typing import Tuple, Dict
import os
from ast import literal_eval

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -242,7 +244,7 @@ def interpolated_disparity(
left["validity_mask"] = mask_border(left)

@staticmethod
@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def interpolate_occlusion_mc_cnn(disp: np.ndarray, valid: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Interpolation of the left disparity map to resolve occlusion conflicts.
Expand Down Expand Up @@ -299,7 +301,7 @@ def interpolate_occlusion_mc_cnn(disp: np.ndarray, valid: np.ndarray) -> Tuple[n
return out_disp, out_val

@staticmethod
@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def interpolate_mismatch_mc_cnn(disp: np.ndarray, valid: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Interpolation of the left disparity map to resolve mismatch conflicts.
Expand Down Expand Up @@ -464,7 +466,7 @@ def interpolated_disparity(
left.attrs["interpolated_disparity"] = "sgm"

@staticmethod
@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def interpolate_occlusion_sgm(disp: np.ndarray, valid: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Interpolation of the left disparity map to resolve occlusion conflicts.
Expand Down Expand Up @@ -510,7 +512,7 @@ def interpolate_occlusion_sgm(disp: np.ndarray, valid: np.ndarray) -> Tuple[np.n
return out_disp, out_val

@staticmethod
@njit(cache=True)
@njit(cache=literal_eval(os.environ.get("PANDORA_NUMBA_CACHE", "True")))
def interpolate_mismatch_sgm(disp: np.ndarray, valid: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Interpolation of the left disparity map to resolve mismatch conflicts. Interpolate mismatch by finding the
Expand Down

0 comments on commit 93ddf20

Please sign in to comment.