Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for safetensor checkpoints #1028

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 54 additions & 25 deletions MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,20 @@

os.environ["JAX_PLATFORMS"] = "cpu"

import numpy as np
from flax.training import train_state
import jax
from jax import tree
from flax.training import train_state
import torch
import numpy as np
import psutil
from safetensors import safe_open
import torch
from tqdm import tqdm

import max_logging
from train import save_checkpoint
import checkpointing

CHECKPOINT_TYPES = ("pth", "safetensors")

MODEL_PARAMS_DICT = {
"llama2-70b": {
Expand Down Expand Up @@ -141,7 +143,8 @@
},
}

SIMULATED_CPU_DEVICES_COUNT = 16
# We get errors with > 1 that we don't as yet understand
SIMULATED_CPU_DEVICES_COUNT = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall this fails on llama2-70b, so Anisha added 16 as a workaround.



def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict:
Expand All @@ -151,12 +154,12 @@ def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict:
"norm.weight": "model.norm.weight",
"output.weight": "lm_head.weight",
# MOE model
f"layers.{layer_idx}.attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight",
f"layers.{layer_idx}.ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight",
f"layers.{layer_idx}.attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight",
f"layers.{layer_idx}.attention.wk.weight": f"model.layers.{layer_idx}.self_attn.k_proj.weight",
f"layers.{layer_idx}.attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight",
f"layers.{layer_idx}.attention.wo.weight": f"model.layers.{layer_idx}.self_attn.o_proj.weight",
f"layers.{layer_idx}.attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight",
f"layers.{layer_idx}.attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight",
f"layers.{layer_idx}.attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight",
f"layers.{layer_idx}.ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight",
f"layers.{layer_idx}.feed_forward.gate.weight": f"model.layers.{layer_idx}.block_sparse_moe.gate.weight",
f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w1.weight": f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight",
f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w2.weight": f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight",
Expand Down Expand Up @@ -197,7 +200,31 @@ def permute_to_match_maxtext_rope(arr):
return np.concatenate((evens, odds), axis=arr.ndim - 1)


def convert_to_jax_weights(base_model_path, model_size):
def load_pth_checkpoint(ckpt_paths):
chkpt_vars_raw = {}
for i, ckpt_path in enumerate(ckpt_paths):
max_logging.log(f"Loading checkpointpath {i+1} of {len(ckpt_paths)} ...")
checkpoint = torch.load(ckpt_path, map_location="cpu")
chkpt_vars_raw[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint
chkpt_vars_sorted = [chkpt_vars_raw[i] for i in sorted(list(chkpt_vars_raw.keys()))]
# map weight names if they use HuggingFace instead of PyTorch convention
chkpt_vars = [_HFNamespaceMapper(var) for var in chkpt_vars_sorted]
return chkpt_vars


def load_safetensors_checkpoint(ckpt_paths):
chkpt_vars_raw = {}
for i, ckpt_path in enumerate(ckpt_paths):
max_logging.log(f"Loading checkpoint path {i+1} of {len(ckpt_paths)} ...")
with safe_open(ckpt_path, framework="pt") as f:
for k in f.keys():
assert k not in chkpt_vars_raw
chkpt_vars_raw[k] = f.get_tensor(k)
chkpt_vars = [_HFNamespaceMapper(chkpt_vars_raw)]
return chkpt_vars


def convert_to_jax_weights(base_model_path, model_size, checkpoint_type):
"""
Function to convert the checkpoint at base_model_path into Orbax checkpoint
for MaxText and output jax_weights ready for MaxText
Expand All @@ -219,15 +246,12 @@ def convert_to_jax_weights(base_model_path, model_size):

max_logging.log(f"Loading the base model from {base_model_path}")
# Skip any hidden files for checkpoints
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.pth"))
chkpt_vars = {}
for i, ckpt_path in enumerate(ckpt_paths):
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
checkpoint = torch.load(ckpt_path, map_location="cpu")
chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint
chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))]
# map weight names if they use HuggingFace instead of PyTorch convention
chkpt_vars = [_HFNamespaceMapper(var) for var in chkpt_vars]
ckpt_paths = sorted(pathlib.Path(base_model_path).glob(f"[!.]*.{checkpoint_type}"))
if checkpoint_type == "safetensors":
chkpt_vars = load_safetensors_checkpoint(ckpt_paths)
else:
assert checkpoint_type == "pth"
chkpt_vars = load_pth_checkpoint(ckpt_paths)

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

Expand Down Expand Up @@ -256,9 +280,8 @@ def convert_to_jax_weights(base_model_path, model_size):

# logits dense #################################################
max_logging.log("Processing logits dense")
logits_dense = np.concatenate(
[var["output.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0
).transpose()[:, :vocab_size]
logits_dense = np.concatenate([var["output.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0).transpose()
assert logits_dense.shape[1] == vocab_size
jax_weights["decoder"]["logits_dense"]["kernel"] = logits_dense

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
Expand All @@ -268,10 +291,10 @@ def convert_to_jax_weights(base_model_path, model_size):
if model_size[:6] == "llama3":
token_embedder = np.concatenate([var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0)
else:
token_embedder = np.concatenate(
[var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=1
)[:vocab_size, :]
token_embedder = np.concatenate([var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=1)
assert token_embedder.shape[0] == vocab_size
jax_weights["token_embedder"]["embedding"] = token_embedder

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

# self attention ###############################################
Expand Down Expand Up @@ -532,12 +555,18 @@ def checkpoint_device_put(arr):
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--maxtext-model-path", type=str, required=True)
parser.add_argument("--model-size", type=str, required=True)
parser.add_argument("--checkpoint-type", type=str, required=True)

args = parser.parse_args()

if args.model_size not in MODEL_PARAMS_DICT:
raise NotImplementedError

if args.checkpoint_type not in CHECKPOINT_TYPES:
raise NotImplementedError

os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}"

save_jax_weights_to_checkpoint(args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size))
save_jax_weights_to_checkpoint(
args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size, args.checkpoint_type)
)
Loading