Skip to content

Commit

Permalink
Fix pickling error in Numba exception handling & improve cache checks (
Browse files Browse the repository at this point in the history
…#1083)

* attempt to fix recursion error (1)

* better checks for caching

* add better check if a numba function is cached or needs to be re-compiled

* add case for non-numba functions
  • Loading branch information
philipc2 authored Nov 21, 2024
1 parent b6b1eb8 commit 2f193a4
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 35 deletions.
25 changes: 0 additions & 25 deletions uxarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
import uxarray.constants
import sys
#
# # TODO: numba recursion limit ?


from .core.api import open_grid, open_dataset, open_mfdataset

from .core.dataset import UxDataset
Expand All @@ -25,23 +19,6 @@
__version__ = "999"


# Flag for enabling FMA instructions across the package
def enable_fma():
"""Enables Fused-Multiply-Add (FMA) instructions using the ``pyfma``
package."""
uxarray.constants.ENABLE_FMA = True


def disable_fma():
"""Disable Fused-Multiply-Add (FMA) instructions using the ``pyfma``
package."""
uxarray.constants.ENABLE_FMA = False


disable_fma()
sys.setrecursionlimit(10000)


__all__ = (
"open_grid",
"open_dataset",
Expand All @@ -55,6 +32,4 @@ def disable_fma():
"diverging",
"sequential_blue",
"sequential_green",
"enable_fma",
"disable_fma",
)
4 changes: 1 addition & 3 deletions uxarray/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,7 @@ def pole_point_inside_polygon(pole, face_edges_xyz, face_edges_lonlat):
return ((north_intersections + south_intersections) % 2) != 0

else:
raise ValueError(
f"Invalid pole point query. Current location: {location}, query pole point: {pole}"
)
raise ValueError("Invalid pole point query.")


@njit(cache=True)
Expand Down
15 changes: 8 additions & 7 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
_check_normalization,
)

from uxarray.utils.numba import is_numba_function_cached


from uxarray.conventions import ugrid

Expand Down Expand Up @@ -1367,13 +1369,12 @@ def bounds(self):
Dimensions ``(n_face", two, two)``
"""
if "bounds" not in self._ds:
if hasattr(compute_temp_latlon_array, "inspect_llvm"):
if len(compute_temp_latlon_array.inspect_llvm()) == 0:
warn(
"Necessary functions for computing face bounds are not translated yet with Numba. This initial"
"translation may take some time.",
RuntimeWarning,
)
if not is_numba_function_cached(compute_temp_latlon_array):
warn(
"Necessary functions for computing the bounds of each face are not yet compiled with Numba. "
"This initial execution will be significantly longer.",
RuntimeWarning,
)

_populate_bounds(self)

Expand Down
44 changes: 44 additions & 0 deletions uxarray/utils/numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import pickle
import numba


def is_numba_function_cached(func):
"""
Determines if a numba function is cached and up-to-date.
Returns:
- True if cache exists and is valid or the input is not a Numba function.
- False if cache doesn't exist or needs recompilation
"""

if not hasattr(func, "_cache"):
return True

cache = func._cache
cache_file = cache._cache_file

# Check if cache file exists
full_path = os.path.join(cache._cache_path, cache_file._index_name)
if not os.path.isfile(full_path):
return False

try:
# Load and check version
with open(full_path, "rb") as f:
version = pickle.load(f)
if version != numba.__version__:
return False

# Load and check source stamp
data = f.read()
stamp, _ = pickle.loads(data)

# Get current source stamp
current_stamp = cache._impl.locator.get_source_stamp()

# Compare stamps
return stamp == current_stamp

except (OSError, pickle.PickleError):
return False

0 comments on commit 2f193a4

Please sign in to comment.