Skip to content

Commit

Permalink
Merge branch 'trintamaki/internvit' into 'main'
Browse files Browse the repository at this point in the history
InternViT support for NVLM

See merge request ADLR/megatron-lm!2295
  • Loading branch information
trintamaki committed Nov 9, 2024
2 parents 32fc18a + 95ea6e5 commit c2e9fb5
Show file tree
Hide file tree
Showing 11 changed files with 672 additions and 16 deletions.
45 changes: 45 additions & 0 deletions examples/multimodal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ def get_language_model_config(config):
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
elif config.language_model_type == "yi-34b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm
)
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 20480
else:
raise ValueError(f"unknown language model type {config.language_model_type}")

return config

Expand Down Expand Up @@ -107,6 +122,30 @@ def get_vision_model_config(config, apply_query_key_layer_scaling):
config.apply_rope_fusion = False
config.qk_layernorm = False
config.layernorm_epsilon = 1e-6
elif config.vision_model_type == "internvit":
config.num_layers = 45
config.num_attention_heads = 32 # Padded for TP=8.
config.num_query_groups = 32 # Padded for TP=8.
config.kv_channels = 128
config.add_bias_linear = True
config.add_qkv_bias = False
config.hidden_size = 3200
config.hidden_dropout = 0.0
config.attention_dropout = 0.0
config.ffn_hidden_size = 12800
config.gated_linear_unit = False
config.activation_func = torch.nn.functional.gelu
config.layernorm_zero_centered_gamma = False
config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
config.bias_activation_fusion = False
config.bias_dropout_fusion = False
config.attention_softmax_in_fp32 = True
config.normalization = 'RMSNorm'
config.layernorm_epsilon = 1e-6
config.apply_rope_fusion = False
else:
raise ValueError(f"unknown vision model type {config.vision_model_type}")


return config

Expand All @@ -128,6 +167,12 @@ def get_vision_projection_config(config, hidden_size):
elif config.language_model_type == "mistral_7b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "yi-34b":
config.ffn_hidden_size = 20480
config.normalization = 'LayerNorm'
config.activation_func = torch.nn.functional.gelu
else:
raise ValueError(f"unknown language model type {config.language_model_type}")

return config

Expand Down
27 changes: 17 additions & 10 deletions examples/multimodal/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
from torchvision.transforms import Compose, RandAugment, RandomResizedCrop, Resize, ToPILImage


# Imagenet's mean and std.
pixel_mean = [123.675, 116.28, 103.53]
pixel_std = [58.395, 57.12, 57.375]

# Reshape for broadcasting.
pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
pixel_mean_clip = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std_clip = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)

pixel_mean_siglip = torch.Tensor([127.5, 127.5, 127.5]).view(-1, 1, 1)
pixel_std_siglip = torch.Tensor([127.5, 127.5, 127.5]).view(-1, 1, 1)

pixel_statistics = {
"clip": (pixel_mean_clip, pixel_std_clip),
"siglip": (pixel_mean_siglip, pixel_std_siglip),
"internvit": (pixel_mean_clip, pixel_std_clip),
}


def convert_to_rgb(image):
Expand All @@ -36,12 +41,14 @@ def _transform_test(img_h, img_w):
])


def standardize_image(img):
def standardize_image(img, mean, std):
"""Standardize image pixel values."""
return (torch.Tensor(np.array(img)).permute(2, 0, 1) - pixel_mean) / pixel_std
return (torch.Tensor(np.array(img)).permute(2, 0, 1) - mean) / std


def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, vision_model_type="clip"):
pixel_mean, pixel_std = pixel_statistics[vision_model_type]

def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False):
if use_tiling:
assert img_h == img_w, "dynamic tiling expects equal tile height and width"
imgs = dynamic_preprocess(img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail)
Expand All @@ -60,7 +67,7 @@ def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, u
img = visual_transform(img)

# Standardize pixel values.
img = standardize_image(img)
img = standardize_image(img, pixel_mean, pixel_std)

# Pad to target image size.
delta_h, delta_w = img_h - scaled_h, img_w - scaled_w
Expand Down
5 changes: 4 additions & 1 deletion examples/multimodal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def model_provider(

num_image_embeddings = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, args.vision_model_type,
args.disable_vision_class_token, 1
args.disable_vision_class_token, 1, args.pixel_shuffle,
)
old_seq_length = args.seq_length
args.seq_length = args.encoder_seq_length = num_image_embeddings
Expand Down Expand Up @@ -92,6 +92,9 @@ def model_provider(
vision_transformer_layer_spec = get_layer_spec(
is_vit=True, normalization=vision_config.normalization
)
elif vision_model_type == "internvit":
from nvlm.internvit import get_internvit_layer_spec
vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te)
else:
raise RuntimeError("unsupported vision model type", vision_model_type)

