diff --git a/nbs/models.gru.ipynb b/nbs/models.gru.ipynb
index 7cb14f21..7f0608a5 100644
--- a/nbs/models.gru.ipynb
+++ b/nbs/models.gru.ipynb
@@ -1,5 +1,14 @@
{
"cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%set_env PYTORCH_ENABLE_MPS_FALLBACK=1"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -70,6 +79,7 @@
"outputs": [],
"source": [
"#| export\n",
+ "import warnings\n",
"from typing import Optional\n",
"\n",
"import torch\n",
@@ -91,7 +101,7 @@
" \"\"\" GRU\n",
"\n",
" Multi Layer Recurrent Network with Gated Units (GRU), and\n",
- " MLP decoder. The network has `tanh` or `relu` non-linearities, it is trained \n",
+ " MLP decoder. The network has non-linear activation functions, it is trained \n",
" using ADAM stochastic gradient descent. The network accepts static, historic \n",
" and future exogenous data, flattens the inputs.\n",
"\n",
@@ -101,7 +111,7 @@
" `inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.
\n",
" `encoder_n_layers`: int=2, number of layers for the GRU.
\n",
" `encoder_hidden_size`: int=200, units for the GRU's hidden state size.
\n",
- " `encoder_activation`: str=`tanh`, type of GRU activation from `tanh` or `relu`.
\n",
+ " `encoder_activation`: Optional[str]=None, Deprecated. Activation function in GRU is frozen in PyTorch.
\n",
" `encoder_bias`: bool=True, whether or not to use biases b_ih, b_hh within GRU units.
\n",
" `encoder_dropout`: float=0., dropout regularization applied to GRU outputs.
\n",
" `context_size`: int=10, size of context vector for each timestamp on the forecasting window.
\n",
@@ -143,7 +153,7 @@
" inference_input_size: int = -1,\n",
" encoder_n_layers: int = 2,\n",
" encoder_hidden_size: int = 200,\n",
- " encoder_activation: str = 'tanh',\n",
+ " encoder_activation: Optional[str] = None,\n",
" encoder_bias: bool = True,\n",
" encoder_dropout: float = 0.,\n",
" context_size: int = 10,\n",
@@ -199,6 +209,14 @@
" **trainer_kwargs\n",
" )\n",
"\n",
+ " if encoder_activation is not None:\n",
+ " warnings.warn(\n",
+ " \"The 'encoder_activation' argument is deprecated and will be removed in \"\n",
+ " \"future versions. The activation function in GRU is frozen in PyTorch and \"\n",
+ " \"it cannot be modified.\",\n",
+ " DeprecationWarning,\n",
+ " )\n",
+ "\n",
" # RNN\n",
" self.encoder_n_layers = encoder_n_layers\n",
" self.encoder_hidden_size = encoder_hidden_size\n",
@@ -322,7 +340,7 @@
"import matplotlib.pyplot as plt\n",
"\n",
"from neuralforecast import NeuralForecast\n",
- "from neuralforecast.models import GRU\n",
+ "# from neuralforecast.models import GRU\n",
"from neuralforecast.losses.pytorch import DistributionLoss\n",
"from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n",
"\n",
diff --git a/neuralforecast/models/gru.py b/neuralforecast/models/gru.py
index 900eac16..9a6d9232 100644
--- a/neuralforecast/models/gru.py
+++ b/neuralforecast/models/gru.py
@@ -3,7 +3,8 @@
# %% auto 0
__all__ = ['GRU']
-# %% ../../nbs/models.gru.ipynb 6
+# %% ../../nbs/models.gru.ipynb 7
+import warnings
from typing import Optional
import torch
@@ -13,12 +14,12 @@
from ..common._base_recurrent import BaseRecurrent
from ..common._modules import MLP
-# %% ../../nbs/models.gru.ipynb 7
+# %% ../../nbs/models.gru.ipynb 8
class GRU(BaseRecurrent):
"""GRU
Multi Layer Recurrent Network with Gated Units (GRU), and
- MLP decoder. The network has `tanh` or `relu` non-linearities, it is trained
+ MLP decoder. The network has non-linear activation functions, it is trained
using ADAM stochastic gradient descent. The network accepts static, historic
and future exogenous data, flattens the inputs.
@@ -28,7 +29,7 @@ class GRU(BaseRecurrent):
`inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.
`encoder_n_layers`: int=2, number of layers for the GRU.
`encoder_hidden_size`: int=200, units for the GRU's hidden state size.
- `encoder_activation`: str=`tanh`, type of GRU activation from `tanh` or `relu`.
+ `encoder_activation`: Optional[str]=None, Deprecated. Activation function in GRU is frozen in PyTorch.
`encoder_bias`: bool=True, whether or not to use biases b_ih, b_hh within GRU units.
`encoder_dropout`: float=0., dropout regularization applied to GRU outputs.
`context_size`: int=10, size of context vector for each timestamp on the forecasting window.
@@ -72,7 +73,7 @@ def __init__(
inference_input_size: int = -1,
encoder_n_layers: int = 2,
encoder_hidden_size: int = 200,
- encoder_activation: str = "tanh",
+ encoder_activation: Optional[str] = None,
encoder_bias: bool = True,
encoder_dropout: float = 0.0,
context_size: int = 10,
@@ -129,6 +130,14 @@ def __init__(
**trainer_kwargs
)
+ if encoder_activation is not None:
+ warnings.warn(
+ "The 'encoder_activation' argument is deprecated and will be removed in "
+ "future versions. The activation function in GRU is frozen in PyTorch and "
+ "it cannot be modified.",
+ DeprecationWarning,
+ )
+
# RNN
self.encoder_n_layers = encoder_n_layers
self.encoder_hidden_size = encoder_hidden_size