From 2c38c090064fb7d0317ab192e0757d73d338c6ee Mon Sep 17 00:00:00 2001 From: Rich James Date: Tue, 12 Nov 2024 19:16:42 +0000 Subject: [PATCH] Support for safetensor checkpoints --- MaxText/llama_or_mistral_ckpt.py | 79 ++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 25 deletions(-) diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index fadcbffa3..c5cc193dd 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -40,18 +40,20 @@ os.environ["JAX_PLATFORMS"] = "cpu" -import numpy as np +from flax.training import train_state import jax from jax import tree -from flax.training import train_state -import torch +import numpy as np import psutil +from safetensors import safe_open +import torch from tqdm import tqdm import max_logging from train import save_checkpoint import checkpointing +CHECKPOINT_TYPES = ("pth", "safetensors") MODEL_PARAMS_DICT = { "llama2-70b": { @@ -141,7 +143,8 @@ }, } -SIMULATED_CPU_DEVICES_COUNT = 16 +# We get errors with > 1 that we don't as yet understand +SIMULATED_CPU_DEVICES_COUNT = 1 def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict: @@ -151,12 +154,12 @@ def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict: "norm.weight": "model.norm.weight", "output.weight": "lm_head.weight", # MOE model - f"layers.{layer_idx}.attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight", - f"layers.{layer_idx}.ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight", - f"layers.{layer_idx}.attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight", f"layers.{layer_idx}.attention.wk.weight": f"model.layers.{layer_idx}.self_attn.k_proj.weight", - f"layers.{layer_idx}.attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight", f"layers.{layer_idx}.attention.wo.weight": f"model.layers.{layer_idx}.self_attn.o_proj.weight", + f"layers.{layer_idx}.attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight", + f"layers.{layer_idx}.attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight", + f"layers.{layer_idx}.attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight", + f"layers.{layer_idx}.ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight", f"layers.{layer_idx}.feed_forward.gate.weight": f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w1.weight": f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight", f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w2.weight": f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight", @@ -197,7 +200,31 @@ def permute_to_match_maxtext_rope(arr): return np.concatenate((evens, odds), axis=arr.ndim - 1) -def convert_to_jax_weights(base_model_path, model_size): +def load_pth_checkpoint(ckpt_paths): + chkpt_vars_raw = {} + for i, ckpt_path in enumerate(ckpt_paths): + max_logging.log(f"Loading checkpointpath {i+1} of {len(ckpt_paths)} ...") + checkpoint = torch.load(ckpt_path, map_location="cpu") + chkpt_vars_raw[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint + chkpt_vars_sorted = [chkpt_vars_raw[i] for i in sorted(list(chkpt_vars_raw.keys()))] + # map weight names if they use HuggingFace instead of PyTorch convention + chkpt_vars = [_HFNamespaceMapper(var) for var in chkpt_vars_sorted] + return chkpt_vars + + +def load_safetensors_checkpoint(ckpt_paths): + chkpt_vars_raw = {} + for i, ckpt_path in enumerate(ckpt_paths): + max_logging.log(f"Loading checkpoint path {i+1} of {len(ckpt_paths)} ...") + with safe_open(ckpt_path, framework="pt") as f: + for k in f.keys(): + assert k not in chkpt_vars_raw + chkpt_vars_raw[k] = f.get_tensor(k) + chkpt_vars = [_HFNamespaceMapper(chkpt_vars_raw)] + return chkpt_vars + + +def convert_to_jax_weights(base_model_path, model_size, checkpoint_type): """ Function to convert the checkpoint at base_model_path into Orbax checkpoint for MaxText and output jax_weights ready for MaxText @@ -219,15 +246,12 @@ def convert_to_jax_weights(base_model_path, model_size): max_logging.log(f"Loading the base model from {base_model_path}") # Skip any hidden files for checkpoints - ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.pth")) - chkpt_vars = {} - for i, ckpt_path in enumerate(ckpt_paths): - max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") - checkpoint = torch.load(ckpt_path, map_location="cpu") - chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint - chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))] - # map weight names if they use HuggingFace instead of PyTorch convention - chkpt_vars = [_HFNamespaceMapper(var) for var in chkpt_vars] + ckpt_paths = sorted(pathlib.Path(base_model_path).glob(f"[!.]*.{checkpoint_type}")) + if checkpoint_type == "safetensors": + chkpt_vars = load_safetensors_checkpoint(ckpt_paths) + else: + assert checkpoint_type == "pth" + chkpt_vars = load_pth_checkpoint(ckpt_paths) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) @@ -256,9 +280,8 @@ def convert_to_jax_weights(base_model_path, model_size): # logits dense ################################################# max_logging.log("Processing logits dense") - logits_dense = np.concatenate( - [var["output.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0 - ).transpose()[:, :vocab_size] + logits_dense = np.concatenate([var["output.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0).transpose() + assert logits_dense.shape[1] == vocab_size jax_weights["decoder"]["logits_dense"]["kernel"] = logits_dense logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) @@ -268,10 +291,10 @@ def convert_to_jax_weights(base_model_path, model_size): if model_size[:6] == "llama3": token_embedder = np.concatenate([var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0) else: - token_embedder = np.concatenate( - [var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=1 - )[:vocab_size, :] + token_embedder = np.concatenate([var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=1) + assert token_embedder.shape[0] == vocab_size jax_weights["token_embedder"]["embedding"] = token_embedder + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # self attention ############################################### @@ -532,12 +555,18 @@ def checkpoint_device_put(arr): parser.add_argument("--base-model-path", type=str, required=True) parser.add_argument("--maxtext-model-path", type=str, required=True) parser.add_argument("--model-size", type=str, required=True) + parser.add_argument("--checkpoint-type", type=str, required=True) args = parser.parse_args() if args.model_size not in MODEL_PARAMS_DICT: raise NotImplementedError + if args.checkpoint_type not in CHECKPOINT_TYPES: + raise NotImplementedError + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}" - save_jax_weights_to_checkpoint(args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size)) + save_jax_weights_to_checkpoint( + args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size, args.checkpoint_type) + )