diff --git a/README.md b/README.md index 1779a19..7de3f8a 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,30 @@ shape: (1, 1) └───────────────────────┘ ``` +### H3 Spatial Index +```python +df = pl.DataFrame( + {"coord": [{"longitude": -120.6623, "latitude": 35.3003}]}, + schema={ + "coord": pl.Struct( + [pl.Field("longitude", pl.Float64), pl.Field("latitude", pl.Float64)] + ), + }, +) + +df.with_columns( + plh.col('coord').h3.from_coords().alias('h3') +) +shape: (1, 2) +┌─────────────────────┬─────────────────┐ +│ coord ┆ h3 │ +│ --- ┆ --- │ +│ struct[2] ┆ str │ +╞═════════════════════╪═════════════════╡ +│ {-120.6623,35.3003} ┆ 8c29adc423821ff │ +└─────────────────────┴─────────────────┘ +``` + ## Create hash from multiple columns ```python diff --git a/polars_hash/Cargo.toml b/polars_hash/Cargo.toml index df63ab7..e7c88d9 100644 --- a/polars_hash/Cargo.toml +++ b/polars_hash/Cargo.toml @@ -20,6 +20,7 @@ sha2 = { version = "0.10.8" } sha3 = { version = "0.10.8" } blake3 = { version = "1.5.0" } md5 = {version = "0.7.0"} +h3o = { version = "0.6.2" } [target.'cfg(target_os = "linux")'.dependencies] diff --git a/polars_hash/polars_hash/__init__.py b/polars_hash/polars_hash/__init__.py index b8c3907..0c04be5 100644 --- a/polars_hash/polars_hash/__init__.py +++ b/polars_hash/polars_hash/__init__.py @@ -178,6 +178,21 @@ def neighbors(self) -> pl.Expr: ) +@pl.api.register_expr_namespace("h3") +class H3NameSpace: + def __init__(self, expr: pl.Expr): + self._expr = expr + + def from_coords(self, len: int = 12) -> pl.Expr: + """Takes Struct with latitude, longitude as input and returns utf8 H3 spatial index.""" + return register_plugin_function( + plugin_path=Path(__file__).parent, + args=[self._expr, len], + function_name="h3_encode", + is_elementwise=True, + ) + + class HExpr(pl.Expr): @property def chash(self) -> CryptographicHashingNameSpace: @@ -191,6 +206,10 @@ def nchash(self) -> NonCryptographicHashingNameSpace: def geohash(self) -> GeoHashingNameSpace: return GeoHashingNameSpace(self) + @property + def h3(self) -> H3NameSpace: + return H3NameSpace(self) + class HashColumn(Protocol): def __call__( diff --git a/polars_hash/src/expressions.rs b/polars_hash/src/expressions.rs index 78b4207..47ab732 100644 --- a/polars_hash/src/expressions.rs +++ b/polars_hash/src/expressions.rs @@ -1,4 +1,5 @@ use crate::geohashers::{geohash_decoder, geohash_encoder, geohash_neighbors}; +use crate::h3::h3_encoder; use crate::sha_hashers::*; use polars::{ chunked_array::ops::arity::{try_binary_elementwise, try_ternary_elementwise}, @@ -212,6 +213,49 @@ fn ghash_encode(inputs: &[Series]) -> PolarsResult { Ok(out.into_series()) } +#[polars_expr(output_type=String)] +fn h3_encode(inputs: &[Series]) -> PolarsResult { + let ca = inputs[0].struct_()?; + let len = match inputs[1].dtype() { + DataType::Int64 => inputs[1].clone(), + DataType::Int32 => inputs[1].cast(&DataType::Int64)?, + DataType::Int16 => inputs[1].cast(&DataType::Int64)?, + DataType::Int8 => inputs[1].cast(&DataType::Int64)?, + _ => polars_bail!(InvalidOperation:"Length input needs to be integer"), + }; + let len = len.i64()?; + + let lat = ca.field_by_name("latitude")?; + let long = ca.field_by_name("longitude")?; + let lat = match lat.dtype() { + DataType::Float32 => lat.cast(&DataType::Float64)?, + DataType::Float64 => lat, + _ => polars_bail!(InvalidOperation:"Latitude input needs to be float"), + }; + + let long = match long.dtype() { + DataType::Float32 => long.cast(&DataType::Float64)?, + DataType::Float64 => long, + _ => polars_bail!(InvalidOperation:"Longitude input needs to be float"), + }; + + let ca_lat = lat.f64()?; + let ca_long = long.f64()?; + + let out: StringChunked = match len.len() { + 1 => match unsafe { len.get_unchecked(0) } { + Some(len) => try_binary_elementwise(ca_lat, ca_long, |ca_lat_opt, ca_long_opt| { + h3_encoder(ca_lat_opt, ca_long_opt, Some(len)) + }), + _ => Err(PolarsError::ComputeError( + "Length may not be null".to_string().into(), + )), + }, + _ => try_ternary_elementwise(ca_lat, ca_long, len, h3_encoder), + }?; + Ok(out.into_series()) +} + pub fn geohash_decode_output(field: &[Field]) -> PolarsResult { let v: Vec = vec![ Field::new("longitude", Float64), diff --git a/polars_hash/src/h3.rs b/polars_hash/src/h3.rs new file mode 100644 index 0000000..3fbee1f --- /dev/null +++ b/polars_hash/src/h3.rs @@ -0,0 +1,53 @@ +use h3o::{LatLng, Resolution}; +use polars::prelude::*; + +fn get_resolution(resolution: i64) -> PolarsResult { + match resolution { + 1 => Ok(Resolution::One), + 2 => Ok(Resolution::Two), + 3 => Ok(Resolution::Three), + 4 => Ok(Resolution::Four), + 5 => Ok(Resolution::Five), + 6 => Ok(Resolution::Six), + 7 => Ok(Resolution::Seven), + 8 => Ok(Resolution::Eight), + 9 => Ok(Resolution::Nine), + 10 => Ok(Resolution::Ten), + 11 => Ok(Resolution::Eleven), + 12 => Ok(Resolution::Twelve), + 13 => Ok(Resolution::Thirteen), + 14 => Ok(Resolution::Fourteen), + 15 => Ok(Resolution::Fifteen), + _ => { + polars_bail!(InvalidOperation: "expected resolution between 1 and 15, got {}", resolution) + } + } +} + +pub fn h3_encoder( + lat: Option, + long: Option, + len: Option, +) -> PolarsResult> { + match (lat, long) { + (Some(lat), Some(long)) => match len { + Some(len) => Ok(Some( + LatLng::new(lat, long) + .expect("valid coord") + .to_cell(get_resolution(len)?) + .to_string(), + )), + _ => Err(PolarsError::ComputeError( + "Length may not be null".to_string().into(), + )), + }, + _ => Err(PolarsError::ComputeError( + format!( + "Coordinates cannot be null. + Provided latitude: {:?}, longitude: {:?}", + lat, long + ) + .into(), + )), + } +} diff --git a/polars_hash/src/lib.rs b/polars_hash/src/lib.rs index b163907..f3060a0 100644 --- a/polars_hash/src/lib.rs +++ b/polars_hash/src/lib.rs @@ -1,5 +1,6 @@ mod expressions; mod geohashers; +mod h3; mod sha_hashers; use pyo3::types::PyModule; use pyo3::{pymodule, PyResult, Python}; diff --git a/polars_hash/tests/test_hash.py b/polars_hash/tests/test_hash.py index 7f0687e..4394c93 100644 --- a/polars_hash/tests/test_hash.py +++ b/polars_hash/tests/test_hash.py @@ -138,6 +138,26 @@ def test_geohash(): ) +def test_h3(): + df = pl.DataFrame( + {"coord": [{"longitude": -120.6623, "latitude": 35.3003}]}, + schema={ + "coord": pl.Struct( + [pl.Field("longitude", pl.Float64), pl.Field("latitude", pl.Float64)] + ), + }, + ) + + result = df.select(pl.col("coord").h3.from_coords(5)) # type: ignore + + expected = pl.DataFrame( + [ + pl.Series("coord", ["8529adc7fffffff"], dtype=pl.Utf8), + ] + ) + assert_frame_equal(result, expected) + + def test_lazy_name(): result = ( pl.from_dicts({"h1": "sp1xk2m6194y"})