Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move robust literal handling for PySpark #1880

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

from narwhals._spark_like.utils import ExprKind
from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.typing import CompliantLazyFrame
Expand Down Expand Up @@ -97,9 +98,7 @@ def select(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns, returns_scalar = parse_exprs_and_named_exprs(
self, *exprs, **named_exprs
)
new_columns, expr_kinds = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)

if not new_columns:
# return empty dataframe, like Polars does
Expand All @@ -110,19 +109,17 @@ def select(

return self._from_native_frame(spark_df)

if all(returns_scalar):
if not any(expr_kind is ExprKind.TRANSFORM for expr_kind in expr_kinds):
new_columns_list = [
col.alias(col_name) for col_name, col in new_columns.items()
]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
else:
new_columns_list = [
col.over(Window.partitionBy(F.lit(1))).alias(col_name)
if _returns_scalar
if expr_kind is ExprKind.AGGREGATION
else col.alias(col_name)
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
for (col_name, col), expr_kind in zip(new_columns.items(), expr_kinds)
]
return self._from_native_frame(self._native_frame.select(*new_columns_list))

Expand All @@ -131,15 +128,13 @@ def with_columns(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns, returns_scalar = parse_exprs_and_named_exprs(
self, *exprs, **named_exprs
)
new_columns, expr_kinds = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)

new_columns_map = {
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
col_name: col.over(Window.partitionBy(F.lit(1)))
if expr_kind is ExprKind.AGGREGATION
else col
for (col_name, col), expr_kind in zip(new_columns.items(), expr_kinds)
}
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

Expand Down
Loading
Loading