Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Improve performance for polars' pivot_longer #1402

Merged
merged 4 commits into from
Sep 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 110 additions & 35 deletions janitor/polars/pivot_longer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def pivot_longer_spec(
A declarative interface to pivot a Polars Frame
from wide to long form,
where you describe how the data will be unpivoted,
using a DataFrame. This gives you, the user,
using a DataFrame.

It is modeled after tidyr's `pivot_longer_spec`.

This gives you, the user,
more control over the transformation to long form,
using a *spec* DataFrame that describes exactly
how data stored in the column names
Expand Down Expand Up @@ -108,41 +112,56 @@ def pivot_longer_spec(
corresponding to columns pivoted from the wide format.
Note that these additional columns should not already exist
in the source DataFrame.
If there are additional columns, the combination of these columns
and the `.value` column must be unique.

Raises:
KeyError: If `.name` or `.value` is missing from the spec's columns.
ValueError: If the labels in `spec['.name']` is not unique.
ValueError: If the labels in spec's `.name` column is not unique.

Returns:
A polars DataFrame/LazyFrame.
"""
check("spec", spec, [pl.DataFrame])
if ".name" not in spec.columns:
spec_columns = spec.collect_schema().names()
if ".name" not in spec_columns:
raise KeyError(
"Kindly ensure the spec DataFrame has a `.name` column."
)
if ".value" not in spec.columns:
if ".value" not in spec_columns:
raise KeyError(
"Kindly ensure the spec DataFrame has a `.value` column."
)
if spec.select(pl.col(".name").is_duplicated().any()).item():
if spec.get_column(".name").is_duplicated().any():
raise ValueError("The labels in the `.name` column should be unique.")

exclude = set(df.columns).intersection(spec.columns)
df_columns = df.collect_schema().names()
exclude = set(df_columns).intersection(spec_columns)
if exclude:
raise ValueError(
f"Labels {*exclude, } in the spec dataframe already exist "
"as column labels in the source dataframe. "
"Kindly ensure the spec DataFrame's columns "
"are not present in the source DataFrame."
)

index = [
label for label in df.columns if label not in spec.get_column(".name")
label for label in df_columns if label not in spec.get_column(".name")
]
others = [
label for label in spec.columns if label not in {".name", ".value"}
label for label in spec_columns if label not in {".name", ".value"}
]
variable_name = "".join(df.columns + spec.columns)

if (len(others) == 1) & (spec.get_column(others[0]).dtype == pl.String):
# shortcut that avoids the implode/explode approach - and is faster
# if the requirements are met
# inspired by https://github.com/pola-rs/polars/pull/18519#issue-2500860927
return _pivot_longer_dot_value_string(
df=df,
index=index,
spec=spec,
variable_name=others[0],
)
variable_name = "".join(df_columns + spec_columns)
variable_name = f"{variable_name}_"
if others:
dot_value_only = False
Expand Down Expand Up @@ -219,7 +238,7 @@ def pivot_longer(
│ 5.9 ┆ 3.0 ┆ 5.1 ┆ 1.8 ┆ virginica │
└──────────────┴─────────────┴──────────────┴─────────────┴───────────┘

Replicate polars' [melt](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.melt.html#polars-dataframe-melt):
Replicate polars' [melt](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.unpivot.html#polars-dataframe-melt):
>>> df.pivot_longer(index = 'Species').sort(by=pl.all())
shape: (8, 3)
┌───────────┬──────────────┬───────┐
Expand Down Expand Up @@ -375,8 +394,8 @@ def pivot_longer(
specification as polars' `str.split` method.
names_pattern: Determines how the column name is broken up.
It can be a regular expression containing matching groups.
It takes the same
specification as polars' `str.extract_groups` method.
It takes the same specification as
polars' `str.extract_groups` method.
names_transform: Use this option to change the types of columns that
have been transformed to rows.
This does not applies to the values' columns.
Expand Down Expand Up @@ -440,7 +459,7 @@ def _pivot_longer(
names_pattern=names_pattern,
)

variable_name = "".join(df.columns)
variable_name = "".join(df.collect_schema().names())
variable_name = f"{variable_name}_"
spec = _pivot_longer_create_spec(
column_names=column_names,
Expand All @@ -461,8 +480,25 @@ def _pivot_longer(
variable_name=variable_name,
names_transform=names_transform,
)

if {".name", ".value"}.symmetric_difference(spec.columns):
if {".name", ".value"}.symmetric_difference(spec.collect_schema().names()):
# shortcut that avoids the implode/explode approach - and is faster
# if the requirements are met
# inspired by https://github.com/pola-rs/polars/pull/18519#issue-2500860927
data = spec.get_column(variable_name)
others = data.struct.fields
data = data.struct[others[0]]
if (
(len(others) == 1)
& (data.dtype == pl.String)
& (names_transform is None)
):
spec = spec.unnest(variable_name)
return _pivot_longer_dot_value_string(
df=df,
index=index,
spec=spec,
variable_name=others[0],
)
dot_value_only = False
else:
dot_value_only = True
Expand Down Expand Up @@ -552,7 +588,7 @@ def _pivot_longer_create_spec(
return spec.select(".name", ".value")
_spec = spec.get_column(variable_name)
_spec = _spec.struct.unnest()
fields = _spec.columns
fields = _spec.collect_schema().names()

if len(set(names_to)) == 1:
expression = pl.concat_str(fields).alias(".value")
Expand Down Expand Up @@ -591,7 +627,7 @@ def _pivot_longer_no_dot_value(
# do the operation on a smaller size
# and then blow it up after
# it is usually much faster
# than running on the actual data
# than unpivoting and running the string operations after
outcome = (
df.select(pl.all().implode())
.unpivot(
Expand All @@ -606,11 +642,44 @@ def _pivot_longer_no_dot_value(
outcome = outcome.unnest(variable_name)
if names_transform is not None:
outcome = outcome.with_columns(names_transform)
columns = [name for name in outcome.columns if name not in names_to]
columns = [
name
for name in outcome.collect_schema().names()
if name not in names_to
]
outcome = outcome.explode(columns=columns)
return outcome


def _pivot_longer_dot_value_string(
df: pl.DataFrame | pl.LazyFrame,
spec: pl.DataFrame,
index: ColumnNameOrSelector,
variable_name: str,
) -> pl.DataFrame | pl.LazyFrame:
"""
fastpath for .value - does not require implode/explode approach.
"""
spec = spec.group_by(variable_name)
spec = spec.agg(pl.all())
expressions = []
for names, fields, header in zip(
spec.get_column(".name").to_list(),
spec.get_column(".value").to_list(),
spec.get_column(variable_name).to_list(),
):
expression = pl.struct(names).struct.rename_fields(names=fields)
expression = expression.alias(header)
expressions.append(expression)
expressions = [*index, *expressions]
df = (
df.select(expressions)
.unpivot(index=index, variable_name=variable_name, value_name=".value")
.unnest(".value")
)
return df


def _pivot_longer_dot_value(
df: pl.DataFrame | pl.LazyFrame,
spec: pl.DataFrame,
Expand All @@ -621,7 +690,7 @@ def _pivot_longer_dot_value(
) -> pl.DataFrame | pl.LazyFrame:
"""
flip polars Frame to long form,
if names_sep and .value in names_to.
if .value in names_to.
"""
spec = spec.group_by(variable_name)
spec = spec.agg(pl.all())
Expand All @@ -634,25 +703,31 @@ def _pivot_longer_dot_value(
expressions.append(expression)
expressions = [*index, *expressions]
spec = spec.get_column(variable_name)
if dot_value_only:
outcome = (
df.select(expressions)
.unpivot(
index=index, variable_name=variable_name, value_name=".value"
)
.select(pl.exclude(variable_name))
.unnest(".value")
)
return outcome

outcome = (
df.select(expressions)
.select(pl.all().implode())
.unpivot(index=index, variable_name=variable_name, value_name=".value")
.with_columns(spec)
)

if dot_value_only:
columns = [
label for label in outcome.columns if label != variable_name
]
outcome = outcome.explode(columns).unnest(".value")
outcome = outcome.select(pl.exclude(variable_name))
return outcome
outcome = outcome.unnest(variable_name)
if names_transform is not None:
outcome = outcome.with_columns(names_transform)
columns = [
label for label in outcome.columns if label not in spec.struct.fields
label
for label in outcome.collect_schema().names()
if label not in spec.struct.fields
]
outcome = outcome.explode(columns)
outcome = outcome.unnest(".value")
Expand Down Expand Up @@ -710,17 +785,17 @@ def _data_checks_pivot_longer(
check("values_to", values_to, [str])

if (index is None) and (column_names is None):
column_names = df.columns
column_names = df.collect_schema().names()
index = []
elif (index is None) and (column_names is not None):
column_names = df.select(column_names).columns
index = df.select(pl.exclude(column_names)).columns
column_names = df.select(column_names).collect_schema().names()
index = df.select(pl.exclude(column_names)).collect_schema().names()
elif (index is not None) and (column_names is None):
index = df.select(index).columns
column_names = df.select(pl.exclude(index)).columns
index = df.select(index).collect_schema().names()
column_names = df.select(pl.exclude(index)).collect_schema().names()
else:
index = df.select(index).columns
column_names = df.select(column_names).columns
index = df.select(index).collect_schema().names()
column_names = df.select(column_names).collect_schema().names()

return (
df,
Expand Down
Loading