Skip to content

Commit

Permalink
use weights_only to True (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Sep 19, 2024
1 parent f13ac8c commit d3a3978
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion shimmer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def migrate_model(ckpt_path: str | PathLike, **torch_load_kwargs):
torch_load_kwargs: additional args given to torch.load.
"""
ckpt_path = Path(ckpt_path)
ckpt = torch.load(ckpt_path, **torch_load_kwargs)
default_torch_kwargs: dict[str, Any] = {"weights_only": True}
default_torch_kwargs.update(torch_load_kwargs)
ckpt = torch.load(ckpt_path, **default_torch_kwargs)
new_ckpt, done_migrations = migrate_from_folder(ckpt, MIGRATION_DIR)
done_migration_log = ", ".join(map(lambda x: x.name, done_migrations))
print(f"Migrating: {done_migration_log}")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ckpt_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def test_ckpt_migration_2_domains():
old_ckpt_path = here / "data" / "old_gw_2_domains.ckpt"
old_ckpt = torch.load(old_ckpt_path)
old_ckpt = torch.load(old_ckpt_path, weights_only=True)
new_ckpt, done_migrations = migrate_from_folder(old_ckpt, MIGRATION_DIR)

old_keys = set(old_ckpt["state_dict"].keys())
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_ckpt_migration_2_domains():

def test_ckpt_migration_gw():
old_ckpt_path = here / "data" / "old_gw.ckpt"
old_ckpt = torch.load(old_ckpt_path)
old_ckpt = torch.load(old_ckpt_path, weights_only=True)
new_ckpt, done_migrations = migrate_from_folder(old_ckpt, MIGRATION_DIR)

old_keys = set(old_ckpt["state_dict"].keys())
Expand Down

0 comments on commit d3a3978

Please sign in to comment.