diff --git a/raster2dggs/h3.py b/raster2dggs/h3.py index d8b8b1b..2fddf11 100644 --- a/raster2dggs/h3.py +++ b/raster2dggs/h3.py @@ -5,6 +5,7 @@ import multiprocessing from numbers import Number import numpy as np +import math from pathlib import Path import tempfile import threading @@ -15,16 +16,22 @@ import click_log import dask import dask.dataframe as dd +import geopandas as gpd +import h3 as h3py import h3pandas # Necessary import despite lack of explicit use import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +import pyproj import rasterio as rio from rasterio import crs from rasterio.enums import Resampling from rasterio.vrt import WarpedVRT from rasterio.warp import calculate_default_transform import rioxarray +from shapely import segmentize +from shapely.geometry import box, mapping +from shapely.ops import transform from tqdm import tqdm from tqdm.dask import TqdmCallback import xarray as xr @@ -45,7 +52,7 @@ "decimals": 1, "warp_mem_limit": 12000, "resampling": "average", - "tempdir": tempfile.tempdir, + "tempdir": tempfile.gettempdir(), } DEFAULT_PARENT_OFFSET = 6 @@ -70,12 +77,21 @@ def _get_parent_res(parent_res: Union[None, int], resolution: int) -> int: ) +def h3_coords_generator(h3_cells): + for cell in h3_cells: + yield h3py.h3_to_geo(cell) + +def sample_values_generator(sdf, coords_generator): + for lat, lon in coords_generator: + yield sdf.sel(y=lat, x=lon, method='nearest').values[0] + def _h3func( sdf: xr.DataArray, resolution: int, parent_res: int, nodata: Number = np.nan, band_labels: Tuple[str] = None, + coverage: bool = False ) -> pa.Table: """ Index a raster window to H3. @@ -83,16 +99,37 @@ def _h3func( If windows are very small, or in strips rather than blocks, processing may be slower than necessary and the recommendation is to write different windows in the source raster. """ - sdf: pd.DataFrame = sdf.to_dataframe().drop(columns=["spatial_ref"]).reset_index() - subset: pd.DataFrame = sdf.dropna() - subset = subset[subset.value != nodata] - subset = pd.pivot_table( - subset, values=DEFAULT_NAME, index=["x", "y"], columns=["band"] - ).reset_index() - # Primary H3 index - h3index = subset.h3.geo_to_h3(resolution, lat_col="y", lng_col="x").drop( - columns=["x", "y"] - ) + + if not coverage: + sdf: pd.DataFrame = sdf.to_dataframe().drop(columns=["spatial_ref"]).reset_index() + # Cells are considered points; there may be gaps between DGGS cells depending on raster resolution + subset: pd.DataFrame = sdf.dropna() + subset = subset[subset.value != nodata] + subset = pd.pivot_table( + subset, values=DEFAULT_NAME, index=["x", "y"], columns=["band"] + ).reset_index() + # Primary H3 index + h3index = subset.h3.geo_to_h3(resolution, lat_col="y", lng_col="x").drop( + columns=["x", "y"] + ) + else: + # Cells are considered areas; the raster is a coverage and there will not be gaps in the DGGS grid + h3_cells = h3py.polyfill_polygon(mapping(box(sdf['y'].min(), sdf['x'].min(), sdf['y'].max(), sdf['x'].max()))["coordinates"][0], res=resolution) + + coords_gen = h3_coords_generator(h3_cells) + values_gen = sample_values_generator(sdf, coords_gen) + sampled_values = list(values_gen) + + h3index = pd.DataFrame({ + f'h3_{resolution:02}': list(h3_cells), + DEFAULT_NAME: sampled_values, + 'band': sdf.band.values[0] + }) + h3index = pd.pivot_table( + h3index, values=DEFAULT_NAME, index=f'h3_{resolution:02}', columns=["band"] + ) + sdf: pd.DataFrame = sdf.to_dataframe().drop(columns=["spatial_ref"]).reset_index() + # Secondary (parent) H3 index, used later for partitioning h3index = h3index.h3.h3_to_parent(parent_res).reset_index() # Renaming columns to actual band labels @@ -191,6 +228,7 @@ def process(window): parent_res, vrt.nodata, band_labels=band_names, + coverage=kwargs['coverage'], ) with write_lock: @@ -369,6 +407,12 @@ def _address_boundary_issues( type=click.Choice(Resampling._member_names_), help="Input raster may be warped to EPSG:4326 if it is not already in this CRS. Or, if the upscale parameter is greater than 1, there is a need to resample. This setting specifies this resampling algorithm.", ) +@click.option( + "--raster_model", + type=click.Choice(['area', 'point']), + default='point', + help="If area, the input raster as a coverage, where cells are considered areas. The output grid will cover the extent of the input with no gaps. If point, raster cells are treated as point samples rather than coverages. The output grid may have gaps where DGGS cells do not coincide with any input cells." +) @click.option( "--tempdir", default=DEFAULTS["tempdir"], @@ -389,6 +433,7 @@ def h3( overwrite: bool, warp_mem_limit: int, resampling: str, + raster_model: bool, tempdir: Union[str, Path], ): """ @@ -397,7 +442,7 @@ def h3( RASTER_INPUT is the path to input raster data; prepend with protocol like s3:// or hdfs:// for remote data. OUTPUT_DIRECTORY should be a directory, not a file, as it will be the write location for an Apache Parquet data store, with partitions equivalent to parent cells of target cells at a fixed offset. However, this can also be remote (use the appropriate prefix, e.g. s3://). """ - tempfile.tempdir = tempdir if tempdir is not None else tempfile.tempdir + tempfile.tempdir = tempdir if tempdir is not None else tempfile.gettempdir() if parent_res is not None and not int(parent_res) < int(resolution): raise ParentResolutionException( "Parent resolution ({pr}) must be less than target resolution ({r})".format( @@ -437,6 +482,7 @@ def h3( "warp_mem_limit": warp_mem_limit, "resampling": resampling, "overwrite": overwrite, + "coverage": True if raster_model == 'area' else False, } _initial_index( raster_input,