Skip to content

Commit

Permalink
Merge pull request #93 from e10v/dev
Browse files Browse the repository at this point in the history
Cast to float before aggregation
  • Loading branch information
e10v authored Sep 15, 2024
2 parents 27e9f31 + f66ca6b commit 39803e7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/tea_tasting/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,24 +284,24 @@ def read_aggregates(
demean_cols = tuple({*var_cols, *itertools.chain(*cov_cols)})
if len(demean_cols) > 0:
demean_expr = {
_DEMEAN.format(col): data[col] - data[col].mean() # type: ignore
_DEMEAN.format(col): data[col] - data[col].cast("float").mean() # type: ignore
for col in demean_cols
}
grouped_data = data.group_by(group_col) if group_col is not None else data # type: ignore
data = grouped_data.mutate(**demean_expr) # type: ignore

count_expr = {_COUNT: data.count()} if has_count else {}
mean_expr = {_MEAN.format(col): data[col].mean() for col in mean_cols} # type: ignore
mean_expr = {_MEAN.format(col): data[col].cast("float").mean() for col in mean_cols} # type: ignore
var_expr = {
_VAR.format(col): (
data[_DEMEAN.format(col)] * data[_DEMEAN.format(col)]
).sum().cast("float") / (data.count() - 1) # type: ignore
).sum() / (data.count() - 1) # type: ignore
for col in var_cols
}
cov_expr = {
_COV.format(left, right): (
data[_DEMEAN.format(left)] * data[_DEMEAN.format(right)]
).sum().cast("float") / (data.count() - 1) # type: ignore
).sum() / (data.count() - 1) # type: ignore
for left, right in cov_cols
}

Expand Down

0 comments on commit 39803e7

Please sign in to comment.