Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed error in warn len test #314

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars_ds"
version = "0.7.1"
version = "0.8.0"
edition = "2021"

[lib]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "maturin"
[project]
name = "polars_ds"
requires-python = ">=3.9"
version = "0.7.1"
version = "0.8.0"

license = { file = "LICENSE.txt" }
classifiers = [
Expand Down
2 changes: 1 addition & 1 deletion python/polars_ds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ._utils import str_to_expr
from polars_ds.exprs import *

__version__ = "0.7.1"
__version__ = "0.8.0"

def frame(size: int = 2_000, index_name: str = "row_num") -> pl.DataFrame:
"""
Expand Down
1 change: 1 addition & 0 deletions python/polars_ds/compat/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"pl",
"annotations",
"__version__",
"warn_len_compare"
}

__all__ = ["compat"]
Expand Down
47 changes: 25 additions & 22 deletions python/polars_ds/exprs/expr_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@
"within_dist_from",
]

def warn_len_compare(item1: Iterable[Any], item2: Iterable[Any]) -> bool:
"""
Compares the len of two Iterables if they have len returning true and warning if no len.

Parameters
----------
item1: Iterable[Any]
Any iterable
item2: Iterable[Any])
Any iterable

Returns:
bool: If both items have __len__ then it will simply return whether or not
they have equal size. If they don't have len then it returns True with a
warning
"""
# print()
if hasattr(item1, "__len__") and hasattr(item2, "__len__"):
return len(cast(Sequence, item1)) == len(cast(Sequence, item2))
else:
msg = "The inputs do not each have len so can't be compared, unexpected results may follow."
warnings.warn(msg, stacklevel=2)
return True


def query_dist_from_kth_nb(
*features: str | pl.Expr,
Expand Down Expand Up @@ -309,27 +333,6 @@ def query_knn_avg(
)


def warn_len_compare(item1: Iterable[Any], item2: Iterable[Any]) -> bool:
"""
Compares the len of two Iterables if they have len returning true and warning if no len.

Args:
item1 (Iterable[Any]): Any iterable
item2 (Iterable[Any]): Any iterable

Returns:
bool: If both items have __len__ then it will simply return whether or not
they have equal size. If they don't have len then it returns True with a
warning
"""
if hasattr(item1, "__len__") and hasattr(item2, "__len__"):
return len(cast(Sequence, item1)) == len(cast(Sequence, item2))
else:
msg = "The inputs do not each have len so can't be compared, unexpected results may follow"
warnings.warn(msg)
return True


def within_dist_from(
*features: str | pl.Expr,
pt: Sequence[float] | Iterable[float],
Expand Down Expand Up @@ -419,7 +422,7 @@ def is_knn_from(
"""
# For a single point, it is faster to just do it in native polars
oth = [str_to_expr(x) for x in features]
if warn_len_compare(pt, oth):
if not warn_len_compare(pt, oth):
raise ValueError("Dimension does not match.")

if dist == "l1":
Expand Down
Loading