Skip to content

Commit

Permalink
fix: add callable type in get_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn committed Dec 4, 2024
1 parent ee538d2 commit 2311ee6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
27 changes: 21 additions & 6 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,7 +1871,7 @@ def set_mask(self, mask: NDArrayBool | Mask) -> None:
else:
self.data[mask_arr > 0] = np.ma.masked

def _statistics(self, band: int = 1) -> dict[str, float]:
def _statistics(self, band: int = 1) -> dict[str, np.floating[Any]]:
"""
Calculate common statistics for a specified band in the raster.
Expand All @@ -1885,6 +1885,10 @@ def _statistics(self, band: int = 1) -> dict[str, float]:
else:
data = self.data[band - 1]

# If data is a MaskedArray, use the compressed version (without masked values)
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()

# Compute the statistics
stats_dict = {
"Mean": np.nanmean(data),
Expand All @@ -1901,8 +1905,12 @@ def _statistics(self, band: int = 1) -> dict[str, float]:
return stats_dict

def get_stats(
self, stats_name: str | list[str | Callable[[NDArrayNum], float]] | None = None, band: int = 1
) -> float | dict[str, float]:
self,
stats_name: (
str | Callable[[NDArrayNum], np.floating[Any]] | list[str | Callable[[NDArrayNum], np.floating[Any]]] | None
) = None,
band: int = 1,
) -> np.floating[Any] | dict[str, np.floating[Any]]:
"""
Retrieve specified statistics or all available statistics for the raster data. Allows passing custom callables
to calculate custom stats.
Expand Down Expand Up @@ -1944,10 +1952,15 @@ def get_stats(
result[name] = self._get_single_stat(stats_dict, stats_aliases, name)
return result
else:
return self._get_single_stat(stats_dict, stats_aliases, stats_name)
if callable(stats_name):
return stats_name(self.data[band] if self.count > 1 else self.data)
else:
return self._get_single_stat(stats_dict, stats_aliases, stats_name)

@staticmethod
def _get_single_stat(stats_dict: dict[str, float], stats_aliases: dict[str, str], stat_name: str) -> float:
def _get_single_stat(
stats_dict: dict[str, np.floating[Any]], stats_aliases: dict[str, str], stat_name: str
) -> np.floating[Any]:
"""
Retrieve a single statistic based on a flexible name or alias.
Expand All @@ -1964,7 +1977,7 @@ def _get_single_stat(stats_dict: dict[str, float], stats_aliases: dict[str, str]
return stats_dict[actual_name]
else:
logging.warning("Statistic name '%s' is not recognized", stat_name)
return np.nan
return np.floating(np.nan)

def _nmad(self, nfact: float = 1.4826, band: int = 0) -> np.floating[Any]:
"""
Expand All @@ -1977,6 +1990,8 @@ def _nmad(self, nfact: float = 1.4826, band: int = 0) -> np.floating[Any]:
data = self.data
else:
data = self.data[band]
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()
return nfact * np.nanmedian(np.abs(data - np.nanmedian(data)))

@overload
Expand Down
7 changes: 6 additions & 1 deletion tests/test_raster/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,10 +1971,15 @@ def test_stats(self, example: str) -> None:
stat = raster.get_stats(stats_name="Average")
assert isinstance(stat, np.floating)

# Selected stats and callable
def percentile_95(data: NDArrayNum) -> np.floating[Any]:
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()
return np.nanpercentile(data, 95)

stat = raster.get_stats(stats_name=percentile_95)
assert isinstance(stat, np.floating)

# Selected stats and callable
stats_name = ["mean", "maximum", "std", "percentile_95"]
stats = raster.get_stats(stats_name=["mean", "maximum", "std", percentile_95])
for name in stats_name:
Expand Down

0 comments on commit 2311ee6

Please sign in to comment.