Skip to content

Commit

Permalink
feat: Deprecate activation functions for GRU (#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcopeix authored Nov 21, 2024
1 parent 642ced4 commit 15f061f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
26 changes: 22 additions & 4 deletions nbs/models.gru.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%set_env PYTORCH_ENABLE_MPS_FALLBACK=1"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -70,6 +79,7 @@
"outputs": [],
"source": [
"#| export\n",
"import warnings\n",
"from typing import Optional\n",
"\n",
"import torch\n",
Expand All @@ -91,7 +101,7 @@
" \"\"\" GRU\n",
"\n",
" Multi Layer Recurrent Network with Gated Units (GRU), and\n",
" MLP decoder. The network has `tanh` or `relu` non-linearities, it is trained \n",
" MLP decoder. The network has non-linear activation functions, it is trained \n",
" using ADAM stochastic gradient descent. The network accepts static, historic \n",
" and future exogenous data, flattens the inputs.\n",
"\n",
Expand All @@ -101,7 +111,7 @@
" `inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.<br>\n",
" `encoder_n_layers`: int=2, number of layers for the GRU.<br>\n",
" `encoder_hidden_size`: int=200, units for the GRU's hidden state size.<br>\n",
" `encoder_activation`: str=`tanh`, type of GRU activation from `tanh` or `relu`.<br>\n",
" `encoder_activation`: Optional[str]=None, Deprecated. Activation function in GRU is frozen in PyTorch.<br>\n",
" `encoder_bias`: bool=True, whether or not to use biases b_ih, b_hh within GRU units.<br>\n",
" `encoder_dropout`: float=0., dropout regularization applied to GRU outputs.<br>\n",
" `context_size`: int=10, size of context vector for each timestamp on the forecasting window.<br>\n",
Expand Down Expand Up @@ -143,7 +153,7 @@
" inference_input_size: int = -1,\n",
" encoder_n_layers: int = 2,\n",
" encoder_hidden_size: int = 200,\n",
" encoder_activation: str = 'tanh',\n",
" encoder_activation: Optional[str] = None,\n",
" encoder_bias: bool = True,\n",
" encoder_dropout: float = 0.,\n",
" context_size: int = 10,\n",
Expand Down Expand Up @@ -199,6 +209,14 @@
" **trainer_kwargs\n",
" )\n",
"\n",
" if encoder_activation is not None:\n",
" warnings.warn(\n",
" \"The 'encoder_activation' argument is deprecated and will be removed in \"\n",
" \"future versions. The activation function in GRU is frozen in PyTorch and \"\n",
" \"it cannot be modified.\",\n",
" DeprecationWarning,\n",
" )\n",
"\n",
" # RNN\n",
" self.encoder_n_layers = encoder_n_layers\n",
" self.encoder_hidden_size = encoder_hidden_size\n",
Expand Down Expand Up @@ -322,7 +340,7 @@
"import matplotlib.pyplot as plt\n",
"\n",
"from neuralforecast import NeuralForecast\n",
"from neuralforecast.models import GRU\n",
"# from neuralforecast.models import GRU\n",
"from neuralforecast.losses.pytorch import DistributionLoss\n",
"from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n",
"\n",
Expand Down
19 changes: 14 additions & 5 deletions neuralforecast/models/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# %% auto 0
__all__ = ['GRU']

# %% ../../nbs/models.gru.ipynb 6
# %% ../../nbs/models.gru.ipynb 7
import warnings
from typing import Optional

import torch
Expand All @@ -13,12 +14,12 @@
from ..common._base_recurrent import BaseRecurrent
from ..common._modules import MLP

# %% ../../nbs/models.gru.ipynb 7
# %% ../../nbs/models.gru.ipynb 8
class GRU(BaseRecurrent):
"""GRU
Multi Layer Recurrent Network with Gated Units (GRU), and
MLP decoder. The network has `tanh` or `relu` non-linearities, it is trained
MLP decoder. The network has non-linear activation functions, it is trained
using ADAM stochastic gradient descent. The network accepts static, historic
and future exogenous data, flattens the inputs.
Expand All @@ -28,7 +29,7 @@ class GRU(BaseRecurrent):
`inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.<br>
`encoder_n_layers`: int=2, number of layers for the GRU.<br>
`encoder_hidden_size`: int=200, units for the GRU's hidden state size.<br>
`encoder_activation`: str=`tanh`, type of GRU activation from `tanh` or `relu`.<br>
`encoder_activation`: Optional[str]=None, Deprecated. Activation function in GRU is frozen in PyTorch.<br>
`encoder_bias`: bool=True, whether or not to use biases b_ih, b_hh within GRU units.<br>
`encoder_dropout`: float=0., dropout regularization applied to GRU outputs.<br>
`context_size`: int=10, size of context vector for each timestamp on the forecasting window.<br>
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
inference_input_size: int = -1,
encoder_n_layers: int = 2,
encoder_hidden_size: int = 200,
encoder_activation: str = "tanh",
encoder_activation: Optional[str] = None,
encoder_bias: bool = True,
encoder_dropout: float = 0.0,
context_size: int = 10,
Expand Down Expand Up @@ -129,6 +130,14 @@ def __init__(
**trainer_kwargs
)

if encoder_activation is not None:
warnings.warn(
"The 'encoder_activation' argument is deprecated and will be removed in "
"future versions. The activation function in GRU is frozen in PyTorch and "
"it cannot be modified.",
DeprecationWarning,
)

# RNN
self.encoder_n_layers = encoder_n_layers
self.encoder_hidden_size = encoder_hidden_size
Expand Down

0 comments on commit 15f061f

Please sign in to comment.