Skip to content

Commit

Permalink
interactive plotting with lonboard (#67)
Browse files Browse the repository at this point in the history
* draft of the explore functionality based on `lonboard`

* expose the draft

* docstring for `explore`

* use the proper import from `matplotlib` for the standard normalizer

* use `pandas` to compute the half range

* specify `vmin` and `vmax`

* explicitly specify `skipna`

* refactor to construct the polygons as geoarrow

This has speed implications: before, we've been using the shapely →
geopandas → geoarrow route, which adds significant overhead.

Co-authored-by: Kyle Barron <[email protected]>

* allow choosing the transparency

* move the healpix-specific code to `xdggs.healpix`

In the future, we might want to move the conversion code to a separate module.

* change the signature of `cell_boundaries`

* also have h3 support the geoarrow cell boundaries

* fix several typos

* adjust the tests to fit the new dateline fix algorithm

* update the list of extra dependencies for `explore`

* mention that this only works for 1D arrays for now

* add optional dependencies for `explore`

* explicitly use `pyproj` to construct the crs string

* raise an error if the geometry type is wrong

* wrap the output of `crs.to_json()` in a json dict

* verify that both backends produce the same polygons

* move the data normalization to a separate function

* move the table creation to a separate function

* check that the normalization works properly

* coverage configuration

* pass on `center` to the normalizing function

* skip the plotting module if `arro3.core` is not available

* use `json` to construct the arrow extension metadata

* Revert "skip the plotting module if `arro3.core` is not available"

This reverts commit 5a71559.

* workaround for missing macos-arm packages of arro3-core

* install arro3-core using `pip`

* always include the cell ids in the table data

closes keewis#1

---------

Co-authored-by: Kyle Barron <[email protected]>
  • Loading branch information
keewis and kylebarron authored Oct 14, 2024
1 parent b3b5a40 commit 774821d
Show file tree
Hide file tree
Showing 12 changed files with 408 additions and 30 deletions.
3 changes: 3 additions & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ dependencies:
- hypothesis
- ruff
- typing-extensions
- geoarrow-pyarrow
- lonboard
- pip
- pip:
- arro3-core
- h3ronpy
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ dependencies = [
"typing-extensions",
]

[project.optional-dependencies]
explore = [
"lonboard>=0.9.3",
"pyproj>=3.3",
"matplotlib",
"arro3-core>=0.4.0"
]

[project.urls]
# Home = "https://xdggs.readthedocs.io"
Repository = "https://github.com/xarray-contrib/xdggs"
Expand Down
36 changes: 36 additions & 0 deletions xdggs/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from xdggs.grid import DGGSInfo
from xdggs.index import DGGSIndex
from xdggs.plotting import explore


@xr.register_dataset_accessor("dggs")
Expand Down Expand Up @@ -115,3 +116,38 @@ def cell_boundaries(self):
return xr.DataArray(
boundaries, coords={self._name: self.cell_ids}, dims=self.cell_ids.dims
)

def explore(self, *, cmap="viridis", center=None, alpha=None):
"""interactively explore the data using `lonboard`
Requires `lonboard`, `matplotlib`, and `arro3.core` to be installed.
Parameters
----------
cmap : str
The name of the color map to use
center : int or float, optional
If set, will use this as the center value of a diverging color map.
alpha : float, optional
If set, controls the transparency of the polygons.
Returns
-------
map : lonboard.Map
The rendered map.
Notes
-----
Plotting currently is restricted to 1D `DataArray` objects.
"""
if isinstance(self._obj, xr.Dataset):
raise ValueError("does not work with Dataset objects, yet")

cell_dim = self._obj[self._name].dims[0]
return explore(
self._obj,
cell_dim=cell_dim,
cmap=cmap,
center=center,
alpha=alpha,
)
2 changes: 1 addition & 1 deletion xdggs/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ def cell_ids2geographic(self, cell_ids):
def geographic2cell_ids(self, lon, lat):
raise NotImplementedError()

def cell_boundaries(self, cell_ids):
def cell_boundaries(self, cell_ids, backend="shapely"):
raise NotImplementedError()
50 changes: 47 additions & 3 deletions xdggs/h3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, ClassVar
Expand All @@ -8,7 +9,6 @@
from typing_extensions import Self

import numpy as np
import shapely
import xarray as xr
from h3ronpy.arrow.vector import (
cells_to_coordinates,
Expand All @@ -22,6 +22,42 @@
from xdggs.utils import _extract_cell_id_variable, register_dggs


def polygons_shapely(wkb):
import shapely

return shapely.from_wkb(wkb)


def polygons_geoarrow(wkb):
import pyproj
import shapely
from arro3.core import list_array

polygons = shapely.from_wkb(wkb)
crs = pyproj.CRS.from_epsg(4326)

geometry_type, coords, (ring_offsets, geom_offsets) = shapely.to_ragged_array(
polygons
)

if geometry_type != shapely.GeometryType.POLYGON:
raise ValueError(f"unsupported geometry type found: {geometry_type}")

polygon_array = list_array(
geom_offsets.astype("int32"), list_array(ring_offsets.astype("int32"), coords)
)
polygon_array_with_geo_meta = polygon_array.cast(
polygon_array.field.with_metadata(
{
"ARROW:extension:name": "geoarrow.polygon",
"ARROW:extension:metadata": json.dumps({"crs": crs.to_json_dict()}),
}
)
)

return polygon_array_with_geo_meta


@dataclass(frozen=True)
class H3Info(DGGSInfo):
resolution: int
Expand Down Expand Up @@ -50,10 +86,18 @@ def cell_ids2geographic(
def geographic2cell_ids(self, lon, lat):
return coordinates_to_cells(lat, lon, self.resolution, radians=False)

def cell_boundaries(self, cell_ids):
def cell_boundaries(self, cell_ids, backend="shapely"):
# TODO: convert cell ids directly to geoarrow once h3ronpy supports it
wkb = cells_to_wkb_polygons(cell_ids, radians=False, link_cells=False)

return shapely.from_wkb(wkb)
backends = {
"shapely": polygons_shapely,
"geoarrow": polygons_geoarrow,
}
backend_func = backends.get(backend)
if backend_func is None:
raise ValueError("invalid backend: {backend!r}")
return backend_func(wkb)


@register_dggs("h3")
Expand Down
93 changes: 83 additions & 10 deletions xdggs/healpix.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import operator
from collections.abc import Mapping
from dataclasses import dataclass, field
Expand Down Expand Up @@ -26,6 +27,72 @@
from exceptiongroup import ExceptionGroup


def polygons_shapely(vertices):
import shapely

return shapely.polygons(vertices)


def polygons_geoarrow(vertices):
import pyproj
from arro3.core import list_array

polygon_vertices = np.concatenate([vertices, vertices[:, :1, :]], axis=1)
crs = pyproj.CRS.from_epsg(4326)

# construct geoarrow arrays
coords = np.reshape(polygon_vertices, (-1, 2))
coords_per_pixel = polygon_vertices.shape[1]
geom_offsets = np.arange(vertices.shape[0] + 1, dtype="int32")
ring_offsets = geom_offsets * coords_per_pixel

polygon_array = list_array(geom_offsets, list_array(ring_offsets, coords))

# We need to tag the array with extension metadata (`geoarrow.polygon`) so that Lonboard knows that this is a geospatial column.
polygon_array_with_geo_meta = polygon_array.cast(
polygon_array.field.with_metadata(
{
"ARROW:extension:name": "geoarrow.polygon",
"ARROW:extension:metadata": json.dumps(
{"crs": crs.to_json_dict(), "edges": "spherical"}
),
}
)
)
return polygon_array_with_geo_meta


def center_around_prime_meridian(lon, lat):
# three tasks:
# - center around the prime meridian (map to a range of [-180, 180])
# - replace the longitude of points at the poles with the median
# of longitude of the other vertices
# - cells that cross the dateline should have longitudes around 180

# center around prime meridian
recentered = (lon + 180) % 360 - 180

# replace lon of pole with the median of the remaining vertices
contains_poles = np.isin(lat, np.array([-90, 90]))
pole_cells = np.any(contains_poles, axis=-1)
recentered[contains_poles] = np.median(
np.reshape(
recentered[pole_cells[:, None] & np.logical_not(contains_poles)], (-1, 3)
),
axis=-1,
)

# keep cells that cross the dateline centered around 180
polygons_to_fix = np.any(recentered < -100, axis=-1) & np.any(
recentered > 100, axis=-1
)
result = np.where(
polygons_to_fix[:, None] & (recentered < 0), recentered + 360, recentered
)

return result


@dataclass(frozen=True)
class HealpixInfo(DGGSInfo):
resolution: int
Expand Down Expand Up @@ -135,23 +202,29 @@ def cell_ids2geographic(self, cell_ids):
def geographic2cell_ids(self, lon, lat):
return healpy.ang2pix(self.nside, lon, lat, lonlat=True, nest=self.nest)

def cell_boundaries(self, cell_ids: Any) -> np.ndarray:
import shapely

def cell_boundaries(self, cell_ids: Any, backend="shapely") -> np.ndarray:
boundary_vectors = healpy.boundaries(
self.nside, cell_ids, step=1, nest=self.nest
)

lon, lat = healpy.vec2ang(np.moveaxis(boundary_vectors, 1, -1), lonlat=True)
boundaries = np.reshape(np.stack((lon, lat), axis=-1), (-1, 4, 2))
lon_reshaped = np.reshape(lon, (-1, 4))
lat_reshaped = np.reshape(lat, (-1, 4))

lon_ = center_around_prime_meridian(lon_reshaped, lat_reshaped)

vertices = np.stack((lon_, lat_reshaped), axis=-1)

backends = {
"shapely": polygons_shapely,
"geoarrow": polygons_geoarrow,
}

# fix the dateline / prime meridian issue
lon_ = boundaries[..., 0]
to_fix = abs(np.max(lon_, axis=-1) - np.min(lon_, axis=-1)) > 300
fixed_lon = (lon_[to_fix, :] + 180) % 360 - 180
boundaries[to_fix, :, 0] = fixed_lon
backend_func = backends.get(backend)
if backend_func is None:
raise ValueError("invalid backend: {backend!r}")

return shapely.polygons(boundaries)
return backend_func(vertices)


@register_dggs("healpix")
Expand Down
72 changes: 72 additions & 0 deletions xdggs/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np


def create_arrow_table(polygons, arr, coords=None):
from arro3.core import Array, ChunkedArray, Schema, Table

if coords is None:
coords = ["latitude", "longitude"]

array = Array.from_arrow(polygons)
name = arr.name or "data"
arrow_arrays = {
"geometry": array,
"cell_ids": ChunkedArray([Array.from_numpy(arr.coords["cell_ids"])]),
name: ChunkedArray([Array.from_numpy(arr.data)]),
} | {
coord: ChunkedArray([Array.from_numpy(arr.coords[coord].data)])
for coord in coords
if coord in arr.coords
}

fields = [array.field.with_name(name) for name, array in arrow_arrays.items()]
schema = Schema(fields)

return Table.from_arrays(list(arrow_arrays.values()), schema=schema)


def normalize(var, center=None):
from matplotlib.colors import CenteredNorm, Normalize

if center is None:
vmin = var.min(skipna=True)
vmax = var.max(skipna=True)
normalizer = Normalize(vmin=vmin, vmax=vmax)
else:
halfrange = np.abs(var - center).max(skipna=True)
normalizer = CenteredNorm(vcenter=center, halfrange=halfrange)

return normalizer(var.data)


def explore(
arr,
cell_dim="cells",
cmap="viridis",
center=None,
alpha=None,
):
import lonboard
from lonboard import SolidPolygonLayer
from lonboard.colormap import apply_continuous_cmap
from matplotlib import colormaps

if len(arr.dims) != 1 or cell_dim not in arr.dims:
raise ValueError(
f"exploration only works with a single dimension ('{cell_dim}')"
)

cell_ids = arr.dggs.coord.data
grid_info = arr.dggs.grid_info

polygons = grid_info.cell_boundaries(cell_ids, backend="geoarrow")

normalized_data = normalize(arr.variable, center=center)

colormap = colormaps[cmap]
colors = apply_continuous_cmap(normalized_data, colormap, alpha=alpha)

table = create_arrow_table(polygons, arr)
layer = SolidPolygonLayer(table=table, filled=True, get_fill_color=colors)

return lonboard.Map(layer)
7 changes: 7 additions & 0 deletions xdggs/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import geoarrow.pyarrow as ga
import shapely

from xdggs.tests.matchers import ( # noqa: F401
Match,
MatchResult,
assert_exceptions_equal,
)


def geoarrow_to_shapely(arr):
return shapely.from_wkb(ga.as_wkb(arr))
11 changes: 8 additions & 3 deletions xdggs/tests/test_h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from xarray.core.indexes import PandasIndex

from xdggs import h3
from xdggs.tests import geoarrow_to_shapely

# from the h3 gallery, at resolution 3
cell_ids = [
Expand Down Expand Up @@ -202,14 +203,18 @@ def test_geographic2cell_ids(self, cell_centers, cell_ids):
),
),
)
def test_cell_boundaries(self, resolution, cell_ids, expected_coords):
@pytest.mark.parametrize("backend", ["shapely", "geoarrow"])
def test_cell_boundaries(self, resolution, cell_ids, backend, expected_coords):
expected = shapely.polygons(expected_coords)

grid = h3.H3Info(resolution=resolution)

actual = grid.cell_boundaries(cell_ids)
backends = {"shapely": lambda arr: arr, "geoarrow": geoarrow_to_shapely}
converter = backends[backend]

shapely.testing.assert_geometries_equal(actual, expected)
actual = grid.cell_boundaries(cell_ids, backend=backend)

shapely.testing.assert_geometries_equal(converter(actual), expected)


@pytest.mark.parametrize("resolution", resolutions)
Expand Down
Loading

0 comments on commit 774821d

Please sign in to comment.