From 7c316f7f8e7595ae7ffbe200ee3bb6bcc8d73547 Mon Sep 17 00:00:00 2001 From: Takuya Ueshin Date: Thu, 23 Jan 2025 16:39:59 -0800 Subject: [PATCH] [SPARK-50968][PYTHON] Fix the usage of `Column.__new__` ### What changes were proposed in this pull request? Fixes the usage of `Column.__new__`. ### Why are the changes needed? Currently `Column.__init__` is called in `Column.__new__`, but it will call `__init__` twice, because it will be automatically called from Python, so it doesn't need to be explicitly called in `__new__`. ```py >>> class A: ... def __new__(cls, *args, **kwargs): ... print(f"__NEW__: {args}, {kwargs}") ... obj = object.__new__(cls) ... obj.__init__(*args, **kwargs) ... return obj ... def __init__(self, *args, **kwargs): ... print(f"__INIT__: {args}, {kwargs}") ... >>> A(1,2,3, k=4) __NEW__: (1, 2, 3), {'k': 4} __INIT__: (1, 2, 3), {'k': 4} __INIT__: (1, 2, 3), {'k': 4} <__main__.A object at 0x102ccab90> >>> class B: ... def __new__(cls, *args, **kwargs): ... print(f"__NEW__: {args}, {kwargs}") ... return object.__new__(cls) ... def __init__(self, *args, **kwargs): ... print(f"__INIT__: {args}, {kwargs}") ... >>> B(1,2,3, k=4) __NEW__: (1, 2, 3), {'k': 4} __INIT__: (1, 2, 3), {'k': 4} <__main__.B object at 0x102b2b970> ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The existing tests should pass. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49631 from ueshin/issues/SPARK-50968/column_new. Authored-by: Takuya Ueshin Signed-off-by: Takuya Ueshin --- python/pyspark/pandas/indexing.py | 19 ++++++++++++++----- python/pyspark/pandas/utils.py | 12 ++++++++++++ python/pyspark/sql/classic/column.py | 12 +++++------- python/pyspark/sql/column.py | 11 ++--------- python/pyspark/sql/connect/column.py | 13 ++++++------- 5 files changed, 39 insertions(+), 28 deletions(-) diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index c93366a31e315..d1ab273ef4f5e 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -29,6 +29,7 @@ from pyspark.sql import functions as F, Column as PySparkColumn from pyspark.sql.types import BooleanType, LongType, DataType +from pyspark.sql.utils import is_remote from pyspark.errors import AnalysisException from pyspark import pandas as ps # noqa: F401 from pyspark.pandas._typing import Label, Name, Scalar @@ -534,11 +535,19 @@ def __getitem__(self, key: Any) -> Union["Series", "DataFrame"]: sdf = sdf.limit(sdf.count() + limit) sdf = sdf.drop(NATURAL_ORDER_COLUMN_NAME) except AnalysisException: - raise KeyError( - "[{}] don't exist in columns".format( - [col._jc.toString() for col in data_spark_columns] - ) - ) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + cols_as_str = [ + cast(ConnectColumn, col)._expr.__repr__() for col in data_spark_columns + ] + else: + from pyspark.sql.classic.column import Column as ClassicColumn + + cols_as_str = [ + cast(ClassicColumn, col)._jc.toString() for col in data_spark_columns + ] + raise KeyError("[{}] don't exist in columns".format(cols_as_str)) internal = InternalFrame( spark_frame=sdf, diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index 111bfd4630667..e17afa026c5af 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -957,6 +957,18 @@ def spark_column_equals(left: Column, right: Column) -> bool: ) return repr(left).replace("`", "") == repr(right).replace("`", "") else: + from pyspark.sql.classic.column import Column as ClassicColumn + + if not isinstance(left, ClassicColumn): + raise PySparkTypeError( + errorClass="NOT_COLUMN", + messageParameters={"arg_name": "left", "arg_type": type(left).__name__}, + ) + if not isinstance(right, ClassicColumn): + raise PySparkTypeError( + errorClass="NOT_COLUMN", + messageParameters={"arg_name": "right", "arg_type": type(right).__name__}, + ) return left._jc.equals(right._jc) diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py index fe0e440203c36..161f8ba4bb7ab 100644 --- a/python/pyspark/sql/classic/column.py +++ b/python/pyspark/sql/classic/column.py @@ -177,13 +177,11 @@ def _reverse_op( @with_origin_to_class class Column(ParentColumn): - def __new__( - cls, - jc: "JavaObject", - ) -> "Column": - self = object.__new__(cls) - self.__init__(jc) # type: ignore[misc] - return self + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": + return object.__new__(cls) + + def __getnewargs__(self) -> Tuple[Any, ...]: + return (self._jc,) def __init__(self, jc: "JavaObject") -> None: self._jc = jc diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index e5640dd81b1fb..a055e44564952 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -31,7 +31,6 @@ from pyspark.errors import PySparkValueError if TYPE_CHECKING: - from py4j.java_gateway import JavaObject from pyspark.sql._typing import LiteralType, DecimalLiteral, DateTimeLiteral from pyspark.sql.window import WindowSpec @@ -72,16 +71,10 @@ class Column(TableValuedFunctionArgument): # HACK ALERT!! this is to reduce the backward compatibility concern, and returns # Spark Classic Column by default. This is NOT an API, and NOT supposed to # be directly invoked. DO NOT use this constructor. - def __new__( - cls, - jc: "JavaObject", - ) -> "Column": + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": from pyspark.sql.classic.column import Column - return Column.__new__(Column, jc) - - def __init__(self, jc: "JavaObject") -> None: - self._jc = jc + return Column.__new__(Column, *args, **kwargs) # arithmetic operators @dispatch_col_method diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index c5733801814eb..e6d58aefbf2f9 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -27,6 +27,7 @@ Any, Union, Optional, + Tuple, ) from pyspark.sql.column import Column as ParentColumn @@ -109,13 +110,11 @@ def _to_expr(v: Any) -> Expression: @with_origin_to_class(["to_plan"]) class Column(ParentColumn): - def __new__( - cls, - expr: "Expression", - ) -> "Column": - self = object.__new__(cls) - self.__init__(expr) # type: ignore[misc] - return self + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": + return object.__new__(cls) + + def __getnewargs__(self) -> Tuple[Any, ...]: + return (self._expr,) def __init__(self, expr: "Expression") -> None: if not isinstance(expr, Expression):