Skip to content

Commit

Permalink
Add cyano band ratios
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 12, 2024
1 parent 7f3c1d6 commit e8572ce
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions hypercoast/pace.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,75 @@ def pace_chla_to_image(data, output=None, **kwargs):
image_to_geotiff(image, output, dtype="float32")

return image


def cyano_band_ratios(
dataset: xr.Dataset | str,
plot: bool = True,
extent: List[float] = None,
figsize: tuple[int, int] = (12, 6),
**kwargs,
) -> xr.DataArray:
"""
Calculates cyanobacteria band ratios from PACE data.
Args:
dataset (xr.Dataset or str): The dataset containing the PACE data or the file path to the dataset.
plot (bool, optional): Whether to plot the data. Defaults to True.
extent (list, optional): The extent of the plot. Defaults to None.
figsize (tuple, optional): Figure size. Defaults to (12, 6).
**kwargs: Additional keyword arguments to pass to the `plt.subplots` function.
Returns:
xr.DataArray: The cyanobacteria band ratios.
"""
import cartopy.crs as ccrs
import cartopy.feature as cfeature

if isinstance(dataset, str):
dataset = read_pace(dataset)
elif not isinstance(dataset, xr.Dataset):
raise ValueError("dataset must be an xarray Dataset")

da = dataset["Rrs"]
data = (
(da.sel(wavelength=650) > da.sel(wavelength=620))
& (da.sel(wavelength=701) > da.sel(wavelength=681))
& (da.sel(wavelength=701) > da.sel(wavelength=450))
)

if plot:
# Create a plot
_, ax = plt.subplots(
figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, **kwargs
)

if extent is not None:
ax.set_extent(extent, crs=ccrs.PlateCarree())

# Plot the data
data.plot(
ax=ax,
transform=ccrs.PlateCarree(),
cmap="coolwarm",
cbar_kwargs={"label": "Cyano"},
)

# Add coastlines
ax.coastlines()

# Add state boundaries
states_provinces = cfeature.NaturalEarthFeature(
category="cultural",
name="admin_1_states_provinces_lines",
scale="50m",
facecolor="none",
)

ax.add_feature(states_provinces, edgecolor="gray")

# Optionally, add gridlines, labels, etc.
ax.gridlines(draw_labels=True)
plt.show()

return data

0 comments on commit e8572ce

Please sign in to comment.