From 5cd3c832530a4ce255c7dc3d181652fa896dc1ac Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Wed, 4 Sep 2024 09:30:26 +0000 Subject: [PATCH] use weights_only to True --- shimmer/utils.py | 4 +++- tests/test_ckpt_migrations.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/shimmer/utils.py b/shimmer/utils.py index dfdc39b3..4fff6531 100644 --- a/shimmer/utils.py +++ b/shimmer/utils.py @@ -71,7 +71,9 @@ def migrate_model(ckpt_path: str | PathLike, **torch_load_kwargs): torch_load_kwargs: additional args given to torch.load. """ ckpt_path = Path(ckpt_path) - ckpt = torch.load(ckpt_path, **torch_load_kwargs) + default_torch_kwargs: dict[str, Any] = {"weights_only": True} + default_torch_kwargs.update(torch_load_kwargs) + ckpt = torch.load(ckpt_path, **default_torch_kwargs) new_ckpt, done_migrations = migrate_from_folder(ckpt, MIGRATION_DIR) done_migration_log = ", ".join(map(lambda x: x.name, done_migrations)) print(f"Migrating: {done_migration_log}") diff --git a/tests/test_ckpt_migrations.py b/tests/test_ckpt_migrations.py index 1b90b134..09cb98fc 100644 --- a/tests/test_ckpt_migrations.py +++ b/tests/test_ckpt_migrations.py @@ -14,7 +14,7 @@ def test_ckpt_migration_2_domains(): old_ckpt_path = here / "data" / "old_gw_2_domains.ckpt" - old_ckpt = torch.load(old_ckpt_path) + old_ckpt = torch.load(old_ckpt_path, weights_only=True) new_ckpt, done_migrations = migrate_from_folder(old_ckpt, MIGRATION_DIR) old_keys = set(old_ckpt["state_dict"].keys()) @@ -74,7 +74,7 @@ def test_ckpt_migration_2_domains(): def test_ckpt_migration_gw(): old_ckpt_path = here / "data" / "old_gw.ckpt" - old_ckpt = torch.load(old_ckpt_path) + old_ckpt = torch.load(old_ckpt_path, weights_only=True) new_ckpt, done_migrations = migrate_from_folder(old_ckpt, MIGRATION_DIR) old_keys = set(old_ckpt["state_dict"].keys())