diff --git a/src/tea_tasting/aggr.py b/src/tea_tasting/aggr.py index d03380c..9b10794 100644 --- a/src/tea_tasting/aggr.py +++ b/src/tea_tasting/aggr.py @@ -284,24 +284,24 @@ def read_aggregates( demean_cols = tuple({*var_cols, *itertools.chain(*cov_cols)}) if len(demean_cols) > 0: demean_expr = { - _DEMEAN.format(col): data[col] - data[col].mean() # type: ignore + _DEMEAN.format(col): data[col] - data[col].cast("float").mean() # type: ignore for col in demean_cols } grouped_data = data.group_by(group_col) if group_col is not None else data # type: ignore data = grouped_data.mutate(**demean_expr) # type: ignore count_expr = {_COUNT: data.count()} if has_count else {} - mean_expr = {_MEAN.format(col): data[col].mean() for col in mean_cols} # type: ignore + mean_expr = {_MEAN.format(col): data[col].cast("float").mean() for col in mean_cols} # type: ignore var_expr = { _VAR.format(col): ( data[_DEMEAN.format(col)] * data[_DEMEAN.format(col)] - ).sum().cast("float") / (data.count() - 1) # type: ignore + ).sum() / (data.count() - 1) # type: ignore for col in var_cols } cov_expr = { _COV.format(left, right): ( data[_DEMEAN.format(left)] * data[_DEMEAN.format(right)] - ).sum().cast("float") / (data.count() - 1) # type: ignore + ).sum() / (data.count() - 1) # type: ignore for left, right in cov_cols }