Skip to content

Commit

Permalink
Merge pull request #98 from ibm-granite/ttm_cleanup
Browse files Browse the repository at this point in the history
Clean up TTM utils and update notebooks
  • Loading branch information
wgifford authored Aug 12, 2024
2 parents c9bbacf + 9d9e5e5 commit 7ca28ee
Show file tree
Hide file tree
Showing 38 changed files with 8,760 additions and 2,686 deletions.
21 changes: 8 additions & 13 deletions notebooks/hfdemo/patch_tsmixer_blog.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,27 @@
"source": [
"# Standard\n",
"import os\n",
"import random\n",
"\n",
"# supress some warnings\n",
"import warnings\n",
"\n",
"import pandas as pd\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
" PatchTSMixerConfig,\n",
" PatchTSMixerForPrediction,\n",
" set_seed,\n",
" Trainer,\n",
" TrainingArguments,\n",
" set_seed,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
"from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n",
"from tsfm_public.toolkit.util import select_by_index\n",
"\n",
"# supress some warnings\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\", module=\"torch\")"
]
Expand Down Expand Up @@ -916,9 +915,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer_4/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer_4/electricity/model/pretrain/\")\n",
"print(\"Done\")"
]
},
Expand Down Expand Up @@ -1323,9 +1320,7 @@
],
"source": [
"# Reload the model\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer_4/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer_4/electricity/model/pretrain/\")\n",
"finetune_forecast_trainer = Trainer(\n",
" model=finetune_forecast_model,\n",
" args=finetune_forecast_args,\n",
Expand Down
12 changes: 5 additions & 7 deletions notebooks/hfdemo/patch_tsmixer_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
],
"source": [
"# Standard\n",
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -43,9 +46,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down Expand Up @@ -321,9 +321,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"inference_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"ibm-granite/granite-timeseries-patchtsmixer\"\n",
")\n",
"inference_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"ibm-granite/granite-timeseries-patchtsmixer\")\n",
"print(\"Done\")"
]
},
Expand Down
15 changes: 6 additions & 9 deletions notebooks/hfdemo/patch_tsmixer_transfer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -57,9 +61,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down Expand Up @@ -923,9 +924,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n",
"print(\"Done\")"
]
},
Expand Down Expand Up @@ -1415,9 +1414,7 @@
],
"source": [
"# Reload the model\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n",
"finetune_forecast_trainer = Trainer(\n",
" model=finetune_forecast_model,\n",
" args=finetune_forecast_args,\n",
Expand Down
12 changes: 5 additions & 7 deletions notebooks/hfdemo/patch_tst_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
"outputs": [],
"source": [
"# Standard\n",
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -30,9 +33,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down Expand Up @@ -309,9 +309,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"inference_forecast_model = PatchTSTForPrediction.from_pretrained(\n",
" \"ibm-granite/granite-timeseries-patchtst\"\n",
")\n",
"inference_forecast_model = PatchTSTForPrediction.from_pretrained(\"ibm-granite/granite-timeseries-patchtst\")\n",
"print(\"Done\")"
]
},
Expand Down
7 changes: 4 additions & 3 deletions notebooks/hfdemo/patch_tst_transfer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -43,9 +47,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down
Loading

0 comments on commit 7ca28ee

Please sign in to comment.