Expand Down
162 changes: 162 additions & 0 deletions examples/multimodal/model_converter/internvit_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import argparse
import os

import torch
from transformers import AutoModel


def convert(model_name, output_path, tensor_parallel_size, use_te):
"""Convert InternViT HF checkpoint to mcore."""
hf_model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
)

hf_state_dict = hf_model.state_dict()
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]

hidden_size = 3200
num_heads = 25
dim = 128

order = torch.ones(3 * hidden_size).long()

for j in range(num_heads):
for i in range(dim):
order[i + dim*3*j] = j*dim+i
order[dim + i + dim*3*j] = j*dim+i+num_heads*dim
order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2

for name, tensor in hf_state_dict.items():
# Map parameter names to ones used in megatron.
new_name = ""
new_tensor = tensor

# This is used for chunking some tensors to target tensor parallel size.
chunk_dim = None

if "embeddings.class_embedding" in name:
new_name = "class_token"
elif "embeddings.patch_embedding.weight" in name:
new_name = "conv1.weight"
elif "embeddings.patch_embedding.bias" in name:
new_name = "conv1.bias"
elif "embeddings.position_embedding" in name:
new_name = "position_embeddings.weight"
new_tensor = new_tensor.squeeze(0)
elif "encoder.layers" in name:
layer_idx = name.split(".")[2]

base = f"decoder.layers.{layer_idx}"

head_dim = 128

if tensor_parallel_size == 1:
num_padded_heads = 25
elif tensor_parallel_size == 8:
# Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism.
# So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model.
num_padded_heads = 32
else:
raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size)

if "ls1" in name:
new_name = f"{base}.ls1"
elif "ls2" in name:
new_name = f"{base}.ls2"
elif "attn.qkv.weight" in name:
new_name = f"{base}.self_attention.linear_qkv.weight"
num_tensors = 3
padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0], :] = new_tensor[order]
new_tensor = padded_tensor
chunk_dim = 0
elif "attn.q_norm.weight" in name:
new_name = f"{base}.self_attention.q_layernorm.weight"
num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0]] = new_tensor
new_tensor = padded_tensor
chunk_dim = 0
elif "attn.k_norm.weight" in name:
new_name = f"{base}.self_attention.k_layernorm.weight"
num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:new_tensor.shape[0]] = new_tensor
new_tensor = padded_tensor
chunk_dim = 0
elif "attn.proj.weight" in name:
new_name = f"{base}.self_attention.linear_proj.weight"
num_tensors = 1
padded_dim = head_dim * num_padded_heads * num_tensors
padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device)
padded_tensor[:, :new_tensor.shape[-1]] = new_tensor
new_tensor = padded_tensor
chunk_dim = 1
elif "attn.proj.bias" in name:
new_name = f"{base}.self_attention.linear_proj.bias"
elif "mlp.fc1.weight" in name:
new_name = f"{base}.mlp.linear_fc1.weight"
chunk_dim = 0
elif "mlp.fc1.bias" in name:
new_name = f"{base}.mlp.linear_fc1.bias"
chunk_dim = 0
elif "mlp.fc2.weight" in name:
new_name = f"{base}.mlp.linear_fc2.weight"
chunk_dim = 1
elif "mlp.fc2.bias" in name:
new_name = f"{base}.mlp.linear_fc2.bias"
elif "norm1" in name:
new_name = f"{base}.input_layernorm.weight"
elif "norm2" in name:
new_name = f"{base}.pre_mlp_layernorm.weight"
else:
raise RuntimeError("unexpected transformer layer name", name)
else:
raise RuntimeError("unexpected layer name", name)

assert new_name != "", f"unexpected layer name {name}"

# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
if use_te and is_extra_state_layer:
layer = new_name.split(".")[-2]
if layer in extra_state_layers:
extra_state_name = (
new_name[: new_name.rfind(".") + 1] + "_extra_state"
) # Replace the weight name.
for i in range(tensor_parallel_size):
new_state_dicts[i]["model"][extra_state_name] = None

if chunk_dim is None:
new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
else:
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)

for i in range(tensor_parallel_size):
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()

for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}")
os.makedirs(output_dir_tp, exist_ok=True)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp)
print("saved file", output_path_tp)

print("done")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter")
parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace")
parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.")
parser.add_argument("--use-te", action="store_true", default=True)
parser.add_argument("--tensor-parallel-size", type=int, required=True)

args = parser.parse_args()

convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te)
Loading

0 comments on commit c2e9fb5

Please sign in to comment.