diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index c93366a31e315..d1ab273ef4f5e 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -29,6 +29,7 @@ from pyspark.sql import functions as F, Column as PySparkColumn from pyspark.sql.types import BooleanType, LongType, DataType +from pyspark.sql.utils import is_remote from pyspark.errors import AnalysisException from pyspark import pandas as ps # noqa: F401 from pyspark.pandas._typing import Label, Name, Scalar @@ -534,11 +535,19 @@ def __getitem__(self, key: Any) -> Union["Series", "DataFrame"]: sdf = sdf.limit(sdf.count() + limit) sdf = sdf.drop(NATURAL_ORDER_COLUMN_NAME) except AnalysisException: - raise KeyError( - "[{}] don't exist in columns".format( - [col._jc.toString() for col in data_spark_columns] - ) - ) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + cols_as_str = [ + cast(ConnectColumn, col)._expr.__repr__() for col in data_spark_columns + ] + else: + from pyspark.sql.classic.column import Column as ClassicColumn + + cols_as_str = [ + cast(ClassicColumn, col)._jc.toString() for col in data_spark_columns + ] + raise KeyError("[{}] don't exist in columns".format(cols_as_str)) internal = InternalFrame( spark_frame=sdf, diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index 111bfd4630667..e17afa026c5af 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -957,6 +957,18 @@ def spark_column_equals(left: Column, right: Column) -> bool: ) return repr(left).replace("`", "") == repr(right).replace("`", "") else: + from pyspark.sql.classic.column import Column as ClassicColumn + + if not isinstance(left, ClassicColumn): + raise PySparkTypeError( + errorClass="NOT_COLUMN", + messageParameters={"arg_name": "left", "arg_type": type(left).__name__}, + ) + if not isinstance(right, ClassicColumn): + raise PySparkTypeError( + errorClass="NOT_COLUMN", + messageParameters={"arg_name": "right", "arg_type": type(right).__name__}, + ) return left._jc.equals(right._jc) diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py index fe0e440203c36..161f8ba4bb7ab 100644 --- a/python/pyspark/sql/classic/column.py +++ b/python/pyspark/sql/classic/column.py @@ -177,13 +177,11 @@ def _reverse_op( @with_origin_to_class class Column(ParentColumn): - def __new__( - cls, - jc: "JavaObject", - ) -> "Column": - self = object.__new__(cls) - self.__init__(jc) # type: ignore[misc] - return self + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": + return object.__new__(cls) + + def __getnewargs__(self) -> Tuple[Any, ...]: + return (self._jc,) def __init__(self, jc: "JavaObject") -> None: self._jc = jc diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index e5640dd81b1fb..a055e44564952 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -31,7 +31,6 @@ from pyspark.errors import PySparkValueError if TYPE_CHECKING: - from py4j.java_gateway import JavaObject from pyspark.sql._typing import LiteralType, DecimalLiteral, DateTimeLiteral from pyspark.sql.window import WindowSpec @@ -72,16 +71,10 @@ class Column(TableValuedFunctionArgument): # HACK ALERT!! this is to reduce the backward compatibility concern, and returns # Spark Classic Column by default. This is NOT an API, and NOT supposed to # be directly invoked. DO NOT use this constructor. - def __new__( - cls, - jc: "JavaObject", - ) -> "Column": + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": from pyspark.sql.classic.column import Column - return Column.__new__(Column, jc) - - def __init__(self, jc: "JavaObject") -> None: - self._jc = jc + return Column.__new__(Column, *args, **kwargs) # arithmetic operators @dispatch_col_method diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index c5733801814eb..e6d58aefbf2f9 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -27,6 +27,7 @@ Any, Union, Optional, + Tuple, ) from pyspark.sql.column import Column as ParentColumn @@ -109,13 +110,11 @@ def _to_expr(v: Any) -> Expression: @with_origin_to_class(["to_plan"]) class Column(ParentColumn): - def __new__( - cls, - expr: "Expression", - ) -> "Column": - self = object.__new__(cls) - self.__init__(expr) # type: ignore[misc] - return self + def __new__(cls, *args: Any, **kwargs: Any) -> "Column": + return object.__new__(cls) + + def __getnewargs__(self) -> Tuple[Any, ...]: + return (self._expr,) def __init__(self, expr: "Expression") -> None: if not isinstance(expr, Expression):