Skip to content

Commit

Permalink
fix covid support
Browse files Browse the repository at this point in the history
  • Loading branch information
elray1 committed Nov 20, 2024
1 parent 872f956 commit 8323af1
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/idmodels/gbqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import lightgbm as lgb
import numpy as np
import pandas as pd
from iddata.loader import FluDataLoader
from iddata.loader import DiseaseDataLoader
from tqdm.autonotebook import tqdm

from idmodels.preprocess import create_features_and_targets
Expand Down Expand Up @@ -31,7 +31,7 @@ def run(self, run_config):
ilinet_kwargs = {"scale_to_positive": False}
flusurvnet_kwargs = {"burden_adj": False}

fdl = FluDataLoader()
fdl = DiseaseDataLoader()
df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date, disease=run_config.disease},
ilinet_kwargs=ilinet_kwargs,
flusurvnet_kwargs=flusurvnet_kwargs,
Expand All @@ -41,14 +41,20 @@ def run(self, run_config):
df = df.loc[df["location"].isin(run_config.locations)]

# augment data with features and target values
if run_config.disease == "flu":
init_feats = ["inc_trans_cs", "season_week", "log_pop"]
elif run_config.disease == "covid":
init_feats = ["inc_trans_cs", "log_pop"]

df, feat_names = create_features_and_targets(
df = df,
incl_level_feats=self.model_config.incl_level_feats,
max_horizon=run_config.max_horizon,
curr_feat_names=["inc_trans_cs", "season_week", "log_pop"])
curr_feat_names=init_feats)

# keep only rows that are in-season
df = df.query("season_week >= 5 and season_week <= 45")
if run_config.disease == "flu":
df = df.query("season_week >= 5 and season_week <= 45")

# "test set" df used to generate look-ahead predictions
df_test = df.loc[df.wk_end_date == df.wk_end_date.max()] \
Expand Down

0 comments on commit 8323af1

Please sign in to comment.