Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpoint可以下载到本地来加载吗? #5

Open
keochoi opened this issue May 29, 2024 · 5 comments
Open

checkpoint可以下载到本地来加载吗? #5

keochoi opened this issue May 29, 2024 · 5 comments

Comments

@keochoi
Copy link

keochoi commented May 29, 2024

google/timesfm-1.0-200m这个可以提前从huggingface下载到本地,从本地加载吗?

tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") # 这句应该怎么改?

@lhw828
Copy link
Owner

lhw828 commented May 29, 2024

当然可以。从Hugging Face Hub预先下载模型到本地,然后从本地加载是一种常见做法,尤其在你可能需要多次使用该模型或网络不稳定的情况下。以下是大致步骤:

下载模型到本地
手动下载:你可以直接从Hugging Face Model Hub的网页上找到该模型页面,通常会有“Download”或“Model weights”之类的按钮,点击即可下载模型的文件或压缩包。对于google/timesfm-1.0-200m这样的模型,你可以在Hugging Face的网站上搜索该模型名,找到对应页面并下载。

使用CLI下载:另外,你也可以使用Hugging Face的命令行工具huggingface-cli来下载模型。首先确保你安装了这个工具,然后执行如下命令:

Bash
huggingface-cli login # 登录你的Hugging Face账号,如果还没登录的话
huggingface-cli repo download google/timesfm-1.0-200m --cache_dir ./models # 将模型下载到本地的./models目录
这里--cache_dir指定了模型下载的本地目录,你可以根据需要修改。

从本地加载模型
下载完成后,你可以在代码中指定模型的本地路径来加载它。假设模型下载到了./models/google_timesfm-1.0-200m目录下,你可以这样修改加载模型的代码:

Python
from transformers import AutoModel

指定本地路径加载模型

local_model_path = "./models/google_timesfm-1.0-200m"
model = AutoModel.from_pretrained(local_model_path)

或者对于timesfm模型,如果你使用的是特定的加载方式,可能是这样的

注意:下面的代码是示意性的,具体加载方法取决于timesfm模型实际的加载逻辑

tfm = TimesFm(context_len=context_len, horizon_len=horizon_len, input_patch_len=32, output_patch_len=128, num_layers=20, model_dims=1280, backend='cpu')
tfm.load_local(local_model_path) # 假设timesfm有类似load_local这样的方法来从本地路径加载模型

@lhw828
Copy link
Owner

lhw828 commented May 29, 2024

这是ai给的回复。

@zhaokui001
Copy link

google/timesfm-1.0-200m这个可以提前从huggingface下载到本地,从本地加载吗?

tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") # 这句应该怎么改?

应该改为

local_model_path = "/home/dedong/huggingface/timesfm/checkpoints"
tfm.load_from_checkpoint(checkpoint_path=local_model_path)
我将模型本地下载之后,按上面的方法导入是能够成功加载模型并进行预测的

@ham114
Copy link

ham114 commented Sep 14, 2024

google/timesfm-1.0-200m这个可以提前从huggingface下载到本地,从本地加载吗?
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") # 这句应该怎么改?

应该改为

local_model_path = "/home/dedong/huggingface/timesfm/checkpoints"
tfm.load_from_checkpoint(checkpoint_path=local_model_path)
我将模型本地下载之后,按上面的方法导入是能够成功加载模型并进行预测的

WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured CheckpointManager using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume train_state is unpadded.
ERROR:absl:For checkpoint version > 1.0, we require users to provide
train_state_unpadded_shape_dtype_struct during checkpoint
saving/restoring, to avoid potential silent bugs when loading
checkpoints to incompatible unpadded shapes of TrainState.
Restored checkpoint in 3.01 seconds.
Jitting decoding.
Jitted decoding in 52.32 seconds.
这些错误和提示信息怎么解决,影响使用吗

@lhw828
Copy link
Owner

lhw828 commented Sep 19, 2024

google/timesfm-1.0-200m这个可以提前从huggingface下载到本地,从本地加载吗?
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") # 这句应该怎么改?

应该改为

local_model_path = "/home/dedong/huggingface/timesfm/checkpoints"
tfm.load_from_checkpoint(checkpoint_path=local_model_path)
我将模型本地下载之后,按上面的方法导入是能够成功加载模型并进行预测的

WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'> WARNING:absl:Configured CheckpointManager using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024. WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume train_state is unpadded. ERROR:absl:For checkpoint version > 1.0, we require users to provide train_state_unpadded_shape_dtype_struct during checkpoint saving/restoring, to avoid potential silent bugs when loading checkpoints to incompatible unpadded shapes of TrainState. Restored checkpoint in 3.01 seconds. Jitting decoding. Jitted decoding in 52.32 seconds. 这些错误和提示信息怎么解决,影响使用吗

这个可以忽略

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants