Skip to content

Commit

Permalink
fix_base_model
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Nov 19, 2024
1 parent c529ced commit b2c7691
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
" optimizer_kwargs: Union[Dict, None] = None,\n",
" lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None] = None,\n",
" lr_scheduler_kwargs: Union[Dict, None] = None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs,\n",
" ):\n",
" super().__init__()\n",
Expand Down Expand Up @@ -364,6 +365,7 @@
"\n",
" # DataModule arguments\n",
" self.num_workers_loader = num_workers_loader\n",
" self.dataloader_kwargs = dataloader_kwargs\n",
" self.drop_last_loader = drop_last_loader\n",
" # used by on_validation_epoch_end hook\n",
" self.validation_step_outputs: List = []\n",
Expand Down
2 changes: 2 additions & 0 deletions neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
optimizer_kwargs: Union[Dict, None] = None,
lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None] = None,
lr_scheduler_kwargs: Union[Dict, None] = None,
dataloader_kwargs=None,
**trainer_kwargs,
):
super().__init__()
Expand Down Expand Up @@ -352,6 +353,7 @@ def __init__(

# DataModule arguments
self.num_workers_loader = num_workers_loader
self.dataloader_kwargs = dataloader_kwargs
self.drop_last_loader = drop_last_loader
# used by on_validation_epoch_end hook
self.validation_step_outputs: List = []
Expand Down

0 comments on commit b2c7691

Please sign in to comment.