From 170f9734b70532e3bb89be0242d75257b06392e4 Mon Sep 17 00:00:00 2001 From: Olivier Sprangers <45119856+elephaint@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:07:48 +0200 Subject: [PATCH 1/2] [FIX] MLPMultivariate incorrect static_exog parsing (#1170) --- nbs/models.mlpmultivariate.ipynb | 8 +++++++- neuralforecast/models/mlpmultivariate.py | 6 +++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/nbs/models.mlpmultivariate.ipynb b/nbs/models.mlpmultivariate.ipynb index abfefc1e6..ad1a08f47 100644 --- a/nbs/models.mlpmultivariate.ipynb +++ b/nbs/models.mlpmultivariate.ipynb @@ -210,7 +210,11 @@ " x = torch.cat(( x, futr_exog.reshape(batch_size, -1) ), dim=1)\n", "\n", " if self.stat_exog_size > 0:\n", - " x = torch.cat(( x, stat_exog.reshape(batch_size, -1) ), dim=1)\n", + " stat_exog = stat_exog.reshape(-1) # [N, S] -> [N * S]\n", + " stat_exog = stat_exog.unsqueeze(0)\\\n", + " .repeat(batch_size, \n", + " 1) # [N * S] -> [B, N * S] \n", + " x = torch.cat((x, stat_exog), dim=1)\n", "\n", " for layer in self.mlp:\n", " x = torch.relu(layer(x))\n", @@ -362,6 +366,8 @@ "model = MLPMultivariate(h=12, \n", " input_size=24,\n", " n_series=2,\n", + " stat_exog_list=['airline1'],\n", + " futr_exog_list=['trend'], \n", " loss = MAE(),\n", " scaler_type='robust',\n", " learning_rate=1e-3,\n", diff --git a/neuralforecast/models/mlpmultivariate.py b/neuralforecast/models/mlpmultivariate.py index 19cb15eea..4682174e0 100644 --- a/neuralforecast/models/mlpmultivariate.py +++ b/neuralforecast/models/mlpmultivariate.py @@ -160,7 +160,11 @@ def forward(self, windows_batch): x = torch.cat((x, futr_exog.reshape(batch_size, -1)), dim=1) if self.stat_exog_size > 0: - x = torch.cat((x, stat_exog.reshape(batch_size, -1)), dim=1) + stat_exog = stat_exog.reshape(-1) # [N, S] -> [N * S] + stat_exog = stat_exog.unsqueeze(0).repeat( + batch_size, 1 + ) # [N * S] -> [B, N * S] + x = torch.cat((x, stat_exog), dim=1) for layer in self.mlp: x = torch.relu(layer(x)) From 137e3a88220a84ec5f7cce04429be68c5e62bbe0 Mon Sep 17 00:00:00 2001 From: AzulGarza Date: Wed, 9 Oct 2024 23:57:48 -0700 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=92=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .all-contributorsrc | 6 +++--- README.md | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index 75d075f0f..b0c0d1d76 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -6,10 +6,10 @@ "commit": false, "contributors": [ { - "login": "FedericoGarza", - "name": "fede", + "login": "AzulGarza", + "name": "azul", "avatar_url": "https://avatars.githubusercontent.com/u/10517170?v=4", - "profile": "https://github.com/FedericoGarza", + "profile": "https://github.com/AzulGarza", "contributions": [ "code", "maintenance" diff --git a/README.md b/README.md index 6039ea086..d132c8a1b 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d - +
Azul
fede

💻 🚧
azul
azul

💻 🚧
Cristian Challu
Cristian Challu

💻 🚧
José Morales
José Morales

💻 🚧
mergenthaler
mergenthaler

📖 💻