-
Notifications
You must be signed in to change notification settings - Fork 631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle shared layers in save_torch_state_dict
+ add save_torch_model
#2373
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`]. | ||
The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported. | ||
|
||
If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the point of mentioning this but also I think for the Torch community, it's fairly standard practice to ship the model classes and their state dictionaries (i.e., the parameters) separately unlike TensorFlow/Keras, for example.
) | ||
|
||
|
||
def get_tensor_size(tensor: "tf.Tensor") -> int: | ||
def get_tf_storage_size(tensor: "tf.Tensor") -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have all the equivalent torch methods for TensorFlow? Or is that not necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet no. Let's build for torch first and then expand to TF after if needed. For now for TF we have the logic to split a state dict into shards but nothing to save to disk.
"metadata": {**state_dict_split.metadata, **metadata}, | ||
"weight_map": state_dict_split.tensor_to_filename, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be any sanity check on the additional metadata
if not already done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metadata
is at the discretion of the frameworks that will use it (transformers/diffusers/accelerate). In practice, I don't think it'll be much used. In any case, we can't really do sanity check since we are supposed to accept anything that is jsonable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understanding the full scope of the PR is still a little farfetched for me but I left some clarification questions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much needed change! The API w/ contiguous looks ok to me.
Thanks for the changes! Let's see if it breaks in the wild but from a quick check it looks good.
Thanks for the reviews! Let's ship it yes 😄 |
* Use extended path on Windows when downloading to local dir Change the path of the local dir to an extended path by prepending "\\?\" to the absolute path, when the absolute path is longer than 255 characters on Windows. Also fixed a small typo. * Use extended path on Windows when downloading to local dir Change the path of the local dir to an extended path by prepending "\\?\" to the absolute path, when the absolute path is longer than 255 characters on Windows. Also fixed a small typo. * Move path handling to `get_local_download_paths()` for robustness On Windows we check the length of `lock_path` and if it is longer than 255 characters we prepend the `\\?\` prefix to all paths if it does not already exist. We only need to check the length of `lock_path` because it is guaranteed to be the longest path. * `safetensors[torch]` (#2371) * Fix token=False not respected in file download (#2386) * Fix token=False not respected in file download * lint * Handle shared layers in `save_torch_state_dict` + add `save_torch_model` (#2373) * Handle shared layers in save_torch_state_dict + save_torch_model + some helpers * fix pytest rerun * more reruns * Support `expand` parameter in `xxx_info` and `list_xxxs` (model/dataset/Space) (#2333) * First draft to support `expand` parameter for models * add expand support for dataset * add expand support for Space * Use extended path on Windows when downloading to local dir Change the path of the local dir to an extended path by prepending "\\?\" to the absolute path, when the absolute path is longer than 255 characters on Windows. Also fixed a small typo. * Move path handling to `get_local_download_paths()` for robustness On Windows we check the length of `lock_path` and if it is longer than 255 characters we prepend the `\\?\` prefix to all paths if it does not already exist. We only need to check the length of `lock_path` because it is guaranteed to be the longest path. * Use extended path on Windows when downloading to local dir Change the path of the local dir to an extended path by prepending "\\?\" to the absolute path, when the absolute path is longer than 255 characters on Windows. Also fixed a small typo. * Removed old path handling * Reorder path check; add tests * Skip test if opn Windows The test now shows up a `skipped` if executed on a non-Windows machine Co-authored-by: Lucain <[email protected]> * Fix indentation for test_local_folder.py * Fix code style --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Lucain <[email protected]>
Partially resolve #2065.
Follow-up PR after #2314.
In #2314, we introduce
save_torch_state_dict
. This new PR:save_torch_model
to directly save a torchnn.Module
get_tf_storage_size
/get_torch_storage_size
and make them public + documentedA last follow-up PR should had
load_torch_state_dict
/load_torch_model
helpers as well to correctly reload those files, including the shared layers.I'm pinging transformers/accelerate/diffusers cores maintainers for visibility as well. Feel free to comment if someone should be done differently.