Skip to content

Commit

Permalink
t Merge branch 'main' into sagi/finalize_tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
Sagi Polaczek committed Nov 17, 2024
2 parents 99d3346 + cd29140 commit dce4640
Showing 1 changed file with 47 additions and 16 deletions.
63 changes: 47 additions & 16 deletions mammal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "MammalConfig":
config = cls(**config_dict)
return config

@classmethod
def get_deprecated_arguments(cls) -> list[str]:
"""Property of deprecated arguments to support backward compatibility."""
deprecated_arguments = ["load_weights", "t5_pretrained_name"]
return deprecated_arguments

def override(self, config_overrides: dict[str, Any]) -> None:
"""
Override existing (loaded configuration)
Expand Down Expand Up @@ -391,14 +397,35 @@ def from_pretrained(
) from e

if pretrained_model_name_or_path.endswith(".ckpt"):
pl_ckpt_dict = None
print(f"`.ckpt` file was located. {pretrained_model_name_or_path=}")
if config is None:
# Trying to get the config from the `.ckpt` parent directory
pretrained_model_dirpath = os.path.dirname(
pretrained_model_name_or_path
)
config = os.path.join(pretrained_model_dirpath, CONFIG_NAME)

if not os.path.exists(config):
print(
"Couldn't find `config.json` file in the checkpoint's parent directory. Trying to fetch config from the Lightning checkpoint itself"
)
pl_ckpt_dict = torch.load(
pretrained_model_name_or_path,
map_location="cpu",
weights_only=False,
)
config = pl_ckpt_dict["config"]
for deprecated_arg in MammalConfig.get_deprecated_arguments():
# Remove deprecated arg if exists
if hasattr(config, deprecated_arg):
print(
f"Found deprecated argument '{deprecated_arg}'. Deleting it!"
)
delattr(config, deprecated_arg)

if isinstance(config, str):
# Config path was given or was located in the parent directory
with open(config, encoding="utf-8") as f:
config = json.load(f)
config = MammalConfig.from_dict(config)
Expand All @@ -408,25 +435,29 @@ def from_pretrained(
config.override(config_overrides)
model = cls(config)

pl_ckpt_dict = torch.load(
pretrained_model_name_or_path, map_location="cpu", weights_only=True
)
state_dict = pl_ckpt_dict["state_dict"]
lightning_model_prefix = "_model."
state_dict = {
(
key[len(lightning_model_prefix) :]
if key.startswith(lightning_model_prefix)
else key
): value
for key, value in state_dict.items()
}

if config.random_weights:
if hasattr(config, "random_weights") and config.random_weights:
print(
"Warning! You are loading random weights! To disable it, make sure to config 'random_weights' to False."
)
else:
if pl_ckpt_dict is None:
# Didn't load it before to fetch the config. Load it now to get the weights
pl_ckpt_dict = torch.load(
pretrained_model_name_or_path,
map_location="cpu",
weights_only=True,
)

state_dict = pl_ckpt_dict["state_dict"]
lightning_model_prefix = "_model."
state_dict = {
(
key[len(lightning_model_prefix) :]
if key.startswith(lightning_model_prefix)
else key
): value
for key, value in state_dict.items()
}
# Inject weights to model instance
model.load_state_dict(state_dict, strict=strict)

Expand All @@ -450,7 +481,7 @@ def from_pretrained(
pretrained_model_name_or_path, SAFETENSORS_SINGLE_FILE
)

if config.random_weights:
if hasattr(config, "random_weights") and config.random_weights:
print(
"Warning! You are using random weights! To disable it, make sure to config 'random_weights' to False."
)
Expand Down

0 comments on commit dce4640

Please sign in to comment.