Skip to content

Commit

Permalink
add h3 spatial index (#24)
Browse files Browse the repository at this point in the history
* add h3 spatial index

* Update polars_hash/polars_hash/__init__.py

* Update polars_hash/polars_hash/__init__.py
  • Loading branch information
MarcoGorelli authored Apr 11, 2024
1 parent d8406e0 commit 24dc552
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 0 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,30 @@ shape: (1, 1)
└───────────────────────┘
```

### H3 Spatial Index
```python
df = pl.DataFrame(
{"coord": [{"longitude": -120.6623, "latitude": 35.3003}]},
schema={
"coord": pl.Struct(
[pl.Field("longitude", pl.Float64), pl.Field("latitude", pl.Float64)]
),
},
)

df.with_columns(
plh.col('coord').h3.from_coords().alias('h3')
)
shape: (1, 2)
┌─────────────────────┬─────────────────┐
│ coord ┆ h3 │
------
│ struct[2] ┆ str
╞═════════════════════╪═════════════════╡
│ {-120.6623,35.3003} ┆ 8c29adc423821ff
└─────────────────────┴─────────────────┘
```


## Create hash from multiple columns
```python
Expand Down
1 change: 1 addition & 0 deletions polars_hash/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ sha2 = { version = "0.10.8" }
sha3 = { version = "0.10.8" }
blake3 = { version = "1.5.0" }
md5 = {version = "0.7.0"}
h3o = { version = "0.6.2" }


[target.'cfg(target_os = "linux")'.dependencies]
Expand Down
19 changes: 19 additions & 0 deletions polars_hash/polars_hash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,21 @@ def neighbors(self) -> pl.Expr:
)


@pl.api.register_expr_namespace("h3")
class H3NameSpace:
def __init__(self, expr: pl.Expr):
self._expr = expr

def from_coords(self, len: int = 12) -> pl.Expr:
"""Takes Struct with latitude, longitude as input and returns utf8 H3 spatial index."""
return register_plugin_function(
plugin_path=Path(__file__).parent,
args=[self._expr, len],
function_name="h3_encode",
is_elementwise=True,
)


class HExpr(pl.Expr):
@property
def chash(self) -> CryptographicHashingNameSpace:
Expand All @@ -191,6 +206,10 @@ def nchash(self) -> NonCryptographicHashingNameSpace:
def geohash(self) -> GeoHashingNameSpace:
return GeoHashingNameSpace(self)

@property
def h3(self) -> H3NameSpace:
return H3NameSpace(self)


class HashColumn(Protocol):
def __call__(
Expand Down
44 changes: 44 additions & 0 deletions polars_hash/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::geohashers::{geohash_decoder, geohash_encoder, geohash_neighbors};
use crate::h3::h3_encoder;
use crate::sha_hashers::*;
use polars::{
chunked_array::ops::arity::{try_binary_elementwise, try_ternary_elementwise},
Expand Down Expand Up @@ -212,6 +213,49 @@ fn ghash_encode(inputs: &[Series]) -> PolarsResult<Series> {
Ok(out.into_series())
}

#[polars_expr(output_type=String)]
fn h3_encode(inputs: &[Series]) -> PolarsResult<Series> {
let ca = inputs[0].struct_()?;
let len = match inputs[1].dtype() {
DataType::Int64 => inputs[1].clone(),
DataType::Int32 => inputs[1].cast(&DataType::Int64)?,
DataType::Int16 => inputs[1].cast(&DataType::Int64)?,
DataType::Int8 => inputs[1].cast(&DataType::Int64)?,
_ => polars_bail!(InvalidOperation:"Length input needs to be integer"),
};
let len = len.i64()?;

let lat = ca.field_by_name("latitude")?;
let long = ca.field_by_name("longitude")?;
let lat = match lat.dtype() {
DataType::Float32 => lat.cast(&DataType::Float64)?,
DataType::Float64 => lat,
_ => polars_bail!(InvalidOperation:"Latitude input needs to be float"),
};

let long = match long.dtype() {
DataType::Float32 => long.cast(&DataType::Float64)?,
DataType::Float64 => long,
_ => polars_bail!(InvalidOperation:"Longitude input needs to be float"),
};

let ca_lat = lat.f64()?;
let ca_long = long.f64()?;

let out: StringChunked = match len.len() {
1 => match unsafe { len.get_unchecked(0) } {
Some(len) => try_binary_elementwise(ca_lat, ca_long, |ca_lat_opt, ca_long_opt| {
h3_encoder(ca_lat_opt, ca_long_opt, Some(len))
}),
_ => Err(PolarsError::ComputeError(
"Length may not be null".to_string().into(),
)),
},
_ => try_ternary_elementwise(ca_lat, ca_long, len, h3_encoder),
}?;
Ok(out.into_series())
}

pub fn geohash_decode_output(field: &[Field]) -> PolarsResult<Field> {
let v: Vec<Field> = vec![
Field::new("longitude", Float64),
Expand Down
53 changes: 53 additions & 0 deletions polars_hash/src/h3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use h3o::{LatLng, Resolution};
use polars::prelude::*;

fn get_resolution(resolution: i64) -> PolarsResult<Resolution> {
match resolution {
1 => Ok(Resolution::One),
2 => Ok(Resolution::Two),
3 => Ok(Resolution::Three),
4 => Ok(Resolution::Four),
5 => Ok(Resolution::Five),
6 => Ok(Resolution::Six),
7 => Ok(Resolution::Seven),
8 => Ok(Resolution::Eight),
9 => Ok(Resolution::Nine),
10 => Ok(Resolution::Ten),
11 => Ok(Resolution::Eleven),
12 => Ok(Resolution::Twelve),
13 => Ok(Resolution::Thirteen),
14 => Ok(Resolution::Fourteen),
15 => Ok(Resolution::Fifteen),
_ => {
polars_bail!(InvalidOperation: "expected resolution between 1 and 15, got {}", resolution)
}
}
}

pub fn h3_encoder(
lat: Option<f64>,
long: Option<f64>,
len: Option<i64>,
) -> PolarsResult<Option<String>> {
match (lat, long) {
(Some(lat), Some(long)) => match len {
Some(len) => Ok(Some(
LatLng::new(lat, long)
.expect("valid coord")
.to_cell(get_resolution(len)?)
.to_string(),
)),
_ => Err(PolarsError::ComputeError(
"Length may not be null".to_string().into(),
)),
},
_ => Err(PolarsError::ComputeError(
format!(
"Coordinates cannot be null.
Provided latitude: {:?}, longitude: {:?}",
lat, long
)
.into(),
)),
}
}
1 change: 1 addition & 0 deletions polars_hash/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod expressions;
mod geohashers;
mod h3;
mod sha_hashers;
use pyo3::types::PyModule;
use pyo3::{pymodule, PyResult, Python};
Expand Down
20 changes: 20 additions & 0 deletions polars_hash/tests/test_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,26 @@ def test_geohash():
)


def test_h3():
df = pl.DataFrame(
{"coord": [{"longitude": -120.6623, "latitude": 35.3003}]},
schema={
"coord": pl.Struct(
[pl.Field("longitude", pl.Float64), pl.Field("latitude", pl.Float64)]
),
},
)

result = df.select(pl.col("coord").h3.from_coords(5)) # type: ignore

expected = pl.DataFrame(
[
pl.Series("coord", ["8529adc7fffffff"], dtype=pl.Utf8),
]
)
assert_frame_equal(result, expected)


def test_lazy_name():
result = (
pl.from_dicts({"h1": "sp1xk2m6194y"})
Expand Down

0 comments on commit 24dc552

Please sign in to comment.