From 4a35a99801844ae59139d88bb2225704f335feac Mon Sep 17 00:00:00 2001 From: Richard Date: Sat, 4 Apr 2020 10:44:33 -0400 Subject: [PATCH 1/3] CLN: Add/refine type hints to some functions in core.dtypes.cast - maybe_cast_result obj is now "Series" - Added type hint and assert to maybe_cast_to_extension_array - Added generic type to _GroupBy and GroupBy in core.groupby.groupby - Updated missed "try_cast" references to "maybe_cast" --- pandas/core/dtypes/cast.py | 20 +++++++++++++------- pandas/core/groupby/generic.py | 6 +++--- pandas/core/groupby/groupby.py | 11 ++++++++--- pandas/core/groupby/ops.py | 2 +- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index da9646aa8c46f..767b65c885759 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -3,6 +3,7 @@ """ from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Type import numpy as np @@ -63,6 +64,7 @@ ABCDataFrame, ABCDatetimeArray, ABCDatetimeIndex, + ABCExtensionArray, ABCPeriodArray, ABCPeriodIndex, ABCSeries, @@ -70,6 +72,10 @@ from pandas.core.dtypes.inference import is_list_like from pandas.core.dtypes.missing import isna, notna +if TYPE_CHECKING: + from pandas import Series + from pandas.api.extensions import ExtensionArray + _int8_max = np.iinfo(np.int8).max _int16_max = np.iinfo(np.int16).max _int32_max = np.iinfo(np.int32).max @@ -246,9 +252,7 @@ def trans(x): return result -def maybe_cast_result( - result, obj: ABCSeries, numeric_only: bool = False, how: str = "" -): +def maybe_cast_result(result, obj: "Series", numeric_only: bool = False, how: str = ""): """ Try casting result to a different type if appropriate @@ -256,8 +260,8 @@ def maybe_cast_result( ---------- result : array-like Result to cast. - obj : ABCSeries - Input series from which result was calculated. + obj : Series + Input Series from which result was calculated. numeric_only : bool, default False Whether to cast only numerics or datetimes as well. how : str, default "" @@ -313,13 +317,13 @@ def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj: return d.get((dtype, how), dtype) -def maybe_cast_to_extension_array(cls, obj, dtype=None): +def maybe_cast_to_extension_array(cls: Type["ExtensionArray"], obj, dtype=None): """ Call to `_from_sequence` that returns the object unchanged on Exception. Parameters ---------- - cls : ExtensionArray subclass + cls : class, subclass of ExtensionArray obj : arraylike Values to pass to cls._from_sequence dtype : ExtensionDtype, optional @@ -329,6 +333,8 @@ def maybe_cast_to_extension_array(cls, obj, dtype=None): ExtensionArray or obj """ assert isinstance(cls, type), f"must pass a type: {cls}" + assertion_msg = f"must pass a subclass of ExtensionArray: {cls}" + assert issubclass(cls, ABCExtensionArray), assertion_msg try: result = cls._from_sequence(obj, dtype=dtype) except Exception: diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 093c925acbc49..88580f6ebb3ed 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -151,7 +151,7 @@ def pinner(cls): @pin_whitelisted_properties(Series, base.series_apply_whitelist) -class SeriesGroupBy(GroupBy): +class SeriesGroupBy(GroupBy[Series]): _apply_whitelist = base.series_apply_whitelist def _iterate_slices(self) -> Iterable[Series]: @@ -815,7 +815,7 @@ def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None): @pin_whitelisted_properties(DataFrame, base.dataframe_apply_whitelist) -class DataFrameGroupBy(GroupBy): +class DataFrameGroupBy(GroupBy[DataFrame]): _apply_whitelist = base.dataframe_apply_whitelist @@ -1462,7 +1462,7 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame: for i, _ in enumerate(result.columns): res = algorithms.take_1d(result.iloc[:, i].values, ids) # TODO: we have no test cases that get here with EA dtypes; - # try_cast may not be needed if EAs never get here + # maybe_cast_result may not be needed if EAs never get here if cast: res = maybe_cast_result(res, obj.iloc[:, i], how=func_nm) output.append(res) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index ebdb0062491be..38acef6edb5cb 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -17,6 +17,7 @@ class providing the base-class of operations. Callable, Dict, FrozenSet, + Generic, Hashable, Iterable, List, @@ -24,6 +25,7 @@ class providing the base-class of operations. Optional, Tuple, Type, + TypeVar, Union, ) @@ -353,13 +355,16 @@ def _group_selection_context(groupby): ] -class _GroupBy(PandasObject, SelectionMixin): +_GroupByT = TypeVar("_GroupByT", bound="NDFrame") + + +class _GroupBy(Generic[_GroupByT], PandasObject, SelectionMixin): _group_selection = None _apply_whitelist: FrozenSet[str] = frozenset() def __init__( self, - obj: NDFrame, + obj: _GroupByT, keys: Optional[_KeysArgType] = None, axis: int = 0, level=None, @@ -995,7 +1000,7 @@ def _apply_filter(self, indices, dropna): return filtered -class GroupBy(_GroupBy): +class GroupBy(_GroupBy[_GroupByT]): """ Class for grouping and aggregating relational data. diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 742de397956c0..8d535374a083f 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -682,7 +682,7 @@ def _aggregate_series_pure_python(self, obj: Series, func): assert result is not None result = lib.maybe_convert_objects(result, try_float=0) - # TODO: try_cast back to EA? + # TODO: maybe_cast_to_extension_array? return result, counts From 51ca0dc2195500123892cd7dd0274f76da36ddba Mon Sep 17 00:00:00 2001 From: Richard Date: Sat, 4 Apr 2020 15:31:44 -0400 Subject: [PATCH 2/3] Added noqa for linter --- pandas/core/dtypes/cast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 767b65c885759..6cd0948f9beb8 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -74,7 +74,7 @@ if TYPE_CHECKING: from pandas import Series - from pandas.api.extensions import ExtensionArray + from pandas.core.arrays import ExtensionArray # noqa: F401 _int8_max = np.iinfo(np.int8).max _int16_max = np.iinfo(np.int16).max From 52bd752941575a5f03969084a63868277c7ddaa8 Mon Sep 17 00:00:00 2001 From: Richard Date: Sun, 5 Apr 2020 10:19:31 -0400 Subject: [PATCH 3/3] Reworked generic typing of GroupBy and _GroupBy --- pandas/core/groupby/groupby.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 38acef6edb5cb..4d19f6a79ff87 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -355,16 +355,13 @@ def _group_selection_context(groupby): ] -_GroupByT = TypeVar("_GroupByT", bound="NDFrame") - - -class _GroupBy(Generic[_GroupByT], PandasObject, SelectionMixin): +class _GroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]): _group_selection = None _apply_whitelist: FrozenSet[str] = frozenset() def __init__( self, - obj: _GroupByT, + obj: FrameOrSeries, keys: Optional[_KeysArgType] = None, axis: int = 0, level=None, @@ -1000,7 +997,11 @@ def _apply_filter(self, indices, dropna): return filtered -class GroupBy(_GroupBy[_GroupByT]): +# To track operations that expand dimensions, like ohlc +OutputFrameOrSeries = TypeVar("OutputFrameOrSeries", bound=NDFrame) + + +class GroupBy(_GroupBy[FrameOrSeries]): """ Class for grouping and aggregating relational data. @@ -2425,8 +2426,8 @@ def tail(self, n=5): return self._selected_obj[mask] def _reindex_output( - self, output: FrameOrSeries, fill_value: Scalar = np.NaN - ) -> FrameOrSeries: + self, output: OutputFrameOrSeries, fill_value: Scalar = np.NaN + ) -> OutputFrameOrSeries: """ If we have categorical groupers, then we might want to make sure that we have a fully re-indexed output to the levels. This means expanding