Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support prediction_filter_length when plotting #137

Open
lycheesodaa opened this issue Sep 16, 2024 · 1 comment
Open

Support prediction_filter_length when plotting #137

lycheesodaa opened this issue Sep 16, 2024 · 1 comment

Comments

@lycheesodaa
Copy link

lycheesodaa commented Sep 16, 2024

Thanks for amazing work so far.

It's a bit of a minor issue, but I'm running into an error when plotting with the plot_predictions() function after having passed prediction_filter_length=48 to the model instantiation. I'm following the example notebook, with the exact same code for zero-shot forecasting, but I get this:

Traceback (most recent call last):
  File "home/granite-tsfm/run_demand.py", line 86, in <module>
    zeroshot_eval(
  File "home/granite-tsfm/run_demand.py", line 78, in zeroshot_eval
    plot_predictions(
  File "home/granite-tsfm/tsfm_public/toolkit/visualization.py", line 379, in plot_predictions
    axs[i].plot(ts_y, y, label="True", linestyle="-", color="blue", linewidth=2)
  File "home/granite-tsfm/venv/lib/python3.11/site-packages/matplotlib/axes/_axes.py", line 1779, in plot
    lines = [*self._get_lines(self, *args, data=data, **kwargs)]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "home/granite-tsfm/venv/lib/python3.11/site-packages/matplotlib/axes/_base.py", line 296, in __call__
    yield from self._plot_args(
               ^^^^^^^^^^^^^^^^
  File "home/granite-tsfm/venv/lib/python3.11/site-packages/matplotlib/axes/_base.py", line 486, in _plot_args
    raise ValueError(f"x and y must have same first dimension, but "
ValueError: x and y must have same first dimension, but have shapes (144,) and (192,)

It seems like the plot_predictions() function has yet to support reducing the horizon length, and editing the following line seems to fix the issue:

        else:
            batch = dset[index]
            ts_y_hat = np.arange(plot_context, plot_context + prediction_length)
            y_hat = predictions_subset[i]

            ts_y = np.arange(plot_context + prediction_length)
            y = batch["future_values"][:prediction_length, channel].squeeze().numpy() # <- edited line 369
            x = batch["past_values"][-plot_context:, channel].squeeze().numpy()
            y = np.concatenate((x, y), axis=0)
            border = plot_context
            plot_title = f"Example {indices[i]}"

I have only experimented with the case where dset and model are provided to the function, not the other cases, where it might be working just fine.


Edit: it seems like passing a string for the channel argument also doesn't seem to work in this case either, but I presume that's still an unimplemented feature?

@wgifford
Copy link
Collaborator

@lycheesodaa Thanks for the issue -- we will look into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants