Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests for V1.2.0 #8

Open
wants to merge 3 commits into
base: v1.1.2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion toraniko/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def winsorize_group(group: pl.DataFrame) -> pl.DataFrame:
return group

try:
return df.lazy().group_by(group_col).map_groups(winsorize_group, schema=df.collect_schema())
result = df.lazy().group_by(group_col).map_groups(winsorize_group, schema=df.collect_schema())
return result if isinstance(df, pl.LazyFrame) else result.collect()
except AttributeError as e:
raise TypeError(
"`df` must be a Polars DataFrame or LazyFrame, but it's missing `group_by`, `map_groups` "
Expand Down
11 changes: 9 additions & 2 deletions toraniko/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,15 @@ def estimate_factor_returns(
returns_df = (
returns_df.lazy()
.join(mkt_cap_df.lazy(), on=[date_col, symbol_col])
.join(sector_df.lazy(), on=symbol_col)
.join(sector_df.lazy(), on=[date_col, symbol_col])
.join(style_df.lazy(), on=[date_col, symbol_col])
).collect()
# split the conditional winsorization branch into two functions, so we don't have a conditional
# needlessly evaluated on each time period's iteration in the `.map_groups`
if winsor_factor is not None:

def _estimate_factor_returns(data):
r = winsorize(data[asset_returns_col].to_numpy())
r = winsorize(data[asset_returns_col].to_numpy(), percentile=winsor_factor)
fac, eps = factor_returns_cs(
r,
data[mkt_cap_col].to_numpy(),
Expand Down Expand Up @@ -224,3 +224,10 @@ def _estimate_factor_returns(data):
f"`returns_df` must have columns '{date_col}', '{symbol_col}' and '{asset_returns_col}'; "
f"`mkt_cap_df` must have '{date_col}', '{symbol_col}' and '{mkt_cap_col}' columns"
) from e
except BaseException as e:
# This exception is not handled in Polars https://github.com/pola-rs/polars/issues/7704
# Exception raised in groupby.apply(UDF) causes panic
raise ValueError(
f"`returns_df` must have columns '{date_col}', '{symbol_col}' and '{asset_returns_col}'; "
f"`mkt_cap_df` must have '{date_col}', '{symbol_col}' and '{mkt_cap_col}' columns"
) from e
26 changes: 13 additions & 13 deletions toraniko/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,22 @@ def test_reproducibility(sample_data):
def model_sample_data():
dates = ["2021-01-01", "2021-01-02", "2021-01-03"]
symbols = ["AAPL", "MSFT", "GOOGL"]
returns_data = {"date": dates * 3, "symbol": symbols * 3, "asset_returns": np.random.randn(9)}
mkt_cap_data = {"date": dates * 3, "symbol": symbols * 3, "market_cap": np.random.rand(9) * 1000}
date_symbol_pairs = [(date, symbol) for date in dates for symbol in symbols]
dates_list, symbols_list = zip(*date_symbol_pairs)
returns_data = {"date": dates_list, "symbol": symbols_list, "asset_returns": np.random.randn(9)}
mkt_cap_data = {"date": dates_list, "symbol": symbols_list, "market_cap": np.random.rand(9) * 1000}
sector_data = {
"date": dates * 3,
"symbol": symbols * 3,
"date": dates_list,
"symbol": symbols_list,
"Tech": [1, 0, 0] * 3,
"Finance": [0, 1, 0] * 3,
"Consumer": [0, 0, 1] * 3,
}
style_data = {
"date": dates * 3,
"symbol": symbols * 3,
"Value": [1.2, 1.3, 1.4] * 3,
"Growth": [0.8, 0.9, 1.0] * 3,
"Size": [1.6, 2.5, -0.6] * 3,
}
"Consumer": [0, 0, 1] * 3}
style_data = {"date": dates_list,
"symbol": symbols_list,
"Value": [1.2, 1.3, 1.4] * 3,
"Growth": [0.8, 0.9, 1.0] * 3,
"Size": [1.6, 2.5, -0.6] * 3}

return (pl.DataFrame(returns_data), pl.DataFrame(mkt_cap_data), pl.DataFrame(sector_data), pl.DataFrame(style_data))


Expand Down