Skip to content

Commit

Permalink
ENH: Add include_groups parameter into groupby.apply (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
hainaweiben authored Jan 13, 2025
1 parent 081d5f9 commit d458b8b
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions python/xorbits/_mars/dataframe/groupby/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class GroupByApply(
maybe_agg = BoolField("maybe_agg", default=None)
logic_key = StringField("logic_key", default=None)
func_key = AnyField("func_key", default=None)
include_groups = BoolField("include_groups", default=None)

def __init__(self, output_types=None, **kw):
super().__init__(_output_types=output_types, **kw)
Expand Down Expand Up @@ -96,7 +97,10 @@ def execute(cls, ctx, op):
# cudf groupby apply does not receive kwargs parameters.
applied = in_data.apply(func, *op.args)
else:
applied = in_data.apply(func, *op.args, **op.kwds)
kwargs = op.kwds.copy()
if op.include_groups is not None:
kwargs["include_groups"] = op.include_groups
applied = in_data.apply(func, *op.args, **kwargs)

if isinstance(applied, pd.DataFrame):
# when there is only one group, pandas tend to return a DataFrame, while
Expand Down Expand Up @@ -289,6 +293,7 @@ def groupby_apply(
name=None,
index=None,
skip_infer=None,
include_groups=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -331,6 +336,11 @@ def groupby_apply(
skip_infer: bool, default False
Whether infer dtypes when dtypes or output_type is not specified.
include_groups: bool, default None
Whether to include grouping columns in the operation. If None,
defaults to True for backwards compatibility but will change to
False in a future version.
args, kwargs : tuple and dict
Optional positional and keyword arguments to pass to `func`.
Expand Down Expand Up @@ -370,5 +380,11 @@ def groupby_apply(

dtypes = make_dtypes(dtypes)
dtype = make_dtype(dtype)
op = GroupByApply(func=func, args=args, kwds=kwargs, output_types=output_types)
op = GroupByApply(
func=func,
args=args,
kwds=kwargs,
output_types=output_types,
include_groups=include_groups,
)
return op(groupby, dtypes=dtypes, dtype=dtype, name=name, index=index)

0 comments on commit d458b8b

Please sign in to comment.