Skip to content

Commit

Permalink
Merge branch 'main' into 1169-tft-tftfeature_importances-returns-erro…
Browse files Browse the repository at this point in the history
…r-when-there-is-no-futr_exog_list
  • Loading branch information
elephaint authored Oct 10, 2024
2 parents 6d749b6 + 137e3a8 commit e2a56e3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
6 changes: 3 additions & 3 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
"commit": false,
"contributors": [
{
"login": "FedericoGarza",
"name": "fede",
"login": "AzulGarza",
"name": "azul",
"avatar_url": "https://avatars.githubusercontent.com/u/10517170?v=4",
"profile": "https://github.com/FedericoGarza",
"profile": "https://github.com/AzulGarza",
"contributions": [
"code",
"maintenance"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<table>
<tbody>
<tr>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/AzulGarza"><img src="https://avatars.githubusercontent.com/u/10517170?v=4?s=100" width="100px;" alt="Azul"/><br /><sub><b>fede</b></sub></a><br /><a href="https://github.com/Nixtla/neuralforecast/commits?author=AzulGarza" title="Code">💻</a> <a href="#maintenance-AzulGarza" title="Maintenance">🚧</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/AzulGarza"><img src="https://avatars.githubusercontent.com/u/10517170?v=4?s=100" width="100px;" alt="azul"/><br /><sub><b>azul</b></sub></a><br /><a href="https://github.com/Nixtla/neuralforecast/commits?author=AzulGarza" title="Code">💻</a> <a href="#maintenance-AzulGarza" title="Maintenance">🚧</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/cchallu"><img src="https://avatars.githubusercontent.com/u/31133398?v=4?s=100" width="100px;" alt="Cristian Challu"/><br /><sub><b>Cristian Challu</b></sub></a><br /><a href="https://github.com/Nixtla/neuralforecast/commits?author=cchallu" title="Code">💻</a> <a href="#maintenance-cchallu" title="Maintenance">🚧</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/jmoralez"><img src="https://avatars.githubusercontent.com/u/8473587?v=4?s=100" width="100px;" alt="José Morales"/><br /><sub><b>José Morales</b></sub></a><br /><a href="https://github.com/Nixtla/neuralforecast/commits?author=jmoralez" title="Code">💻</a> <a href="#maintenance-jmoralez" title="Maintenance">🚧</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/mergenthaler"><img src="https://avatars.githubusercontent.com/u/4086186?v=4?s=100" width="100px;" alt="mergenthaler"/><br /><sub><b>mergenthaler</b></sub></a><br /><a href="https://github.com/Nixtla/neuralforecast/commits?author=mergenthaler" title="Documentation">📖</a> <a href="https://github.com/Nixtla/neuralforecast/commits?author=mergenthaler" title="Code">💻</a></td>
Expand Down
8 changes: 7 additions & 1 deletion nbs/models.mlpmultivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@
" x = torch.cat(( x, futr_exog.reshape(batch_size, -1) ), dim=1)\n",
"\n",
" if self.stat_exog_size > 0:\n",
" x = torch.cat(( x, stat_exog.reshape(batch_size, -1) ), dim=1)\n",
" stat_exog = stat_exog.reshape(-1) # [N, S] -> [N * S]\n",
" stat_exog = stat_exog.unsqueeze(0)\\\n",
" .repeat(batch_size, \n",
" 1) # [N * S] -> [B, N * S] \n",
" x = torch.cat((x, stat_exog), dim=1)\n",
"\n",
" for layer in self.mlp:\n",
" x = torch.relu(layer(x))\n",
Expand Down Expand Up @@ -362,6 +366,8 @@
"model = MLPMultivariate(h=12, \n",
" input_size=24,\n",
" n_series=2,\n",
" stat_exog_list=['airline1'],\n",
" futr_exog_list=['trend'], \n",
" loss = MAE(),\n",
" scaler_type='robust',\n",
" learning_rate=1e-3,\n",
Expand Down
6 changes: 5 additions & 1 deletion neuralforecast/models/mlpmultivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def forward(self, windows_batch):
x = torch.cat((x, futr_exog.reshape(batch_size, -1)), dim=1)

if self.stat_exog_size > 0:
x = torch.cat((x, stat_exog.reshape(batch_size, -1)), dim=1)
stat_exog = stat_exog.reshape(-1) # [N, S] -> [N * S]
stat_exog = stat_exog.unsqueeze(0).repeat(
batch_size, 1
) # [N * S] -> [B, N * S]
x = torch.cat((x, stat_exog), dim=1)

for layer in self.mlp:
x = torch.relu(layer(x))
Expand Down

0 comments on commit e2a56e3

Please sign in to comment.