Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install error "ModuleNotFoundError: No module named 'jax'" #172

Open
linpingta opened this issue Oct 28, 2024 · 2 comments
Open

Install error "ModuleNotFoundError: No module named 'jax'" #172

linpingta opened this issue Oct 28, 2024 · 2 comments

Comments

@linpingta
Copy link

Hi guys,

If I simply install the lib with "pip install timesfm" and try the example code described in https://huggingface.co/google/timesfm-1.0-200m:

import timesfm

tfm = timesfm.TimesFm(
    context_len=14,
    horizon_len=7,
    input_patch_len=32,
    output_patch_len=128,
    num_layers=20,
    model_dims=1280
)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

It will return an error:

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/timesfm/timesfm_base.py:27
     23 import pandas as pd
     25 from utilsforecast.processing import make_future_dataframe
---> 27 from . import xreg_lib
     29 Category = xreg_lib.Category
     30 XRegMode = xreg_lib.XRegMode

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/timesfm/xreg_lib.py:20
     17 import math
     18 from typing import Any, Iterable, Literal, Mapping, Sequence
---> 20 import jax
     21 import jax.numpy as jnp
     22 import numpy as np

ModuleNotFoundError: No module named 'jax'

Then I try to install jax manually, but it will meet another error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 3
      1 import timesfm
----> 3 tfm = timesfm.TimesFm(
      4     context_len=14,
      5     horizon_len=7,
      6     input_patch_len=32,
      7     output_patch_len=128,
      8     num_layers=20,
      9     model_dims=1280
     10 )
     11 tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

TypeError: TimesFmBase.__init__() got an unexpected keyword argument 'context_len'

Could you help for what's wrong happened? thanks.

@aleksmaksimovic
Copy link

aleksmaksimovic commented Oct 28, 2024

Create a new venv in python 3.10.x and try to install it that way:

pip install timesfm[pax]

@linpingta
Copy link
Author

Thanks @aleksmaksimovic , I think the example inside https://huggingface.co/google/timesfm-1.0-200m may need update because TimesFm doesn't support directly parameter input, but should wrap it with TimesFmHparams.

From code:

class TimesFmBase:
  """Base TimesFM forecast API for inference.

  This class is the scaffolding for calling TimesFM forecast. To properly use:
    1. Create an instance with the correct hyperparameters of a TimesFM model.
    2. Call `load_from_checkpoint` to load a compatible checkpoint.
    3. Call `forecast` for inference.
  """

  def _logging(self, s):
    print(s)

  def __post_init__(self) -> None:
    """Additional initialization for subclasses before checkpoint loading."""
    pass

  def __init__(self, hparams: TimesFmHparams,
               checkpoint: TimesFmCheckpoint) -> None:
    """Initializes the TimesFM forecast API.

Correct me if I am wrong, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants