-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
173 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""Helpers for submission creation.""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
def create_submission_frame( | ||
forecast_date: str | None = None, | ||
bikes_preds=None, | ||
energy_preds=None, | ||
no2_preds=None, | ||
): | ||
"""Create the submission dataframe. | ||
Parameters | ||
---------- | ||
forecast_date : str | ||
Forecast date in the format "YYYY-MM-DD". | ||
bikes_preds : array-like or None | ||
Predictions for bikes. Must have shape (6, 5). | ||
If None, it will be filled with NaNs. | ||
energy_preds : array-like or None | ||
Predictions for energy. Must have shape (6, 5). | ||
If None, it will be filled with NaNs. | ||
no2_preds : array-like or None | ||
Predictions for no2. Must have shape (6, 5). | ||
If None, it will be filled with NaNs. | ||
""" | ||
if forecast_date is None: | ||
# use today | ||
forecast_date = pd.Timestamp.now().strftime("%Y-%m-%d") | ||
|
||
# Define horizons and targets | ||
bike_horizons = ["1 day", "2 day", "3 day", "4 day", "5 day", "6 day"] | ||
energy_horizons = ["36 hour", "40 hour", "44 hour", "60 hour", "64 hour", "68 hour"] | ||
no2_horizons = ["36 hour", "40 hour", "44 hour", "60 hour", "64 hour", "68 hour"] | ||
|
||
# Define quantile column names | ||
quantiles = ["q0.025", "q0.25", "q0.5", "q0.75", "q0.975"] | ||
|
||
# Ensure predictions are valid or filled with NaNs | ||
n_rows = len(bike_horizons) | ||
n_cols = len(quantiles) | ||
bikes_preds = check_predictions(bikes_preds, n_rows=n_rows, n_cols=n_cols) | ||
energy_preds = check_predictions(energy_preds, n_rows=n_rows, n_cols=n_cols) | ||
no2_preds = check_predictions(no2_preds, n_rows=n_rows, n_cols=n_cols) | ||
|
||
bikes_df = pd.DataFrame( | ||
{ | ||
"forecast_date": forecast_date, | ||
"target": "bikes", | ||
"horizon": bike_horizons, | ||
} | ||
) | ||
|
||
energy_df = pd.DataFrame( | ||
{ | ||
"forecast_date": forecast_date, | ||
"target": "energy", | ||
"horizon": energy_horizons, | ||
} | ||
) | ||
|
||
no2_df = pd.DataFrame( | ||
{ | ||
"forecast_date": forecast_date, | ||
"target": "no2", | ||
"horizon": no2_horizons, | ||
} | ||
) | ||
|
||
submission = pd.concat([bikes_df, energy_df, no2_df], ignore_index=True) | ||
|
||
submission.loc[submission["target"] == "bikes", quantiles] = bikes_preds | ||
submission.loc[submission["target"] == "energy", quantiles] = energy_preds | ||
submission.loc[submission["target"] == "no2", quantiles] = no2_preds | ||
|
||
return submission | ||
|
||
|
||
def check_predictions(preds, n_rows, n_cols): | ||
"""Check predictions are valid or fill with NaNs. | ||
Return the provided predictions, | ||
or a NumPy array of NaNs if preds is None. | ||
Parameters | ||
---------- | ||
preds : array-like or None | ||
Predictions to check. | ||
n_rows : int | ||
Number of rows for the predictions. | ||
n_cols : int | ||
Number of columns for the predictions. | ||
Returns | ||
------- | ||
2D NumPy array of predictions or NaNs. | ||
""" | ||
if preds is None: | ||
return np.full((n_rows, n_cols), np.nan) | ||
assert preds.shape == (n_rows, n_cols), "Invalid shape for predictions." | ||
return preds |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Tests for create_submission module.""" | ||
|
||
from io import StringIO | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from probafcst.utils.create_submission import create_submission_frame | ||
|
||
|
||
def test_create_submission_frame(): | ||
"""Test create_submission_frame function.""" | ||
expected_frame = """ | ||
forecast_date,target,horizon,q0.025,q0.25,q0.5,q0.75,q0.975 | ||
2024-10-23,bikes,1 day,2496,4676,5552,6444,7408 | ||
2024-10-23,bikes,2 day,2107,4596,5532,6294,7658 | ||
2024-10-23,bikes,3 day,1992,3357,4068,4826,5828 | ||
2024-10-23,bikes,4 day, 873,1568,2178,2528,3335 | ||
2024-10-23,bikes,5 day,1886,4614,5376,6193,7222 | ||
2024-10-23,bikes,6 day,2270,4786,5678,6274,7463 | ||
2024-10-23,energy,36 hour,52887,59429,61920,65836,71181 | ||
2024-10-23,energy,40 hour,48102,54604,57213,62027,68767 | ||
2024-10-23,energy,44 hour,49212,51915,55298,59425,62749 | ||
2024-10-23,energy,60 hour,47175,49336,51944,56183,60777 | ||
2024-10-23,energy,64 hour,42864,45934,47995,53107,59378 | ||
2024-10-23,energy,68 hour,43070,45698,48347,52965,56667 | ||
2024-10-23,no2,36 hour,NA,NA,NA,NA,NA | ||
2024-10-23,no2,40 hour,NA,NA,NA,NA,NA | ||
2024-10-23,no2,44 hour,NA,NA,NA,NA,NA | ||
2024-10-23,no2,60 hour,NA,NA,NA,NA,NA | ||
2024-10-23,no2,64 hour,NA,NA,NA,NA,NA | ||
2024-10-23,no2,68 hour,NA,NA,NA,NA,NA | ||
""" | ||
|
||
expected_frame = pd.read_csv(StringIO(expected_frame), sep=",") | ||
|
||
forecast_date = "2024-10-23" | ||
bikes_preds = np.array( | ||
[ | ||
[2496, 4676, 5552, 6444, 7408], | ||
[2107, 4596, 5532, 6294, 7658], | ||
[1992, 3357, 4068, 4826, 5828], | ||
[873, 1568, 2178, 2528, 3335], | ||
[1886, 4614, 5376, 6193, 7222], | ||
[2270, 4786, 5678, 6274, 7463], | ||
] | ||
) | ||
energy_preds = np.array( | ||
[ | ||
[52887, 59429, 61920, 65836, 71181], | ||
[48102, 54604, 57213, 62027, 68767], | ||
[49212, 51915, 55298, 59425, 62749], | ||
[47175, 49336, 51944, 56183, 60777], | ||
[42864, 45934, 47995, 53107, 59378], | ||
[43070, 45698, 48347, 52965, 56667], | ||
] | ||
) | ||
no2_preds = None | ||
|
||
actual_frame = create_submission_frame( | ||
forecast_date=forecast_date, | ||
bikes_preds=bikes_preds, | ||
energy_preds=energy_preds, | ||
no2_preds=no2_preds, | ||
) | ||
pd.testing.assert_frame_equal(actual_frame, expected_frame) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.